Merge branch 'main' into feat/hitl-frontend

This commit is contained in:
twwu 2025-12-25 13:43:27 +08:00
commit 8b9846f52b
3693 changed files with 107816 additions and 91946 deletions

View File

@ -1,13 +1,13 @@
---
name: frontend-testing
description: Generate Jest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Jest, RTL, unit tests, integration tests, or write/review test requests.
description: Generate Vitest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Vitest, RTL, unit tests, integration tests, or write/review test requests.
---
# Dify Frontend Testing Skill
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. When in doubt, always refer to that document as the canonical specification.
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. Use Vitest mock/timer APIs (`vi.*`).
## When to Apply This Skill
@ -15,7 +15,7 @@ Apply this skill when the user:
- Asks to **write tests** for a component, hook, or utility
- Asks to **review existing tests** for completeness
- Mentions **Jest**, **React Testing Library**, **RTL**, or **spec files**
- Mentions **Vitest**, **React Testing Library**, **RTL**, or **spec files**
- Requests **test coverage** improvement
- Uses `pnpm analyze-component` output as context
- Mentions **testing**, **unit tests**, or **integration tests** for frontend code
@ -33,9 +33,9 @@ Apply this skill when the user:
| Tool | Version | Purpose |
|------|---------|---------|
| Jest | 29.7 | Test runner |
| Vitest | 4.0.16 | Test runner |
| React Testing Library | 16.0 | Component testing |
| happy-dom | - | Test environment |
| jsdom | - | Test environment |
| nock | 14.0 | HTTP mocking |
| TypeScript | 5.x | Type safety |
@ -46,13 +46,13 @@ Apply this skill when the user:
pnpm test
# Watch mode
pnpm test -- --watch
pnpm test:watch
# Run specific file
pnpm test -- path/to/file.spec.tsx
pnpm test path/to/file.spec.tsx
# Generate coverage report
pnpm test -- --coverage
pnpm test:coverage
# Analyze component complexity
pnpm analyze-component <path>
@ -77,9 +77,9 @@ import Component from './index'
// import { ChildComponent } from './child-component'
// ✅ Mock external dependencies only
jest.mock('@/service/api')
jest.mock('next/navigation', () => ({
useRouter: () => ({ push: jest.fn() }),
vi.mock('@/service/api')
vi.mock('next/navigation', () => ({
useRouter: () => ({ push: vi.fn() }),
usePathname: () => '/test',
}))
@ -88,7 +88,7 @@ let mockSharedState = false
describe('ComponentName', () => {
beforeEach(() => {
jest.clearAllMocks() // ✅ Reset mocks BEFORE each test
vi.clearAllMocks() // ✅ Reset mocks BEFORE each test
mockSharedState = false // ✅ Reset shared state
})
@ -117,7 +117,7 @@ describe('ComponentName', () => {
// User Interactions
describe('User Interactions', () => {
it('should handle click events', () => {
const handleClick = jest.fn()
const handleClick = vi.fn()
render(<Component onClick={handleClick} />)
fireEvent.click(screen.getByRole('button'))
@ -155,7 +155,7 @@ describe('ComponentName', () => {
For each file:
┌────────────────────────────────────────┐
│ 1. Write test │
│ 2. Run: pnpm test -- <file>.spec.tsx │
│ 2. Run: pnpm test <file>.spec.tsx
│ 3. PASS? → Mark complete, next file │
│ FAIL? → Fix first, then continue │
└────────────────────────────────────────┘
@ -316,7 +316,7 @@ For more detailed information, refer to:
### Project Configuration
- `web/jest.config.ts` - Jest configuration
- `web/jest.setup.ts` - Test environment setup
- `web/vitest.config.ts` - Vitest configuration
- `web/vitest.setup.ts` - Test environment setup
- `web/testing/analyze-component.js` - Component analysis tool
- `web/__mocks__/react-i18next.ts` - Shared i18n mock (auto-loaded by Jest, no explicit mock needed; override locally only for custom translations)
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.

View File

@ -23,14 +23,14 @@ import userEvent from '@testing-library/user-event'
// ============================================================================
// Mocks
// ============================================================================
// WHY: Mocks must be hoisted to top of file (Jest requirement).
// WHY: Mocks must be hoisted to top of file (Vitest requirement).
// They run BEFORE imports, so keep them before component imports.
// i18n (automatically mocked)
// WHY: Shared mock at web/__mocks__/react-i18next.ts is auto-loaded by Jest
// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
// No explicit mock needed - it returns translation keys as-is
// Override only if custom translations are required:
// jest.mock('react-i18next', () => ({
// vi.mock('react-i18next', () => ({
// useTranslation: () => ({
// t: (key: string) => {
// const customTranslations: Record<string, string> = {
@ -43,17 +43,17 @@ import userEvent from '@testing-library/user-event'
// Router (if component uses useRouter, usePathname, useSearchParams)
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
// const mockPush = jest.fn()
// jest.mock('next/navigation', () => ({
// const mockPush = vi.fn()
// vi.mock('next/navigation', () => ({
// useRouter: () => ({ push: mockPush }),
// usePathname: () => '/test-path',
// }))
// API services (if component fetches data)
// WHY: Prevents real network calls, enables testing all states (loading/success/error)
// jest.mock('@/service/api')
// vi.mock('@/service/api')
// import * as api from '@/service/api'
// const mockedApi = api as jest.Mocked<typeof api>
// const mockedApi = vi.mocked(api)
// Shared mock state (for portal/dropdown components)
// WHY: Portal components like PortalToFollowElem need shared state between
@ -98,7 +98,7 @@ describe('ComponentName', () => {
// - Prevents mock call history from leaking between tests
// - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
// Reset shared mock state if used (CRITICAL for portal/dropdown tests)
// mockOpenState = false
})
@ -155,7 +155,7 @@ describe('ComponentName', () => {
// - userEvent simulates real user behavior (focus, hover, then click)
// - fireEvent is lower-level, doesn't trigger all browser events
// const user = userEvent.setup()
// const handleClick = jest.fn()
// const handleClick = vi.fn()
// render(<ComponentName onClick={handleClick} />)
//
// await user.click(screen.getByRole('button'))
@ -165,7 +165,7 @@ describe('ComponentName', () => {
it('should call onChange when value changes', async () => {
// const user = userEvent.setup()
// const handleChange = jest.fn()
// const handleChange = vi.fn()
// render(<ComponentName onChange={handleChange} />)
//
// await user.type(screen.getByRole('textbox'), 'new value')
@ -198,7 +198,7 @@ describe('ComponentName', () => {
})
// --------------------------------------------------------------------------
// Async Operations (if component fetches data - useSWR, useQuery, fetch)
// Async Operations (if component fetches data - useQuery, fetch)
// --------------------------------------------------------------------------
// WHY: Async operations have 3 states users experience: loading, success, error
describe('Async Operations', () => {

View File

@ -15,9 +15,9 @@ import { renderHook, act, waitFor } from '@testing-library/react'
// ============================================================================
// API services (if hook fetches data)
// jest.mock('@/service/api')
// vi.mock('@/service/api')
// import * as api from '@/service/api'
// const mockedApi = api as jest.Mocked<typeof api>
// const mockedApi = vi.mocked(api)
// ============================================================================
// Test Helpers
@ -38,7 +38,7 @@ import { renderHook, act, waitFor } from '@testing-library/react'
describe('useHookName', () => {
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
})
// --------------------------------------------------------------------------
@ -145,7 +145,7 @@ describe('useHookName', () => {
// --------------------------------------------------------------------------
describe('Side Effects', () => {
it('should call callback when value changes', () => {
// const callback = jest.fn()
// const callback = vi.fn()
// const { result } = renderHook(() => useHookName({ onChange: callback }))
//
// act(() => {
@ -156,9 +156,9 @@ describe('useHookName', () => {
})
it('should cleanup on unmount', () => {
// const cleanup = jest.fn()
// jest.spyOn(window, 'addEventListener')
// jest.spyOn(window, 'removeEventListener')
// const cleanup = vi.fn()
// vi.spyOn(window, 'addEventListener')
// vi.spyOn(window, 'removeEventListener')
//
// const { unmount } = renderHook(() => useHookName())
//

View File

@ -49,7 +49,7 @@ import userEvent from '@testing-library/user-event'
it('should submit form', async () => {
const user = userEvent.setup()
const onSubmit = jest.fn()
const onSubmit = vi.fn()
render(<Form onSubmit={onSubmit} />)
@ -77,15 +77,15 @@ it('should submit form', async () => {
```typescript
describe('Debounced Search', () => {
beforeEach(() => {
jest.useFakeTimers()
vi.useFakeTimers()
})
afterEach(() => {
jest.useRealTimers()
vi.useRealTimers()
})
it('should debounce search input', async () => {
const onSearch = jest.fn()
const onSearch = vi.fn()
render(<SearchInput onSearch={onSearch} debounceMs={300} />)
// Type in the input
@ -95,7 +95,7 @@ describe('Debounced Search', () => {
expect(onSearch).not.toHaveBeenCalled()
// Advance timers
jest.advanceTimersByTime(300)
vi.advanceTimersByTime(300)
// Now search is called
expect(onSearch).toHaveBeenCalledWith('query')
@ -107,8 +107,8 @@ describe('Debounced Search', () => {
```typescript
it('should retry on failure', async () => {
jest.useFakeTimers()
const fetchData = jest.fn()
vi.useFakeTimers()
const fetchData = vi.fn()
.mockRejectedValueOnce(new Error('Network error'))
.mockResolvedValueOnce({ data: 'success' })
@ -120,7 +120,7 @@ it('should retry on failure', async () => {
})
// Advance timer for retry
jest.advanceTimersByTime(1000)
vi.advanceTimersByTime(1000)
// Second call succeeds
await waitFor(() => {
@ -128,7 +128,7 @@ it('should retry on failure', async () => {
expect(screen.getByText('success')).toBeInTheDocument()
})
jest.useRealTimers()
vi.useRealTimers()
})
```
@ -136,19 +136,19 @@ it('should retry on failure', async () => {
```typescript
// Run all pending timers
jest.runAllTimers()
vi.runAllTimers()
// Run only pending timers (not new ones created during execution)
jest.runOnlyPendingTimers()
vi.runOnlyPendingTimers()
// Advance by specific time
jest.advanceTimersByTime(1000)
vi.advanceTimersByTime(1000)
// Get current fake time
jest.now()
Date.now()
// Clear all timers
jest.clearAllTimers()
vi.clearAllTimers()
```
## API Testing Patterns
@ -158,7 +158,7 @@ jest.clearAllTimers()
```typescript
describe('DataFetcher', () => {
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
})
it('should show loading state', () => {
@ -241,7 +241,7 @@ it('should submit form and show success', async () => {
```typescript
it('should fetch data on mount', async () => {
const fetchData = jest.fn().mockResolvedValue({ data: 'test' })
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
render(<ComponentWithEffect fetchData={fetchData} />)
@ -255,7 +255,7 @@ it('should fetch data on mount', async () => {
```typescript
it('should refetch when id changes', async () => {
const fetchData = jest.fn().mockResolvedValue({ data: 'test' })
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
const { rerender } = render(<ComponentWithEffect id="1" fetchData={fetchData} />)
@ -276,8 +276,8 @@ it('should refetch when id changes', async () => {
```typescript
it('should cleanup subscription on unmount', () => {
const subscribe = jest.fn()
const unsubscribe = jest.fn()
const subscribe = vi.fn()
const unsubscribe = vi.fn()
subscribe.mockReturnValue(unsubscribe)
const { unmount } = render(<SubscriptionComponent subscribe={subscribe} />)
@ -332,14 +332,14 @@ expect(description).toBeInTheDocument()
```typescript
// Bad - fake timers don't work well with real Promises
jest.useFakeTimers()
vi.useFakeTimers()
await waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
}) // May timeout!
// Good - use runAllTimers or advanceTimersByTime
jest.useFakeTimers()
vi.useFakeTimers()
render(<Component />)
jest.runAllTimers()
vi.runAllTimers()
expect(screen.getByText('Data')).toBeInTheDocument()
```

View File

@ -74,9 +74,9 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
### Mocks
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
- [ ] `jest.clearAllMocks()` in `beforeEach` (not `afterEach`)
- [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`)
- [ ] Shared mock state reset in `beforeEach`
- [ ] i18n uses shared mock (auto-loaded); only override locally for custom translations
- [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations
- [ ] Router mocks match actual Next.js API
- [ ] Mocks reflect actual component conditional behavior
- [ ] Only mock: API services, complex context providers, third-party libs
@ -114,15 +114,15 @@ For the current file being tested:
**Run these checks after EACH test file, not just at the end:**
- [ ] Run `pnpm test -- path/to/file.spec.tsx` - **MUST PASS before next file**
- [ ] Run `pnpm test path/to/file.spec.tsx` - **MUST PASS before next file**
- [ ] Fix any failures immediately
- [ ] Mark file as complete in todo list
- [ ] Only then proceed to next file
### After All Files Complete
- [ ] Run full directory test: `pnpm test -- path/to/directory/`
- [ ] Check coverage report: `pnpm test -- --coverage`
- [ ] Run full directory test: `pnpm test path/to/directory/`
- [ ] Check coverage report: `pnpm test:coverage`
- [ ] Run `pnpm lint:fix` on all test files
- [ ] Run `pnpm type-check:tsgo`
@ -132,10 +132,10 @@ For the current file being tested:
```typescript
// ❌ Mock doesn't match actual behavior
jest.mock('./Component', () => () => <div>Mocked</div>)
vi.mock('./Component', () => () => <div>Mocked</div>)
// ✅ Mock matches actual conditional logic
jest.mock('./Component', () => ({ isOpen }: any) =>
vi.mock('./Component', () => ({ isOpen }: any) =>
isOpen ? <div>Content</div> : null
)
```
@ -145,7 +145,7 @@ jest.mock('./Component', () => ({ isOpen }: any) =>
```typescript
// ❌ Shared state not reset
let mockState = false
jest.mock('./useHook', () => () => mockState)
vi.mock('./useHook', () => () => mockState)
// ✅ Reset in beforeEach
beforeEach(() => {
@ -186,16 +186,16 @@ Always test these scenarios:
```bash
# Run specific test
pnpm test -- path/to/file.spec.tsx
pnpm test path/to/file.spec.tsx
# Run with coverage
pnpm test -- --coverage path/to/file.spec.tsx
pnpm test:coverage path/to/file.spec.tsx
# Watch mode
pnpm test -- --watch path/to/file.spec.tsx
pnpm test:watch path/to/file.spec.tsx
# Update snapshots (use sparingly)
pnpm test -- -u path/to/file.spec.tsx
pnpm test -u path/to/file.spec.tsx
# Analyze component
pnpm analyze-component path/to/component.tsx

View File

@ -126,7 +126,7 @@ describe('Counter', () => {
describe('ControlledInput', () => {
it('should call onChange with new value', async () => {
const user = userEvent.setup()
const handleChange = jest.fn()
const handleChange = vi.fn()
render(<ControlledInput value="" onChange={handleChange} />)
@ -136,7 +136,7 @@ describe('ControlledInput', () => {
})
it('should display controlled value', () => {
render(<ControlledInput value="controlled" onChange={jest.fn()} />)
render(<ControlledInput value="controlled" onChange={vi.fn()} />)
expect(screen.getByRole('textbox')).toHaveValue('controlled')
})
@ -195,7 +195,7 @@ describe('ItemList', () => {
it('should handle item selection', async () => {
const user = userEvent.setup()
const onSelect = jest.fn()
const onSelect = vi.fn()
render(<ItemList items={items} onSelect={onSelect} />)
@ -217,20 +217,20 @@ describe('ItemList', () => {
```typescript
describe('Modal', () => {
it('should not render when closed', () => {
render(<Modal isOpen={false} onClose={jest.fn()} />)
render(<Modal isOpen={false} onClose={vi.fn()} />)
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
})
it('should render when open', () => {
render(<Modal isOpen={true} onClose={jest.fn()} />)
render(<Modal isOpen={true} onClose={vi.fn()} />)
expect(screen.getByRole('dialog')).toBeInTheDocument()
})
it('should call onClose when clicking overlay', async () => {
const user = userEvent.setup()
const handleClose = jest.fn()
const handleClose = vi.fn()
render(<Modal isOpen={true} onClose={handleClose} />)
@ -241,7 +241,7 @@ describe('Modal', () => {
it('should call onClose when pressing Escape', async () => {
const user = userEvent.setup()
const handleClose = jest.fn()
const handleClose = vi.fn()
render(<Modal isOpen={true} onClose={handleClose} />)
@ -254,7 +254,7 @@ describe('Modal', () => {
const user = userEvent.setup()
render(
<Modal isOpen={true} onClose={jest.fn()}>
<Modal isOpen={true} onClose={vi.fn()}>
<button>First</button>
<button>Second</button>
</Modal>
@ -279,7 +279,7 @@ describe('Modal', () => {
describe('LoginForm', () => {
it('should submit valid form', async () => {
const user = userEvent.setup()
const onSubmit = jest.fn()
const onSubmit = vi.fn()
render(<LoginForm onSubmit={onSubmit} />)
@ -296,7 +296,7 @@ describe('LoginForm', () => {
it('should show validation errors', async () => {
const user = userEvent.setup()
render(<LoginForm onSubmit={jest.fn()} />)
render(<LoginForm onSubmit={vi.fn()} />)
// Submit empty form
await user.click(screen.getByRole('button', { name: /sign in/i }))
@ -308,7 +308,7 @@ describe('LoginForm', () => {
it('should validate email format', async () => {
const user = userEvent.setup()
render(<LoginForm onSubmit={jest.fn()} />)
render(<LoginForm onSubmit={vi.fn()} />)
await user.type(screen.getByLabelText(/email/i), 'invalid-email')
await user.click(screen.getByRole('button', { name: /sign in/i }))
@ -318,7 +318,7 @@ describe('LoginForm', () => {
it('should disable submit button while submitting', async () => {
const user = userEvent.setup()
const onSubmit = jest.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
const onSubmit = vi.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
render(<LoginForm onSubmit={onSubmit} />)
@ -407,7 +407,7 @@ it('test 1', () => {
// Good - cleanup is automatic with RTL, but reset mocks
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
})
```

View File

@ -23,7 +23,7 @@ import NodeConfigPanel from './node-config-panel'
import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow'
// Mock workflow context
jest.mock('@/app/components/workflow/hooks', () => ({
vi.mock('@/app/components/workflow/hooks', () => ({
useWorkflowStore: () => mockWorkflowStore,
useNodesInteractions: () => mockNodesInteractions,
}))
@ -31,21 +31,21 @@ jest.mock('@/app/components/workflow/hooks', () => ({
let mockWorkflowStore = {
nodes: [],
edges: [],
updateNode: jest.fn(),
updateNode: vi.fn(),
}
let mockNodesInteractions = {
handleNodeSelect: jest.fn(),
handleNodeDelete: jest.fn(),
handleNodeSelect: vi.fn(),
handleNodeDelete: vi.fn(),
}
describe('NodeConfigPanel', () => {
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
mockWorkflowStore = {
nodes: [],
edges: [],
updateNode: jest.fn(),
updateNode: vi.fn(),
}
})
@ -161,23 +161,23 @@ import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import DocumentUploader from './document-uploader'
jest.mock('@/service/datasets', () => ({
uploadDocument: jest.fn(),
parseDocument: jest.fn(),
vi.mock('@/service/datasets', () => ({
uploadDocument: vi.fn(),
parseDocument: vi.fn(),
}))
import * as datasetService from '@/service/datasets'
const mockedService = datasetService as jest.Mocked<typeof datasetService>
const mockedService = vi.mocked(datasetService)
describe('DocumentUploader', () => {
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
})
describe('File Upload', () => {
it('should accept valid file types', async () => {
const user = userEvent.setup()
const onUpload = jest.fn()
const onUpload = vi.fn()
mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' })
render(<DocumentUploader onUpload={onUpload} />)
@ -326,14 +326,14 @@ describe('DocumentList', () => {
describe('Search & Filtering', () => {
it('should filter by search query', async () => {
const user = userEvent.setup()
jest.useFakeTimers()
vi.useFakeTimers()
render(<DocumentList datasetId="ds-1" />)
await user.type(screen.getByPlaceholderText(/search/i), 'test query')
// Debounce
jest.advanceTimersByTime(300)
vi.advanceTimersByTime(300)
await waitFor(() => {
expect(mockedService.getDocuments).toHaveBeenCalledWith(
@ -342,7 +342,7 @@ describe('DocumentList', () => {
)
})
jest.useRealTimers()
vi.useRealTimers()
})
})
})
@ -367,13 +367,13 @@ import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import AppConfigForm from './app-config-form'
jest.mock('@/service/apps', () => ({
updateAppConfig: jest.fn(),
getAppConfig: jest.fn(),
vi.mock('@/service/apps', () => ({
updateAppConfig: vi.fn(),
getAppConfig: vi.fn(),
}))
import * as appService from '@/service/apps'
const mockedService = appService as jest.Mocked<typeof appService>
const mockedService = vi.mocked(appService)
describe('AppConfigForm', () => {
const defaultConfig = {
@ -384,7 +384,7 @@ describe('AppConfigForm', () => {
}
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
mockedService.getAppConfig.mockResolvedValue(defaultConfig)
})

View File

@ -19,8 +19,8 @@
```typescript
// ❌ WRONG: Don't mock base components
jest.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
jest.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
vi.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
vi.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
// ✅ CORRECT: Import and use real base components
import Loading from '@/app/components/base/loading'
@ -41,20 +41,23 @@ Only mock these categories:
| Location | Purpose |
|----------|---------|
| `web/__mocks__/` | Reusable mocks shared across multiple test files |
| Test file | Test-specific mocks, inline with `jest.mock()` |
| `web/vitest.setup.ts` | Global mocks shared by all tests (for example `react-i18next`, `next/image`) |
| `web/__mocks__/` | Reusable mock factories shared across multiple test files |
| Test file | Test-specific mocks, inline with `vi.mock()` |
Modules are not mocked automatically. Use `vi.mock` in test files, or add global mocks in `web/vitest.setup.ts`.
## Essential Mocks
### 1. i18n (Auto-loaded via Shared Mock)
### 1. i18n (Auto-loaded via Global Mock)
A shared mock is available at `web/__mocks__/react-i18next.ts` and is auto-loaded by Jest.
A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
**No explicit mock needed** for most tests - it returns translation keys as-is.
For tests requiring custom translations, override the mock:
```typescript
jest.mock('react-i18next', () => ({
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => {
const translations: Record<string, string> = {
@ -69,15 +72,15 @@ jest.mock('react-i18next', () => ({
### 2. Next.js Router
```typescript
const mockPush = jest.fn()
const mockReplace = jest.fn()
const mockPush = vi.fn()
const mockReplace = vi.fn()
jest.mock('next/navigation', () => ({
vi.mock('next/navigation', () => ({
useRouter: () => ({
push: mockPush,
replace: mockReplace,
back: jest.fn(),
prefetch: jest.fn(),
back: vi.fn(),
prefetch: vi.fn(),
}),
usePathname: () => '/current-path',
useSearchParams: () => new URLSearchParams('?key=value'),
@ -85,7 +88,7 @@ jest.mock('next/navigation', () => ({
describe('Component', () => {
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
})
it('should navigate on click', () => {
@ -102,7 +105,7 @@ describe('Component', () => {
// ⚠️ Important: Use shared state for components that depend on each other
let mockPortalOpenState = false
jest.mock('@/app/components/base/portal-to-follow-elem', () => ({
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
PortalToFollowElem: ({ children, open, ...props }: any) => {
mockPortalOpenState = open || false // Update shared state
return <div data-testid="portal" data-open={open}>{children}</div>
@ -119,7 +122,7 @@ jest.mock('@/app/components/base/portal-to-follow-elem', () => ({
describe('Component', () => {
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
mockPortalOpenState = false // ✅ Reset shared state
})
})
@ -130,13 +133,13 @@ describe('Component', () => {
```typescript
import * as api from '@/service/api'
jest.mock('@/service/api')
vi.mock('@/service/api')
const mockedApi = api as jest.Mocked<typeof api>
const mockedApi = vi.mocked(api)
describe('Component', () => {
beforeEach(() => {
jest.clearAllMocks()
vi.clearAllMocks()
// Setup default mock implementation
mockedApi.fetchData.mockResolvedValue({ data: [] })
@ -239,32 +242,9 @@ describe('Component with Context', () => {
})
```
### 7. SWR / React Query
### 7. React Query
```typescript
// SWR
jest.mock('swr', () => ({
__esModule: true,
default: jest.fn(),
}))
import useSWR from 'swr'
const mockedUseSWR = useSWR as jest.Mock
describe('Component with SWR', () => {
it('should show loading state', () => {
mockedUseSWR.mockReturnValue({
data: undefined,
error: undefined,
isLoading: true,
})
render(<Component />)
expect(screen.getByText(/loading/i)).toBeInTheDocument()
})
})
// React Query
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
const createTestQueryClient = () => new QueryClient({

View File

@ -35,7 +35,7 @@ When testing a **single component, hook, or utility**:
2. Run `pnpm analyze-component <path>` (if available)
3. Check complexity score and features detected
4. Write the test file
5. Run test: `pnpm test -- <file>.spec.tsx`
5. Run test: `pnpm test <file>.spec.tsx`
6. Fix any failures
7. Verify coverage meets goals (100% function, >95% branch)
```
@ -80,7 +80,7 @@ Process files in this recommended order:
```
┌─────────────────────────────────────────────┐
│ 1. Write test file │
│ 2. Run: pnpm test -- <file>.spec.tsx │
│ 2. Run: pnpm test <file>.spec.tsx
│ 3. If FAIL → Fix immediately, re-run │
│ 4. If PASS → Mark complete in todo list │
│ 5. ONLY THEN proceed to next file │
@ -95,10 +95,10 @@ After all individual tests pass:
```bash
# Run all tests in the directory together
pnpm test -- path/to/directory/
pnpm test path/to/directory/
# Check coverage
pnpm test -- --coverage path/to/directory/
pnpm test:coverage path/to/directory/
```
## Component Complexity Guidelines
@ -201,9 +201,9 @@ Run pnpm test ← Multiple failures, hard to debug
```
# GOOD: Incremental with verification
Write component-a.spec.tsx
Run pnpm test -- component-a.spec.tsx ✅
Run pnpm test component-a.spec.tsx ✅
Write component-b.spec.tsx
Run pnpm test -- component-b.spec.tsx ✅
Run pnpm test component-b.spec.tsx ✅
...continue...
```

View File

@ -6,6 +6,9 @@
"context": "..",
"dockerfile": "Dockerfile"
},
"mounts": [
"source=dify-dev-tmp,target=/tmp,type=volume"
],
"features": {
"ghcr.io/devcontainers/features/node:1": {
"nodeGypDependencies": true,
@ -34,19 +37,13 @@
},
"postStartCommand": "./.devcontainer/post_start_command.sh",
"postCreateCommand": "./.devcontainer/post_create_command.sh"
// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {},
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],
// Use 'postCreateCommand' to run commands after the container is created.
// "postCreateCommand": "python --version",
// Configure tool-specific properties.
// "customizations": {},
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
}
}

View File

@ -1,6 +1,7 @@
#!/bin/bash
WORKSPACE_ROOT=$(pwd)
export COREPACK_ENABLE_DOWNLOAD_PROMPT=0
corepack enable
cd web && pnpm install
pipx install uv

309
.github/CODEOWNERS vendored
View File

@ -7,244 +7,243 @@
* @crazywoola @laipz8200 @Yeuoly
# CODEOWNERS file
.github/CODEOWNERS @laipz8200 @crazywoola
/.github/CODEOWNERS @laipz8200 @crazywoola
# Docs
docs/ @crazywoola
/docs/ @crazywoola
# Backend (default owner, more specific rules below will override)
api/ @QuantumGhost
/api/ @QuantumGhost
# Backend - MCP
api/core/mcp/ @Nov1c444
api/core/entities/mcp_provider.py @Nov1c444
api/services/tools/mcp_tools_manage_service.py @Nov1c444
api/controllers/mcp/ @Nov1c444
api/controllers/console/app/mcp_server.py @Nov1c444
api/tests/**/*mcp* @Nov1c444
/api/core/mcp/ @Nov1c444
/api/core/entities/mcp_provider.py @Nov1c444
/api/services/tools/mcp_tools_manage_service.py @Nov1c444
/api/controllers/mcp/ @Nov1c444
/api/controllers/console/app/mcp_server.py @Nov1c444
/api/tests/**/*mcp* @Nov1c444
# Backend - Workflow - Engine (Core graph execution engine)
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
api/core/workflow/runtime/ @laipz8200 @QuantumGhost
api/core/workflow/graph/ @laipz8200 @QuantumGhost
api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
api/core/workflow/node_events/ @laipz8200 @QuantumGhost
api/core/model_runtime/ @laipz8200 @QuantumGhost
/api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
/api/core/workflow/runtime/ @laipz8200 @QuantumGhost
/api/core/workflow/graph/ @laipz8200 @QuantumGhost
/api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
/api/core/workflow/node_events/ @laipz8200 @QuantumGhost
/api/core/model_runtime/ @laipz8200 @QuantumGhost
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
api/core/workflow/nodes/agent/ @Nov1c444
api/core/workflow/nodes/iteration/ @Nov1c444
api/core/workflow/nodes/loop/ @Nov1c444
api/core/workflow/nodes/llm/ @Nov1c444
/api/core/workflow/nodes/agent/ @Nov1c444
/api/core/workflow/nodes/iteration/ @Nov1c444
/api/core/workflow/nodes/loop/ @Nov1c444
/api/core/workflow/nodes/llm/ @Nov1c444
# Backend - RAG (Retrieval Augmented Generation)
api/core/rag/ @JohnJyong
api/services/rag_pipeline/ @JohnJyong
api/services/dataset_service.py @JohnJyong
api/services/knowledge_service.py @JohnJyong
api/services/external_knowledge_service.py @JohnJyong
api/services/hit_testing_service.py @JohnJyong
api/services/metadata_service.py @JohnJyong
api/services/vector_service.py @JohnJyong
api/services/entities/knowledge_entities/ @JohnJyong
api/services/entities/external_knowledge_entities/ @JohnJyong
api/controllers/console/datasets/ @JohnJyong
api/controllers/service_api/dataset/ @JohnJyong
api/models/dataset.py @JohnJyong
api/tasks/rag_pipeline/ @JohnJyong
api/tasks/add_document_to_index_task.py @JohnJyong
api/tasks/batch_clean_document_task.py @JohnJyong
api/tasks/clean_document_task.py @JohnJyong
api/tasks/clean_notion_document_task.py @JohnJyong
api/tasks/document_indexing_task.py @JohnJyong
api/tasks/document_indexing_sync_task.py @JohnJyong
api/tasks/document_indexing_update_task.py @JohnJyong
api/tasks/duplicate_document_indexing_task.py @JohnJyong
api/tasks/recover_document_indexing_task.py @JohnJyong
api/tasks/remove_document_from_index_task.py @JohnJyong
api/tasks/retry_document_indexing_task.py @JohnJyong
api/tasks/sync_website_document_indexing_task.py @JohnJyong
api/tasks/batch_create_segment_to_index_task.py @JohnJyong
api/tasks/create_segment_to_index_task.py @JohnJyong
api/tasks/delete_segment_from_index_task.py @JohnJyong
api/tasks/disable_segment_from_index_task.py @JohnJyong
api/tasks/disable_segments_from_index_task.py @JohnJyong
api/tasks/enable_segment_to_index_task.py @JohnJyong
api/tasks/enable_segments_to_index_task.py @JohnJyong
api/tasks/clean_dataset_task.py @JohnJyong
api/tasks/deal_dataset_index_update_task.py @JohnJyong
api/tasks/deal_dataset_vector_index_task.py @JohnJyong
/api/core/rag/ @JohnJyong
/api/services/rag_pipeline/ @JohnJyong
/api/services/dataset_service.py @JohnJyong
/api/services/knowledge_service.py @JohnJyong
/api/services/external_knowledge_service.py @JohnJyong
/api/services/hit_testing_service.py @JohnJyong
/api/services/metadata_service.py @JohnJyong
/api/services/vector_service.py @JohnJyong
/api/services/entities/knowledge_entities/ @JohnJyong
/api/services/entities/external_knowledge_entities/ @JohnJyong
/api/controllers/console/datasets/ @JohnJyong
/api/controllers/service_api/dataset/ @JohnJyong
/api/models/dataset.py @JohnJyong
/api/tasks/rag_pipeline/ @JohnJyong
/api/tasks/add_document_to_index_task.py @JohnJyong
/api/tasks/batch_clean_document_task.py @JohnJyong
/api/tasks/clean_document_task.py @JohnJyong
/api/tasks/clean_notion_document_task.py @JohnJyong
/api/tasks/document_indexing_task.py @JohnJyong
/api/tasks/document_indexing_sync_task.py @JohnJyong
/api/tasks/document_indexing_update_task.py @JohnJyong
/api/tasks/duplicate_document_indexing_task.py @JohnJyong
/api/tasks/recover_document_indexing_task.py @JohnJyong
/api/tasks/remove_document_from_index_task.py @JohnJyong
/api/tasks/retry_document_indexing_task.py @JohnJyong
/api/tasks/sync_website_document_indexing_task.py @JohnJyong
/api/tasks/batch_create_segment_to_index_task.py @JohnJyong
/api/tasks/create_segment_to_index_task.py @JohnJyong
/api/tasks/delete_segment_from_index_task.py @JohnJyong
/api/tasks/disable_segment_from_index_task.py @JohnJyong
/api/tasks/disable_segments_from_index_task.py @JohnJyong
/api/tasks/enable_segment_to_index_task.py @JohnJyong
/api/tasks/enable_segments_to_index_task.py @JohnJyong
/api/tasks/clean_dataset_task.py @JohnJyong
/api/tasks/deal_dataset_index_update_task.py @JohnJyong
/api/tasks/deal_dataset_vector_index_task.py @JohnJyong
# Backend - Plugins
api/core/plugin/ @Mairuis @Yeuoly @Stream29
api/services/plugin/ @Mairuis @Yeuoly @Stream29
api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
/api/core/plugin/ @Mairuis @Yeuoly @Stream29
/api/services/plugin/ @Mairuis @Yeuoly @Stream29
/api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
/api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
# Backend - Trigger/Schedule/Webhook
api/controllers/trigger/ @Mairuis @Yeuoly
api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
api/core/trigger/ @Mairuis @Yeuoly
api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
api/services/trigger/ @Mairuis @Yeuoly
api/models/trigger.py @Mairuis @Yeuoly
api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
api/libs/schedule_utils.py @Mairuis @Yeuoly
api/services/workflow/scheduler.py @Mairuis @Yeuoly
api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
/api/controllers/trigger/ @Mairuis @Yeuoly
/api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
/api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
/api/core/trigger/ @Mairuis @Yeuoly
/api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
/api/services/trigger/ @Mairuis @Yeuoly
/api/models/trigger.py @Mairuis @Yeuoly
/api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
/api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
/api/libs/schedule_utils.py @Mairuis @Yeuoly
/api/services/workflow/scheduler.py @Mairuis @Yeuoly
/api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
/api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
/api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
/api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
/api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
/api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
/api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
/api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
# Backend - Async Workflow
api/services/async_workflow_service.py @Mairuis @Yeuoly
api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
/api/services/async_workflow_service.py @Mairuis @Yeuoly
/api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
# Backend - Billing
api/services/billing_service.py @hj24 @zyssyz123
api/controllers/console/billing/ @hj24 @zyssyz123
/api/services/billing_service.py @hj24 @zyssyz123
/api/controllers/console/billing/ @hj24 @zyssyz123
# Backend - Enterprise
api/configs/enterprise/ @GarfieldDai @GareArc
api/services/enterprise/ @GarfieldDai @GareArc
api/services/feature_service.py @GarfieldDai @GareArc
api/controllers/console/feature.py @GarfieldDai @GareArc
api/controllers/web/feature.py @GarfieldDai @GareArc
/api/configs/enterprise/ @GarfieldDai @GareArc
/api/services/enterprise/ @GarfieldDai @GareArc
/api/services/feature_service.py @GarfieldDai @GareArc
/api/controllers/console/feature.py @GarfieldDai @GareArc
/api/controllers/web/feature.py @GarfieldDai @GareArc
# Backend - Database Migrations
api/migrations/ @snakevash @laipz8200 @MRZHUH
/api/migrations/ @snakevash @laipz8200 @MRZHUH
# Backend - Vector DB Middleware
api/configs/middleware/vdb/* @JohnJyong
/api/configs/middleware/vdb/* @JohnJyong
# Frontend
web/ @iamjoel
/web/ @iamjoel
# Frontend - Web Tests
.github/workflows/web-tests.yml @iamjoel
/.github/workflows/web-tests.yml @iamjoel
# Frontend - App - Orchestration
web/app/components/workflow/ @iamjoel @zxhlyh
web/app/components/workflow-app/ @iamjoel @zxhlyh
web/app/components/app/configuration/ @iamjoel @zxhlyh
web/app/components/app/app-publisher/ @iamjoel @zxhlyh
/web/app/components/workflow/ @iamjoel @zxhlyh
/web/app/components/workflow-app/ @iamjoel @zxhlyh
/web/app/components/app/configuration/ @iamjoel @zxhlyh
/web/app/components/app/app-publisher/ @iamjoel @zxhlyh
# Frontend - WebApp - Chat
web/app/components/base/chat/ @iamjoel @zxhlyh
/web/app/components/base/chat/ @iamjoel @zxhlyh
# Frontend - WebApp - Completion
web/app/components/share/text-generation/ @iamjoel @zxhlyh
/web/app/components/share/text-generation/ @iamjoel @zxhlyh
# Frontend - App - List and Creation
web/app/components/apps/ @JzoNgKVO @iamjoel
web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
/web/app/components/apps/ @JzoNgKVO @iamjoel
/web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
/web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
/web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
# Frontend - App - API Documentation
web/app/components/develop/ @JzoNgKVO @iamjoel
/web/app/components/develop/ @JzoNgKVO @iamjoel
# Frontend - App - Logs and Annotations
web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
web/app/components/app/log/ @JzoNgKVO @iamjoel
web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
web/app/components/app/annotation/ @JzoNgKVO @iamjoel
/web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
/web/app/components/app/log/ @JzoNgKVO @iamjoel
/web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
/web/app/components/app/annotation/ @JzoNgKVO @iamjoel
# Frontend - App - Monitoring
web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
web/app/components/app/overview/ @JzoNgKVO @iamjoel
/web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
/web/app/components/app/overview/ @JzoNgKVO @iamjoel
# Frontend - App - Settings
web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
/web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
# Frontend - RAG - Hit Testing
web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
/web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
# Frontend - RAG - List and Creation
web/app/components/datasets/list/ @iamjoel @WTW0313
web/app/components/datasets/create/ @iamjoel @WTW0313
web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
/web/app/components/datasets/list/ @iamjoel @WTW0313
/web/app/components/datasets/create/ @iamjoel @WTW0313
/web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
/web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
web/app/components/rag-pipeline/ @iamjoel @WTW0313
web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
/web/app/components/rag-pipeline/ @iamjoel @WTW0313
/web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
/web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
# Frontend - RAG - Documents List
web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
/web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
/web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
# Frontend - RAG - Segments List
web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
/web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
# Frontend - RAG - Settings
web/app/components/datasets/settings/ @iamjoel @WTW0313
/web/app/components/datasets/settings/ @iamjoel @WTW0313
# Frontend - Ecosystem - Plugins
web/app/components/plugins/ @iamjoel @zhsama
/web/app/components/plugins/ @iamjoel @zhsama
# Frontend - Ecosystem - Tools
web/app/components/tools/ @iamjoel @Yessenia-d
/web/app/components/tools/ @iamjoel @Yessenia-d
# Frontend - Ecosystem - MarketPlace
web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
/web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
# Frontend - Login and Registration
web/app/signin/ @douxc @iamjoel
web/app/signup/ @douxc @iamjoel
web/app/reset-password/ @douxc @iamjoel
web/app/install/ @douxc @iamjoel
web/app/init/ @douxc @iamjoel
web/app/forgot-password/ @douxc @iamjoel
web/app/account/ @douxc @iamjoel
/web/app/signin/ @douxc @iamjoel
/web/app/signup/ @douxc @iamjoel
/web/app/reset-password/ @douxc @iamjoel
/web/app/install/ @douxc @iamjoel
/web/app/init/ @douxc @iamjoel
/web/app/forgot-password/ @douxc @iamjoel
/web/app/account/ @douxc @iamjoel
# Frontend - Service Authentication
web/service/base.ts @douxc @iamjoel
/web/service/base.ts @douxc @iamjoel
# Frontend - WebApp Authentication and Access Control
web/app/(shareLayout)/components/ @douxc @iamjoel
web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
web/app/components/app/app-access-control/ @douxc @iamjoel
/web/app/(shareLayout)/components/ @douxc @iamjoel
/web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
/web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
/web/app/components/app/app-access-control/ @douxc @iamjoel
# Frontend - Explore Page
web/app/components/explore/ @CodingOnStar @iamjoel
/web/app/components/explore/ @CodingOnStar @iamjoel
# Frontend - Personal Settings
web/app/components/header/account-setting/ @CodingOnStar @iamjoel
web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
/web/app/components/header/account-setting/ @CodingOnStar @iamjoel
/web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
# Frontend - Analytics
web/app/components/base/ga/ @CodingOnStar @iamjoel
/web/app/components/base/ga/ @CodingOnStar @iamjoel
# Frontend - Base Components
web/app/components/base/ @iamjoel @zxhlyh
/web/app/components/base/ @iamjoel @zxhlyh
# Frontend - Utils and Hooks
web/utils/classnames.ts @iamjoel @zxhlyh
web/utils/time.ts @iamjoel @zxhlyh
web/utils/format.ts @iamjoel @zxhlyh
web/utils/clipboard.ts @iamjoel @zxhlyh
web/hooks/use-document-title.ts @iamjoel @zxhlyh
/web/utils/classnames.ts @iamjoel @zxhlyh
/web/utils/time.ts @iamjoel @zxhlyh
/web/utils/format.ts @iamjoel @zxhlyh
/web/utils/clipboard.ts @iamjoel @zxhlyh
/web/hooks/use-document-title.ts @iamjoel @zxhlyh
# Frontend - Billing and Education
web/app/components/billing/ @iamjoel @zxhlyh
web/app/education-apply/ @iamjoel @zxhlyh
/web/app/components/billing/ @iamjoel @zxhlyh
/web/app/education-apply/ @iamjoel @zxhlyh
# Frontend - Workspace
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
/web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
# Docker
docker/* @laipz8200
/docker/* @laipz8200

View File

@ -68,25 +68,4 @@ jobs:
run: |
uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md"
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup NodeJS
uses: actions/setup-node@v4
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Web dependencies
working-directory: ./web
run: pnpm install --frozen-lockfile
- name: oxlint
working-directory: ./web
run: pnpm exec oxlint --config .oxlintrc.json --fix .
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27

View File

@ -1,4 +1,4 @@
name: Check i18n Files and Create PR
name: Translate i18n Files Based on English
on:
push:
@ -67,25 +67,19 @@ jobs:
working-directory: ./web
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
- name: Generate i18n type definitions
if: env.FILES_CHANGED == 'true'
working-directory: ./web
run: pnpm run gen:i18n-types
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: 'chore(i18n): update translations based on en-US changes'
title: 'chore(i18n): translate i18n files and update type definitions'
title: 'chore(i18n): translate i18n files based on en-US changes'
body: |
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
This PR was automatically created to update i18n translation files based on changes in en-US locale.
**Triggered by:** ${{ github.sha }}
**Changes included:**
- Updated translation files for all locales
- Regenerated TypeScript type definitions for type safety
branch: chore/automated-i18n-updates-${{ github.sha }}
delete-branch: true

View File

@ -35,27 +35,11 @@ jobs:
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Restore Jest cache
uses: actions/cache@v4
with:
path: web/.cache/jest
key: ${{ runner.os }}-jest-${{ hashFiles('web/pnpm-lock.yaml') }}
restore-keys: |
${{ runner.os }}-jest-
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Check i18n types synchronization
run: pnpm run check:i18n-types
- name: Run tests
run: |
pnpm exec jest \
--ci \
--maxWorkers=100% \
--coverage \
--passWithNoTests
run: pnpm test:coverage
- name: Coverage Summary
if: always()
@ -69,7 +53,7 @@ jobs:
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
echo "Coverage data not found. Ensure Jest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
echo "Coverage data not found. Ensure Vitest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
exit 0
fi
@ -365,7 +349,7 @@ jobs:
.join(' | ')} |`;
console.log('');
console.log('<details><summary>Jest coverage table</summary>');
console.log('<details><summary>Vitest coverage table</summary>');
console.log('');
console.log(headerRow);
console.log(dividerRow);

12
.gitignore vendored
View File

@ -139,7 +139,6 @@ pyrightconfig.json
.idea/'
.DS_Store
web/.vscode/settings.json
# Intellij IDEA Files
.idea/*
@ -196,6 +195,7 @@ docker/nginx/ssl/*
!docker/nginx/ssl/.gitkeep
docker/middleware.env
docker/docker-compose.override.yaml
docker/env-backup/*
sdks/python-client/build
sdks/python-client/dist
@ -205,7 +205,6 @@ sdks/python-client/dify_client.egg-info
!.vscode/launch.json.template
!.vscode/README.md
api/.vscode
web/.vscode
# vscode Code History Extension
.history
@ -220,15 +219,6 @@ plugins.jsonl
# mise
mise.toml
# Next.js build output
.next/
# PWA generated files
web/public/sw.js
web/public/sw.js.map
web/public/workbox-*.js
web/public/workbox-*.js.map
web/public/fallback-*.js
# AI Assistant
.roo/

View File

@ -116,6 +116,7 @@ ALIYUN_OSS_AUTH_VERSION=v1
ALIYUN_OSS_REGION=your-region
# Don't start with '/'. OSS doesn't support leading slash in object names.
ALIYUN_OSS_PATH=your-path
ALIYUN_CLOUDBOX_ID=your-cloudbox-id
# Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
@ -133,6 +134,7 @@ HUAWEI_OBS_BUCKET_NAME=your-bucket-name
HUAWEI_OBS_SECRET_KEY=your-secret-key
HUAWEI_OBS_ACCESS_KEY=your-access-key
HUAWEI_OBS_SERVER=your-server-url
HUAWEI_OBS_PATH_STYLE=false
# Baidu OBS Storage Configuration
BAIDU_OBS_BUCKET_NAME=your-bucket-name
@ -690,7 +692,6 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
# Maximum number of concurrent annotation import tasks per tenant
ANNOTATION_IMPORT_MAX_CONCURRENT=5
# Sandbox expired records clean configuration
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000

View File

@ -41,3 +41,8 @@ class AliyunOSSStorageConfig(BaseSettings):
description="Base path within the bucket to store objects (e.g., 'my-app-data/')",
default=None,
)
ALIYUN_CLOUDBOX_ID: str | None = Field(
description="Cloudbox id for aliyun cloudbox service",
default=None,
)

View File

@ -26,3 +26,8 @@ class HuaweiCloudOBSStorageConfig(BaseSettings):
description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')",
default=None,
)
HUAWEI_OBS_PATH_STYLE: bool = Field(
description="Flag to indicate whether to use path-style URLs for OBS requests",
default=False,
)

View File

@ -7,9 +7,9 @@ from controllers.console import console_ns
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, extract_remote_ip, timezone
from libs.helper import EmailStr, timezone
from models import AccountStatus
from services.account_service import AccountService, RegisterService
from services.account_service import RegisterService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -93,7 +93,6 @@ class ActivateApi(Resource):
"ActivationResponse",
{
"result": fields.String(description="Operation result"),
"data": fields.Raw(description="Login token data"),
},
),
)
@ -117,6 +116,4 @@ class ActivateApi(Resource):
account.initialized_at = naive_utc_now()
db.session.commit()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
return {"result": "success", "data": token_pair.model_dump()}
return {"result": "success"}

View File

@ -572,7 +572,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],

View File

@ -1,14 +1,32 @@
from flask_restx import Resource, fields, marshal_with, reqparse
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from constants import HIDDEN_VALUE
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import current_account_with_tenant, login_required
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService
from ..common.schema import register_schema_models
from . import console_ns
from .wraps import account_initialization_required, setup_required
class CodeBasedExtensionQuery(BaseModel):
module: str
class APIBasedExtensionPayload(BaseModel):
name: str = Field(description="Extension name")
api_endpoint: str = Field(description="API endpoint URL")
api_key: str = Field(description="API key for authentication")
register_schema_models(console_ns, APIBasedExtensionPayload)
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
@ -18,11 +36,7 @@ api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_m
class CodeBasedExtensionAPI(Resource):
@console_ns.doc("get_code_based_extension")
@console_ns.doc(description="Get code-based extension data by module name")
@console_ns.expect(
console_ns.parser().add_argument(
"module", type=str, required=True, location="args", help="Extension module name"
)
)
@console_ns.doc(params={"module": "Extension module name"})
@console_ns.response(
200,
"Success",
@ -35,10 +49,9 @@ class CodeBasedExtensionAPI(Resource):
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser().add_argument("module", type=str, required=True, location="args")
args = parser.parse_args()
query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)}
@console_ns.route("/api-based-extension")
@ -56,30 +69,21 @@ class APIBasedExtensionAPI(Resource):
@console_ns.doc("create_api_based_extension")
@console_ns.doc(description="Create a new API-based extension")
@console_ns.expect(
console_ns.model(
"CreateAPIBasedExtensionRequest",
{
"name": fields.String(required=True, description="Extension name"),
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
"api_key": fields.String(required=True, description="API key for authentication"),
},
)
)
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
@console_ns.response(201, "Extension created successfully", api_based_extension_model)
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_model)
def post(self):
args = console_ns.payload
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
_, current_tenant_id = current_account_with_tenant()
extension_data = APIBasedExtension(
tenant_id=current_tenant_id,
name=args["name"],
api_endpoint=args["api_endpoint"],
api_key=args["api_key"],
name=payload.name,
api_endpoint=payload.api_endpoint,
api_key=payload.api_key,
)
return APIBasedExtensionService.save(extension_data)
@ -104,16 +108,7 @@ class APIBasedExtensionDetailAPI(Resource):
@console_ns.doc("update_api_based_extension")
@console_ns.doc(description="Update API-based extension")
@console_ns.doc(params={"id": "Extension ID"})
@console_ns.expect(
console_ns.model(
"UpdateAPIBasedExtensionRequest",
{
"name": fields.String(required=True, description="Extension name"),
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
"api_key": fields.String(required=True, description="API key for authentication"),
},
)
)
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
@console_ns.response(200, "Extension updated successfully", api_based_extension_model)
@setup_required
@login_required
@ -125,13 +120,13 @@ class APIBasedExtensionDetailAPI(Resource):
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
args = console_ns.payload
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
extension_data_from_db.name = args["name"]
extension_data_from_db.api_endpoint = args["api_endpoint"]
extension_data_from_db.name = payload.name
extension_data_from_db.api_endpoint = payload.api_endpoint
if args["api_key"] != HIDDEN_VALUE:
extension_data_from_db.api_key = args["api_key"]
if payload.api_key != HIDDEN_VALUE:
extension_data_from_db.api_key = payload.api_key
return APIBasedExtensionService.save(extension_data_from_db)

View File

@ -1,5 +1,6 @@
import io
from typing import Literal
from collections.abc import Mapping
from typing import Any, Literal
from flask import request, send_file
from flask_restx import Resource
@ -141,6 +142,15 @@ class ParserDynamicOptions(BaseModel):
provider_type: Literal["tool", "trigger"]
class ParserDynamicOptionsWithCredentials(BaseModel):
plugin_id: str
provider: str
action: str
parameter: str
credential_id: str
credentials: Mapping[str, Any]
class PluginPermissionSettingsPayload(BaseModel):
install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
@ -183,6 +193,7 @@ reg(ParserGithubUpgrade)
reg(ParserUninstall)
reg(ParserPermissionChange)
reg(ParserDynamicOptions)
reg(ParserDynamicOptionsWithCredentials)
reg(ParserPreferencesChange)
reg(ParserExcludePlugin)
reg(ParserReadme)
@ -657,6 +668,37 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options})
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options-with-credentials")
class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource):
@console_ns.expect(console_ns.models[ParserDynamicOptionsWithCredentials.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
"""Fetch dynamic options using credentials directly (for edit mode)."""
current_user, tenant_id = current_account_with_tenant()
user_id = current_user.id
args = ParserDynamicOptionsWithCredentials.model_validate(console_ns.payload)
try:
options = PluginParameterService.get_dynamic_select_options_with_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=args.plugin_id,
provider=args.provider,
action=args.action,
parameter=args.parameter,
credential_id=args.credential_id,
credentials=args.credentials,
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"options": options})
@console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource):
@console_ns.expect(console_ns.models[ParserPreferencesChange.__name__])

View File

@ -1081,6 +1081,8 @@ class ToolMCPAuthApi(Resource):
credentials=provider_entity.credentials,
authed=True,
)
# Invalidate cache after updating credentials
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
except MCPAuthError as e:
try:
@ -1094,16 +1096,22 @@ class ToolMCPAuthApi(Resource):
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
response = service.execute_auth_actions(auth_result)
# Invalidate cache after auth actions may have updated provider state
ToolProviderListCache.invalidate_cache(tenant_id)
return response
except MCPRefreshTokenError as e:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
# Invalidate cache after clearing credentials
ToolProviderListCache.invalidate_cache(tenant_id)
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except (MCPError, ValueError) as e:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
# Invalidate cache after clearing credentials
ToolProviderListCache.invalidate_cache(tenant_id)
raise ValueError(f"Failed to connect to MCP server: {e}") from e

View File

@ -1,11 +1,15 @@
import logging
from collections.abc import Mapping
from typing import Any
from flask import make_response, redirect, request
from flask_restx import Resource, reqparse
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from controllers.web.error import NotFoundError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
@ -32,6 +36,32 @@ from ..wraps import (
logger = logging.getLogger(__name__)
class TriggerSubscriptionUpdateRequest(BaseModel):
"""Request payload for updating a trigger subscription"""
name: str | None = Field(default=None, description="The name for the subscription")
credentials: Mapping[str, Any] | None = Field(default=None, description="The credentials for the subscription")
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription")
properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription")
class TriggerSubscriptionVerifyRequest(BaseModel):
"""Request payload for verifying subscription credentials."""
credentials: Mapping[str, Any] = Field(description="The credentials to verify")
console_ns.schema_model(
TriggerSubscriptionUpdateRequest.__name__,
TriggerSubscriptionUpdateRequest.model_json_schema(ref_template="#/definitions/{model}"),
)
console_ns.schema_model(
TriggerSubscriptionVerifyRequest.__name__,
TriggerSubscriptionVerifyRequest.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/icon")
class TriggerProviderIconApi(Resource):
@setup_required
@ -155,16 +185,16 @@ parser_api = (
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify-and-update/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderVerifyApi(Resource):
class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
@console_ns.expect(parser_api)
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Verify a subscription instance for a trigger provider"""
"""Verify and update a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
@ -289,6 +319,83 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
raise ValueError(str(e)) from e
@console_ns.route(
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/update",
)
class TriggerSubscriptionUpdateApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionUpdateRequest.__name__])
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def post(self, subscription_id: str):
"""Update a subscription instance"""
user = current_user
assert user.current_tenant_id is not None
args = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
subscription = TriggerProviderService.get_subscription_by_id(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
)
if not subscription:
raise NotFoundError(f"Subscription {subscription_id} not found")
provider_id = TriggerProviderID(subscription.provider_id)
try:
# rename only
if (
args.name is not None
and args.credentials is None
and args.parameters is None
and args.properties is None
):
TriggerProviderService.update_trigger_subscription(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
name=args.name,
)
return 200
# rebuild for create automatically by the provider
match subscription.credential_type:
case CredentialType.UNAUTHORIZED:
TriggerProviderService.update_trigger_subscription(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
name=args.name,
properties=args.properties,
)
return 200
case CredentialType.API_KEY | CredentialType.OAUTH2:
if args.credentials:
new_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
for key, value in args.credentials.items()
}
else:
new_credentials = subscription.credentials
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=user.current_tenant_id,
name=args.name,
provider_id=provider_id,
subscription_id=subscription_id,
credentials=new_credentials,
parameters=args.parameters or subscription.parameters,
)
return 200
case _:
raise BadRequest("Invalid credential type")
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error updating subscription", exc_info=e)
raise
@console_ns.route(
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
)
@ -576,3 +683,38 @@ class TriggerOAuthClientManageApi(Resource):
except Exception as e:
logger.exception("Error removing OAuth client", exc_info=e)
raise
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/verify/<path:subscription_id>",
)
class TriggerSubscriptionVerifyApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionVerifyRequest.__name__])
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_id):
"""Verify credentials for an existing subscription (edit mode only)"""
user = current_user
assert user.current_tenant_id is not None
verify_request: TriggerSubscriptionVerifyRequest = TriggerSubscriptionVerifyRequest.model_validate(
console_ns.payload
)
try:
result = TriggerProviderService.verify_subscription_credentials(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
subscription_id=subscription_id,
credentials=verify_request.credentials,
)
return result
except ValueError as e:
logger.warning("Credential verification failed", exc_info=e)
raise BadRequest(str(e)) from e
except Exception as e:
logger.exception("Error verifying subscription credentials", exc_info=e)
raise BadRequest(str(e)) from e

View File

@ -4,7 +4,7 @@ from uuid import UUID
from flask import request
from flask_restx import Resource
from flask_restx._http import HTTPStatus
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound
@ -51,6 +51,32 @@ class ConversationRenamePayload(BaseModel):
class ConversationVariablesQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
variable_name: str | None = Field(
default=None, description="Filter variables by name", min_length=1, max_length=255
)
@field_validator("variable_name", mode="before")
@classmethod
def validate_variable_name(cls, v: str | None) -> str | None:
"""
Validate variable_name to prevent injection attacks.
"""
if v is None:
return v
# Only allow safe characters: alphanumeric, underscore, hyphen, period
if not v.replace("-", "").replace("_", "").replace(".", "").isalnum():
raise ValueError(
"Variable name can only contain letters, numbers, hyphens (-), underscores (_), and periods (.)"
)
# Prevent SQL injection patterns
dangerous_patterns = ["'", '"', ";", "--", "/*", "*/", "xp_", "sp_"]
for pattern in dangerous_patterns:
if pattern in v.lower():
raise ValueError(f"Variable name contains invalid characters: {pattern}")
return v
class ConversationVariableUpdatePayload(BaseModel):
@ -199,7 +225,7 @@ class ConversationVariablesApi(Resource):
try:
return ConversationService.get_conversational_variable(
app_model, conversation_id, end_user, query_args.limit, last_id
app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@ -1,14 +1,13 @@
import logging
from flask import request
from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, ConfigDict, Field
from werkzeug.exceptions import Unauthorized
from constants import HEADER_NAME_APP_CODE
from controllers.common import fields
from controllers.web import web_ns
from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource
from controllers.common.schema import register_schema_models
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from libs.passport import PassportService
from libs.token import extract_webapp_passport
@ -18,9 +17,23 @@ from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService
from . import web_ns
from .error import AppUnavailableError
from .wraps import WebApiResource
logger = logging.getLogger(__name__)
class AppAccessModeQuery(BaseModel):
model_config = ConfigDict(populate_by_name=True)
app_id: str | None = Field(default=None, alias="appId", description="Application ID")
app_code: str | None = Field(default=None, alias="appCode", description="Application code")
register_schema_models(web_ns, AppAccessModeQuery)
@web_ns.route("/parameters")
class AppParameterApi(WebApiResource):
"""Resource for app variables."""
@ -96,21 +109,16 @@ class AppAccessMode(Resource):
}
)
def get(self):
parser = (
reqparse.RequestParser()
.add_argument("appId", type=str, required=False, location="args")
.add_argument("appCode", type=str, required=False, location="args")
)
args = parser.parse_args()
raw_args = request.args.to_dict()
args = AppAccessModeQuery.model_validate(raw_args)
features = FeatureService.get_system_features()
if not features.webapp_auth.enabled:
return {"accessMode": "public"}
app_id = args.get("appId")
if args.get("appCode"):
app_code = args["appCode"]
app_id = AppService.get_app_id_by_code(app_code)
app_id = args.app_id
if args.app_code:
app_id = AppService.get_app_id_by_code(args.app_code)
if not app_id:
raise ValueError("appId or appCode must be provided")

View File

@ -2,10 +2,12 @@ import base64
import secrets
from flask import request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console.auth.error import (
AuthenticationFailedError,
EmailCodeError,
@ -18,14 +20,40 @@ from controllers.console.error import EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
from controllers.web import web_ns
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from models import Account
from services.account_service import AccountService
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr
language: str | None = None
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr
code: str
token: str = Field(min_length=1)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(min_length=1)
new_password: str
password_confirm: str
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload)
@web_ns.route("/forgot-password")
class ForgotPasswordSendEmailApi(Resource):
@web_ns.expect(web_ns.models[ForgotPasswordSendPayload.__name__])
@only_edition_enterprise
@setup_required
@email_password_login_enabled
@ -40,35 +68,31 @@ class ForgotPasswordSendEmailApi(Resource):
}
)
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans":
if payload.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none()
token = None
if account is None:
raise AuthenticationFailedError()
else:
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language)
return {"result": "success", "data": token}
@web_ns.route("/forgot-password/validity")
class ForgotPasswordCheckApi(Resource):
@web_ns.expect(web_ns.models[ForgotPasswordCheckPayload.__name__])
@only_edition_enterprise
@setup_required
@email_password_login_enabled
@ -78,45 +102,40 @@ class ForgotPasswordCheckApi(Resource):
responses={200: "Token is valid", 400: "Bad request - invalid token format", 401: "Invalid or expired token"}
)
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
user_email = args["email"]
user_email = payload.email
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
token_data = AccountService.get_reset_password_data(args["token"])
token_data = AccountService.get_reset_password_data(payload.token)
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args["email"])
if payload.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(payload.email)
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_reset_password_token(args["token"])
AccountService.revoke_reset_password_token(payload.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
user_email, code=args["code"], additional_data={"phase": "reset"}
user_email, code=payload.code, additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(args["email"])
AccountService.reset_forgot_password_error_rate_limit(payload.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@web_ns.route("/forgot-password/resets")
class ForgotPasswordResetApi(Resource):
@web_ns.expect(web_ns.models[ForgotPasswordResetPayload.__name__])
@only_edition_enterprise
@setup_required
@email_password_login_enabled
@ -131,20 +150,14 @@ class ForgotPasswordResetApi(Resource):
}
)
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args()
payload = ForgotPasswordResetPayload.model_validate(web_ns.payload or {})
# Validate passwords match
if args["new_password"] != args["password_confirm"]:
if payload.new_password != payload.password_confirm:
raise PasswordMismatchError()
# Validate token and get reset data
reset_data = AccountService.get_reset_password_data(args["token"])
reset_data = AccountService.get_reset_password_data(payload.token)
if not reset_data:
raise InvalidTokenError()
# Must use token in reset phase
@ -152,11 +165,11 @@ class ForgotPasswordResetApi(Resource):
raise InvalidTokenError()
# Revoke token to prevent reuse
AccountService.revoke_reset_password_token(args["token"])
AccountService.revoke_reset_password_token(payload.token)
# Generate secure salt and hash password
salt = secrets.token_bytes(16)
password_hashed = hash_password(args["new_password"], salt)
password_hashed = hash_password(payload.new_password, salt)
email = reset_data.get("email", "")
@ -170,7 +183,7 @@ class ForgotPasswordResetApi(Resource):
return {"result": "success"}
def _update_existing_account(self, account, password_hashed, salt, session):
def _update_existing_account(self, account: Account, password_hashed, salt, session):
# Update existing account credentials
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()

View File

@ -1,9 +1,12 @@
import logging
from typing import Literal
from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from flask import request
from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
AppMoreLikeThisDisabledError,
@ -38,6 +41,33 @@ from services.message_service import MessageService
logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: str = Field(description="Conversation UUID")
first_id: str | None = Field(default=None, description="First message ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
@field_validator("conversation_id", "first_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
class MessageMoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"] = Field(
description="Response mode",
)
register_schema_models(web_ns, MessageListQuery, MessageFeedbackPayload, MessageMoreLikeThisQuery)
@web_ns.route("/messages")
class MessageListApi(WebApiResource):
message_fields = {
@ -68,7 +98,11 @@ class MessageListApi(WebApiResource):
@web_ns.doc(
params={
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": True},
"first_id": {"description": "First message ID for pagination", "type": "string", "required": False},
"first_id": {
"description": "First message ID for pagination",
"type": "string",
"required": False,
},
"limit": {
"description": "Number of messages to return (1-100)",
"type": "integer",
@ -93,17 +127,12 @@ class MessageListApi(WebApiResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = (
reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
.add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
raw_args = request.args.to_dict()
query = MessageListQuery.model_validate(raw_args)
try:
return MessageService.pagination_by_first_id(
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
app_model, end_user, query.conversation_id, query.first_id, query.limit
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@ -128,7 +157,7 @@ class MessageFeedbackApi(WebApiResource):
"enum": ["like", "dislike"],
"required": False,
},
"content": {"description": "Feedback content/comment", "type": "string", "required": False},
"content": {"description": "Feedback content", "type": "string", "required": False},
}
)
@web_ns.doc(
@ -145,20 +174,15 @@ class MessageFeedbackApi(WebApiResource):
def post(self, app_model, end_user, message_id):
message_id = str(message_id)
parser = (
reqparse.RequestParser()
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
.add_argument("content", type=str, location="json", default=None)
)
args = parser.parse_args()
payload = MessageFeedbackPayload.model_validate(web_ns.payload or {})
try:
MessageService.create_feedback(
app_model=app_model,
message_id=message_id,
user=end_user,
rating=args.get("rating"),
content=args.get("content"),
rating=payload.rating,
content=payload.content,
)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
@ -170,17 +194,7 @@ class MessageFeedbackApi(WebApiResource):
class MessageMoreLikeThisApi(WebApiResource):
@web_ns.doc("Generate More Like This")
@web_ns.doc(description="Generate a new completion similar to an existing message (completion apps only).")
@web_ns.doc(
params={
"message_id": {"description": "Message UUID", "type": "string", "required": True},
"response_mode": {
"description": "Response mode",
"type": "string",
"enum": ["blocking", "streaming"],
"required": True,
},
}
)
@web_ns.expect(web_ns.models[MessageMoreLikeThisQuery.__name__])
@web_ns.doc(
responses={
200: "Success",
@ -197,12 +211,10 @@ class MessageMoreLikeThisApi(WebApiResource):
message_id = str(message_id)
parser = reqparse.RequestParser().add_argument(
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
)
args = parser.parse_args()
raw_args = request.args.to_dict()
query = MessageMoreLikeThisQuery.model_validate(raw_args)
streaming = args["response_mode"] == "streaming"
streaming = query.response_mode == "streaming"
try:
response = AppGenerateService.generate_more_like_this(

View File

@ -1,7 +1,8 @@
import urllib.parse
import httpx
from flask_restx import marshal_with, reqparse
from flask_restx import marshal_with
from pydantic import BaseModel, Field, HttpUrl
import services
from controllers.common import helpers
@ -10,14 +11,23 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from controllers.web import web_ns
from controllers.web.wraps import WebApiResource
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
from services.file_service import FileService
from ..common.schema import register_schema_models
from . import web_ns
from .wraps import WebApiResource
class RemoteFileUploadPayload(BaseModel):
url: HttpUrl = Field(description="Remote file URL")
register_schema_models(web_ns, RemoteFileUploadPayload)
@web_ns.route("/remote-files/<path:url>")
class RemoteFileInfoApi(WebApiResource):
@ -97,10 +107,8 @@ class RemoteFileUploadApi(WebApiResource):
FileTooLargeError: File exceeds size limit
UnsupportedFileTypeError: File type not supported
"""
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
args = parser.parse_args()
url = args["url"]
payload = RemoteFileUploadPayload.model_validate(web_ns.payload or {})
url = str(payload.url)
try:
resp = ssrf_proxy.head(url=url)

View File

@ -105,8 +105,9 @@ class BaseAppGenerator:
variable_entity.type in {VariableEntityType.FILE, VariableEntityType.FILE_LIST}
and not variable_entity.required
):
# Treat empty string (frontend default) or empty list as unset
if not value and isinstance(value, (str, list)):
# Treat empty string (frontend default) as unset
# For FILE_LIST, allow empty list [] to pass through
if isinstance(value, str) and not value:
return None
if variable_entity.type in {

View File

@ -72,6 +72,22 @@ def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client:
)
def _get_user_provided_host_header(headers: dict | None) -> str | None:
"""
Extract the user-provided Host header from the headers dict.
This is needed because when using a forward proxy, httpx may override the Host header.
We preserve the user's explicit Host header to support virtual hosting and other use cases.
"""
if not headers:
return None
# Case-insensitive lookup for Host header
for key, value in headers.items():
if key.lower() == "host":
return value
return None
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
@ -90,10 +106,24 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
client = _get_ssrf_client(verify_option)
# Preserve user-provided Host header
# When using a forward proxy, httpx may override the Host header based on the URL.
# We extract and preserve any explicitly set Host header to support virtual hosting.
headers = kwargs.get("headers", {})
user_provided_host = _get_user_provided_host_header(headers)
retries = 0
while retries <= max_retries:
try:
# Build the request manually to preserve the Host header
# httpx may override the Host header when using a proxy, so we use
# the request API to explicitly set headers before sending
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
if user_provided_host is not None:
headers["host"] = user_provided_host
kwargs["headers"] = headers
response = client.request(method=method, url=url, **kwargs)
# Check for SSRF protection by Squid proxy
if response.status_code in (401, 403):
# Check if this is a Squid SSRF rejection

View File

@ -396,7 +396,7 @@ class IndexingRunner:
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],

View File

@ -61,6 +61,7 @@ class SSETransport:
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.endpoint_url: str | None = None
self.event_source: EventSource | None = None
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
"""Validate that the endpoint URL matches the connection origin.
@ -237,6 +238,9 @@ class SSETransport:
write_queue: WriteQueue = queue.Queue()
status_queue: StatusQueue = queue.Queue()
# Store event_source for graceful shutdown
self.event_source = event_source
# Start SSE reader thread
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
@ -296,6 +300,13 @@ def sse_client(
logger.exception("Error connecting to SSE endpoint")
raise
finally:
# Close the SSE connection to unblock the reader thread
if transport.event_source is not None:
try:
transport.event_source.response.close()
except RuntimeError:
pass
# Clean up queues
if read_queue:
read_queue.put(None)

View File

@ -8,6 +8,7 @@ and session management.
import logging
import queue
import threading
from collections.abc import Callable, Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
@ -103,6 +104,9 @@ class StreamableHTTPTransport:
CONTENT_TYPE: JSON,
**self.headers,
}
self.stop_event = threading.Event()
self._active_responses: list[httpx.Response] = []
self._lock = threading.Lock()
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID if available."""
@ -111,6 +115,30 @@ class StreamableHTTPTransport:
headers[MCP_SESSION_ID] = self.session_id
return headers
def _register_response(self, response: httpx.Response):
"""Register a response for cleanup on shutdown."""
with self._lock:
self._active_responses.append(response)
def _unregister_response(self, response: httpx.Response):
"""Unregister a response after it's closed."""
with self._lock:
try:
self._active_responses.remove(response)
except ValueError as e:
logger.debug("Ignoring error during response unregister: %s", e)
def close_active_responses(self):
"""Close all active SSE connections to unblock threads."""
with self._lock:
responses_to_close = list(self._active_responses)
self._active_responses.clear()
for response in responses_to_close:
try:
response.close()
except RuntimeError as e:
logger.debug("Ignoring error during active response close: %s", e)
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialization request."""
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
@ -195,11 +223,21 @@ class StreamableHTTPTransport:
event_source.response.raise_for_status()
logger.debug("GET SSE connection established")
for sse in event_source.iter_sse():
self._handle_sse_event(sse, server_to_client_queue)
# Register response for cleanup
self._register_response(event_source.response)
try:
for sse in event_source.iter_sse():
if self.stop_event.is_set():
logger.debug("GET stream received stop signal")
break
self._handle_sse_event(sse, server_to_client_queue)
finally:
self._unregister_response(event_source.response)
except Exception as exc:
logger.debug("GET stream error (non-fatal): %s", exc)
if not self.stop_event.is_set():
logger.debug("GET stream error (non-fatal): %s", exc)
def _handle_resumption_request(self, ctx: RequestContext):
"""Handle a resumption request using GET with SSE."""
@ -224,15 +262,24 @@ class StreamableHTTPTransport:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")
for sse in event_source.iter_sse():
is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
break
# Register response for cleanup
self._register_response(event_source.response)
try:
for sse in event_source.iter_sse():
if self.stop_event.is_set():
logger.debug("Resumption stream received stop signal")
break
is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
break
finally:
self._unregister_response(event_source.response)
def _handle_post_request(self, ctx: RequestContext):
"""Handle a POST request with response processing."""
@ -295,17 +342,27 @@ class StreamableHTTPTransport:
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext):
"""Handle SSE response from the server."""
try:
# Register response for cleanup
self._register_response(response)
event_source = EventSource(response)
for sse in event_source.iter_sse():
is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
)
if is_complete:
break
try:
for sse in event_source.iter_sse():
if self.stop_event.is_set():
logger.debug("SSE response stream received stop signal")
break
is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
)
if is_complete:
break
finally:
self._unregister_response(response)
except Exception as e:
ctx.server_to_client_queue.put(e)
if not self.stop_event.is_set():
ctx.server_to_client_queue.put(e)
def _handle_unexpected_content_type(
self,
@ -345,6 +402,11 @@ class StreamableHTTPTransport:
"""
while True:
try:
# Check if we should stop
if self.stop_event.is_set():
logger.debug("Post writer received stop signal")
break
# Read message from client queue with timeout to check stop_event periodically
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if session_message is None:
@ -381,7 +443,8 @@ class StreamableHTTPTransport:
except queue.Empty:
continue
except Exception as exc:
server_to_client_queue.put(exc)
if not self.stop_event.is_set():
server_to_client_queue.put(exc)
def terminate_session(self, client: httpx.Client):
"""Terminate the session by sending a DELETE request."""
@ -465,6 +528,12 @@ def streamablehttp_client(
transport.get_session_id,
)
finally:
# Set stop event to signal all threads to stop
transport.stop_event.set()
# Close all active SSE connections to unblock threads
transport.close_active_responses()
if transport.session_id and terminate_on_close:
transport.terminate_session(client)

View File

@ -54,7 +54,7 @@ def generate_dotted_order(run_id: str, start_time: Union[str, datetime], parent_
generate dotted_order for langsmith
"""
start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f")[:-3] + "Z"
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f") + "Z"
current_segment = f"{timestamp}{run_id}"
if parent_dotted_order is None:

View File

@ -90,13 +90,17 @@ class Jieba(BaseKeyword):
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
documents = []
segment_query_stmt = db.session.query(DocumentSegment).where(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices)
)
if document_ids_filter:
segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter))
segments = db.session.execute(segment_query_stmt).scalars().all()
segment_map = {segment.index_node_id: segment for segment in segments}
for chunk_index in sorted_chunk_indices:
segment_query = db.session.query(DocumentSegment).where(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
)
if document_ids_filter:
segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter))
segment = segment_query.first()
segment = segment_map.get(chunk_index)
if segment:
documents.append(

View File

@ -7,6 +7,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from configs import dify_config
from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
@ -138,37 +139,47 @@ class RetrievalService:
@classmethod
def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]:
"""Deduplicate documents based on doc_id to avoid duplicate chunks in hybrid search."""
"""Deduplicate documents in O(n) while preserving first-seen order.
Rules:
- For provider == "dify" and metadata["doc_id"] exists: keep the doc with the highest
metadata["score"] among duplicates; if a later duplicate has no score, ignore it.
- For non-dify documents (or dify without doc_id): deduplicate by content key
(provider, page_content), keeping the first occurrence.
"""
if not documents:
return documents
unique_documents = []
seen_doc_ids = set()
# Map of dedup key -> chosen Document
chosen: dict[tuple, Document] = {}
# Preserve the order of first appearance of each dedup key
order: list[tuple] = []
for document in documents:
# For dify provider documents, use doc_id for deduplication
if document.provider == "dify" and document.metadata is not None and "doc_id" in document.metadata:
doc_id = document.metadata["doc_id"]
if doc_id not in seen_doc_ids:
seen_doc_ids.add(doc_id)
unique_documents.append(document)
# If duplicate, keep the one with higher score
elif "score" in document.metadata:
# Find existing document with same doc_id and compare scores
for i, existing_doc in enumerate(unique_documents):
if (
existing_doc.metadata
and existing_doc.metadata.get("doc_id") == doc_id
and existing_doc.metadata.get("score", 0) < document.metadata.get("score", 0)
):
unique_documents[i] = document
break
for doc in documents:
is_dify = doc.provider == "dify"
doc_id = (doc.metadata or {}).get("doc_id") if is_dify else None
if is_dify and doc_id:
key = ("dify", doc_id)
if key not in chosen:
chosen[key] = doc
order.append(key)
else:
# Only replace if the new one has a score and it's strictly higher
if "score" in doc.metadata:
new_score = float(doc.metadata.get("score", 0.0))
old_score = float(chosen[key].metadata.get("score", 0.0)) if chosen[key].metadata else 0.0
if new_score > old_score:
chosen[key] = doc
else:
# For non-dify documents, use content-based deduplication
if document not in unique_documents:
unique_documents.append(document)
# Content-based dedup for non-dify or dify without doc_id
content_key = (doc.provider or "dify", doc.page_content)
if content_key not in chosen:
chosen[content_key] = doc
order.append(content_key)
# If duplicate content appears, we keep the first occurrence (no score comparison)
return unique_documents
return [chosen[k] for k in order]
@classmethod
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
@ -371,58 +382,96 @@ class RetrievalService:
include_segment_ids = set()
segment_child_map = {}
segment_file_map = {}
with Session(bind=db.engine, expire_on_commit=False) as session:
# Process documents
for document in documents:
segment_id = None
attachment_info = None
child_chunk = None
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
dataset_document = dataset_documents[document_id]
if not dataset_document:
continue
valid_dataset_documents = {}
image_doc_ids = []
child_index_node_ids = []
index_node_ids = []
doc_to_document_map = {}
for document in documents:
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
# Handle parent-child documents
if document.metadata.get("doc_type") == DocType.IMAGE:
attachment_info_dict = cls.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attachment_info"]
segment_id = attachment_info_dict["segment_id"]
else:
child_index_node_id = document.metadata.get("doc_id")
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = session.scalar(child_chunk_stmt)
dataset_document = dataset_documents[document_id]
if not dataset_document:
continue
valid_dataset_documents[document_id] = dataset_document
if not child_chunk:
continue
segment_id = child_chunk.segment_id
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
doc_id = document.metadata.get("doc_id") or ""
doc_to_document_map[doc_id] = document
if document.metadata.get("doc_type") == DocType.IMAGE:
image_doc_ids.append(doc_id)
else:
child_index_node_ids.append(doc_id)
else:
doc_id = document.metadata.get("doc_id") or ""
doc_to_document_map[doc_id] = document
if document.metadata.get("doc_type") == DocType.IMAGE:
image_doc_ids.append(doc_id)
else:
index_node_ids.append(doc_id)
if not segment_id:
continue
image_doc_ids = [i for i in image_doc_ids if i]
child_index_node_ids = [i for i in child_index_node_ids if i]
index_node_ids = [i for i in index_node_ids if i]
segment = (
session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
.first()
)
segment_ids = []
index_node_segments: list[DocumentSegment] = []
segments: list[DocumentSegment] = []
attachment_map = {}
child_chunk_map = {}
doc_segment_map = {}
if not segment:
continue
with session_factory.create_session() as session:
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
for attachment in attachments:
segment_ids.append(attachment["segment_id"])
attachment_map[attachment["segment_id"]] = attachment
doc_segment_map[attachment["segment_id"]] = attachment["attachment_id"]
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
for i in child_index_nodes:
segment_ids.append(i.segment_id)
child_chunk_map[i.segment_id] = i
doc_segment_map[i.segment_id] = i.index_node_id
if index_node_ids:
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id.in_(index_node_ids),
)
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
for index_node_segment in index_node_segments:
doc_segment_map[index_node_segment.id] = index_node_segment.index_node_id
if segment_ids:
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id.in_(segment_ids),
)
segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
if index_node_segments:
segments.extend(index_node_segments)
for segment in segments:
doc_id = doc_segment_map.get(segment.id)
child_chunk = child_chunk_map.get(segment.id)
attachment_info = attachment_map.get(segment.id)
if doc_id:
document = doc_to_document_map[doc_id]
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(
document.metadata.get("document_id")
)
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
if child_chunk:
@ -430,10 +479,10 @@ class RetrievalService:
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
"score": document.metadata.get("score", 0.0) if document else 0.0,
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
@ -452,13 +501,14 @@ class RetrievalService:
"score": document.metadata.get("score", 0.0),
}
if segment.id in segment_child_map:
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) # type: ignore
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
segment_child_map[segment.id]["max_score"],
document.metadata.get("score", 0.0) if document else 0.0,
)
else:
segment_child_map[segment.id] = {
"max_score": document.metadata.get("score", 0.0),
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
"child_chunks": [child_chunk_detail],
}
if attachment_info:
@ -467,46 +517,11 @@ class RetrievalService:
else:
segment_file_map[segment.id] = [attachment_info]
else:
# Handle normal documents
segment = None
if document.metadata.get("doc_type") == DocType.IMAGE:
attachment_info_dict = cls.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attachment_info"]
segment_id = attachment_info_dict["segment_id"]
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
segment = session.scalar(document_segment_stmt)
if segment:
segment_file_map[segment.id] = [attachment_info]
else:
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = session.scalar(document_segment_stmt)
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score"), # type: ignore
"score": document.metadata.get("score", 0.0), # type: ignore
}
if attachment_info:
segment_file_map[segment.id] = [attachment_info]
@ -522,7 +537,7 @@ class RetrievalService:
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"]
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
if record["segment"].id in segment_file_map:
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
@ -565,6 +580,8 @@ class RetrievalService:
flask_app: Flask,
retrieval_method: RetrievalMethod,
dataset: Dataset,
all_documents: list[Document],
exceptions: list[str],
query: str | None = None,
top_k: int = 4,
score_threshold: float | None = 0.0,
@ -573,8 +590,6 @@ class RetrievalService:
weights: dict | None = None,
document_ids_filter: list[str] | None = None,
attachment_id: str | None = None,
all_documents: list[Document] = [],
exceptions: list[str] = [],
):
if not query and not attachment_id:
return
@ -696,3 +711,37 @@ class RetrievalService:
}
return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id}
return None
@classmethod
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
attachment_infos = []
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
if upload_files:
upload_file_ids = [upload_file.id for upload_file in upload_files]
attachment_bindings = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
.all()
)
attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings}
if attachment_bindings:
for upload_file in upload_files:
attachment_binding = attachment_binding_map.get(upload_file.id)
attachment_info = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
"mime_type": upload_file.mime_type,
"source_url": sign_upload_file(upload_file.id, upload_file.extension),
"size": upload_file.size,
}
if attachment_binding:
attachment_infos.append(
{
"attachment_id": attachment_binding.attachment_id,
"attachment_info": attachment_info,
"segment_id": attachment_binding.segment_id,
}
)
return attachment_infos

View File

@ -289,7 +289,8 @@ class OracleVector(BaseVector):
words = pseg.cut(query)
current_entity = ""
for word, pos in words:
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名ns: 地名nt: 机构名
# `nr`: Person, `ns`: Location, `nt`: Organization
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}:
current_entity += word
else:
if current_entity:

View File

@ -213,7 +213,7 @@ class VastbaseVector(BaseVector):
with self._get_cursor() as cur:
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
# Vastbase 支持的向量维度取值范围为 [1,16000]
# Vastbase supports vector dimensions in the range [1, 16,000]
if dimension <= 16000:
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
redis_client.set(collection_exist_cache_key, 1, ex=3600)

View File

@ -25,7 +25,7 @@ class FirecrawlApp:
}
if params:
json_data.update(params)
response = self._post_request(f"{self.base_url}/v2/scrape", json_data, headers)
response = self._post_request(self._build_url("v2/scrape"), json_data, headers)
if response.status_code == 200:
response_data = response.json()
data = response_data["data"]
@ -42,7 +42,7 @@ class FirecrawlApp:
json_data = {"url": url}
if params:
json_data.update(params)
response = self._post_request(f"{self.base_url}/v2/crawl", json_data, headers)
response = self._post_request(self._build_url("v2/crawl"), json_data, headers)
if response.status_code == 200:
# There's also another two fields in the response: "success" (bool) and "url" (str)
job_id = response.json().get("id")
@ -58,7 +58,7 @@ class FirecrawlApp:
if params:
# Pass through provided params, including optional "sitemap": "only" | "include" | "skip"
json_data.update(params)
response = self._post_request(f"{self.base_url}/v2/map", json_data, headers)
response = self._post_request(self._build_url("v2/map"), json_data, headers)
if response.status_code == 200:
return cast(dict[str, Any], response.json())
elif response.status_code in {402, 409, 500, 429, 408}:
@ -69,7 +69,7 @@ class FirecrawlApp:
def check_crawl_status(self, job_id) -> dict[str, Any]:
headers = self._prepare_headers()
response = self._get_request(f"{self.base_url}/v2/crawl/{job_id}", headers)
response = self._get_request(self._build_url(f"v2/crawl/{job_id}"), headers)
if response.status_code == 200:
crawl_status_response = response.json()
if crawl_status_response.get("status") == "completed":
@ -120,6 +120,10 @@ class FirecrawlApp:
def _prepare_headers(self) -> dict[str, Any]:
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _build_url(self, path: str) -> str:
# ensure exactly one slash between base and path, regardless of user-provided base_url
return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
for attempt in range(retries):
response = httpx.post(url, headers=headers, json=data)
@ -139,7 +143,11 @@ class FirecrawlApp:
return response
def _handle_error(self, response, action):
error_message = response.json().get("error", "Unknown error occurred")
try:
payload = response.json()
error_message = payload.get("error") or payload.get("message") or response.text or "Unknown error occurred"
except json.JSONDecodeError:
error_message = response.text or "Unknown error occurred"
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return]
def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
@ -160,7 +168,7 @@ class FirecrawlApp:
}
if params:
json_data.update(params)
response = self._post_request(f"{self.base_url}/v2/search", json_data, headers)
response = self._post_request(self._build_url("v2/search"), json_data, headers)
if response.status_code == 200:
response_data = response.json()
if not response_data.get("success"):

View File

@ -48,13 +48,21 @@ class NotionExtractor(BaseExtractor):
if notion_access_token:
self._notion_access_token = notion_access_token
else:
self._notion_access_token = self._get_access_token(tenant_id, self._credential_id)
if not self._notion_access_token:
try:
self._notion_access_token = self._get_access_token(tenant_id, self._credential_id)
except Exception as e:
logger.warning(
(
"Failed to get Notion access token from datasource credentials: %s, "
"falling back to environment variable NOTION_INTEGRATION_TOKEN"
),
e,
)
integration_token = dify_config.NOTION_INTEGRATION_TOKEN
if integration_token is None:
raise ValueError(
"Must specify `integration_token` or set environment variable `NOTION_INTEGRATION_TOKEN`."
)
) from e
self._notion_access_token = integration_token

View File

@ -231,7 +231,7 @@ class BaseIndexProcessor(ABC):
if not filename:
parsed_url = urlparse(image_url)
# unquote 处理 URL 中的中文
# Decode percent-encoded characters in the URL path.
path = unquote(parsed_url.path)
filename = os.path.basename(path)

View File

@ -151,20 +151,14 @@ class DatasetRetrieval:
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)
# pass if dataset is not available
if not dataset:
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
datasets: list[Dataset] = db.session.execute(dataset_stmt).scalars().all() # type: ignore
for dataset in datasets:
if dataset.available_document_count == 0 and dataset.provider != "external":
continue
# pass if dataset is not available
if dataset and dataset.available_document_count == 0 and dataset.provider != "external":
continue
available_datasets.append(dataset)
if inputs:
inputs = {key: str(value) for key, value in inputs.items()}
else:
@ -282,26 +276,35 @@ class DatasetRetrieval:
)
context_files.append(attachment_info)
if show_retrieve_source:
dataset_ids = [record.segment.dataset_id for record in records]
document_ids = [record.segment.document_id for record in records]
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id.in_(document_ids),
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
documents = db.session.execute(dataset_document_stmt).scalars().all() # type: ignore
dataset_stmt = select(Dataset).where(
Dataset.id.in_(dataset_ids),
)
datasets = db.session.execute(dataset_stmt).scalars().all() # type: ignore
dataset_map = {i.id: i for i in datasets}
document_map = {i.id: i for i in documents}
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
document = db.session.scalar(dataset_document_stmt)
if dataset and document:
dataset_item = dataset_map.get(segment.dataset_id)
document_item = document_map.get(segment.document_id)
if dataset_item and document_item:
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
dataset_id=dataset_item.id,
dataset_name=dataset_item.name,
document_id=document_item.id,
document_name=document_item.name,
data_source_type=document_item.data_source_type,
segment_id=segment.id,
retriever_from=invoke_from.to_source(),
score=record.score or 0.0,
doc_metadata=document.doc_metadata,
doc_metadata=document_item.doc_metadata,
)
if invoke_from.to_source() == "dev":

View File

@ -153,11 +153,11 @@ class ToolInvokeMessage(BaseModel):
@classmethod
def transform_variable_value(cls, values):
"""
Only basic types and lists are allowed.
Only basic types, lists, and None are allowed.
"""
value = values.get("variable_value")
if not isinstance(value, dict | list | str | int | float | bool):
raise ValueError("Only basic types and lists are allowed.")
if value is not None and not isinstance(value, dict | list | str | int | float | bool):
raise ValueError("Only basic types, lists, and None are allowed.")
# if stream is true, the value must be a string
if values.get("stream"):

View File

@ -67,12 +67,16 @@ def create_trigger_provider_encrypter_for_subscription(
def delete_cache_for_subscription(tenant_id: str, provider_id: str, subscription_id: str):
cache = TriggerProviderCredentialsCache(
TriggerProviderCredentialsCache(
tenant_id=tenant_id,
provider_id=provider_id,
credential_id=subscription_id,
)
cache.delete()
).delete()
TriggerProviderPropertiesCache(
tenant_id=tenant_id,
provider_id=provider_id,
subscription_id=subscription_id,
).delete()
def create_trigger_provider_encrypter_for_properties(

View File

@ -247,6 +247,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
DATASOURCE_INFO = "datasource_info"
COMPLETED_REASON = "completed_reason" # completed reason for loop node
class WorkflowNodeExecutionStatus(StrEnum):

View File

@ -1,3 +1,4 @@
from enum import StrEnum
from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field, field_validator
@ -96,3 +97,8 @@ class LoopState(BaseLoopState):
Get current output.
"""
return self.current_output
class LoopCompletedReason(StrEnum):
LOOP_BREAK = "loop_break"
LOOP_COMPLETED = "loop_completed"

View File

@ -29,7 +29,7 @@ from core.workflow.node_events import (
)
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
from core.workflow.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
from core.workflow.utils.condition.processor import ConditionProcessor
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
from libs.datetime_utils import naive_utc_now
@ -96,6 +96,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
loop_usage = LLMUsage.empty_usage()
loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
# Start Loop event
yield LoopStartedEvent(
@ -118,6 +119,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
loop_count = 0
for i in range(loop_count):
# Clear stale variables from previous loop iterations to avoid streaming old values
self._clear_loop_subgraph_variables(loop_node_ids)
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
loop_start_time = naive_utc_now()
@ -177,7 +180,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
WorkflowNodeExecutionMetadataKey.COMPLETED_REASON: (
LoopCompletedReason.LOOP_BREAK
if reach_break_condition
else LoopCompletedReason.LOOP_COMPLETED.value
),
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
@ -274,6 +281,17 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata:
event.node_run_result.metadata = {**current_metadata, **loop_metadata}
def _clear_loop_subgraph_variables(self, loop_node_ids: set[str]) -> None:
"""
Remove variables produced by loop sub-graph nodes from previous iterations.
Keeping stale variables causes a freshly created response coordinator in the
next iteration to fall back to outdated values when no stream chunks exist.
"""
variable_pool = self.graph_runtime_state.variable_pool
for node_id in loop_node_ids:
variable_pool.remove([node_id])
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@ -281,7 +281,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
# handle invoke result
text = invoke_result.message.content or ""
text = invoke_result.message.get_text_content()
if not isinstance(text, str):
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")

View File

@ -26,6 +26,7 @@ class AliyunOssStorage(BaseStorage):
self.bucket_name,
connect_timeout=30,
region=region,
cloudbox_id=dify_config.ALIYUN_CLOUDBOX_ID,
)
def save(self, filename, data):

View File

@ -17,6 +17,7 @@ class HuaweiObsStorage(BaseStorage):
access_key_id=dify_config.HUAWEI_OBS_ACCESS_KEY,
secret_access_key=dify_config.HUAWEI_OBS_SECRET_KEY,
server=dify_config.HUAWEI_OBS_SERVER,
path_style=dify_config.HUAWEI_OBS_PATH_STYLE,
)
def save(self, filename, data):

View File

@ -69,7 +69,7 @@ dependencies = [
"pydantic-extra-types~=2.10.3",
"pydantic-settings~=2.11.0",
"pyjwt~=2.10.1",
"pypdfium2==4.30.0",
"pypdfium2==5.2.0",
"python-docx~=1.1.0",
"python-dotenv==1.0.1",
"pyyaml~=6.0.1",

View File

@ -155,6 +155,7 @@ class AppDslService:
parsed_url.scheme == "https"
and parsed_url.netloc == "github.com"
and parsed_url.path.endswith((".yml", ".yaml"))
and "/blob/" in parsed_url.path
):
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
yaml_url = yaml_url.replace("/blob/", "/")

View File

@ -26,7 +26,7 @@ class FirecrawlAuth(ApiKeyAuthBase):
"limit": 1,
"scrapeOptions": {"onlyMainContent": True},
}
response = self._post_request(f"{self.base_url}/v1/crawl", options, headers)
response = self._post_request(self._build_url("v1/crawl"), options, headers)
if response.status_code == 200:
return True
else:
@ -35,15 +35,17 @@ class FirecrawlAuth(ApiKeyAuthBase):
def _prepare_headers(self):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _build_url(self, path: str) -> str:
# ensure exactly one slash between base and path, regardless of user-provided base_url
return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
def _post_request(self, url, data, headers):
return httpx.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
else:
if response.text:
error_message = json.loads(response.text).get("error", "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")
try:
payload = response.json()
except json.JSONDecodeError:
payload = {}
error_message = payload.get("error") or payload.get("message") or (response.text or "Unknown error occurred")
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")

View File

@ -6,7 +6,9 @@ from typing import Any, Union
from sqlalchemy import asc, desc, func, or_, select
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.db.session_factory import session_factory
from core.llm_generator.llm_generator import LLMGenerator
from core.variables.types import SegmentType
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
@ -202,6 +204,7 @@ class ConversationService:
user: Union[Account, EndUser] | None,
limit: int,
last_id: str | None,
variable_name: str | None = None,
) -> InfiniteScrollPagination:
conversation = cls.get_conversation(app_model, conversation_id, user)
@ -212,7 +215,25 @@ class ConversationService:
.order_by(ConversationVariable.created_at)
)
with Session(db.engine) as session:
# Apply variable_name filter if provided
if variable_name:
# Filter using JSON extraction to match variable names case-insensitively
escaped_variable_name = variable_name.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
# Filter using JSON extraction to match variable names case-insensitively
if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]:
stmt = stmt.where(
func.json_extract(ConversationVariable.data, "$.name").ilike(
f"%{escaped_variable_name}%", escape="\\"
)
)
elif dify_config.DB_TYPE == "postgresql":
stmt = stmt.where(
func.json_extract_path_text(ConversationVariable.data, "name").ilike(
f"%{escaped_variable_name}%", escape="\\"
)
)
with session_factory.create_session() as session:
if last_id:
last_variable = session.scalar(stmt.where(ConversationVariable.id == last_id))
if not last_variable:
@ -279,7 +300,7 @@ class ConversationService:
.where(ConversationVariable.id == variable_id)
)
with Session(db.engine) as session:
with session_factory.create_session() as session:
existing_variable = session.scalar(stmt)
if not existing_variable:
raise ConversationVariableNotExistsError()

View File

@ -105,3 +105,49 @@ class PluginParameterService:
)
.options
)
@staticmethod
def get_dynamic_select_options_with_credentials(
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
action: str,
parameter: str,
credential_id: str,
credentials: Mapping[str, Any],
) -> Sequence[PluginParameterOption]:
"""
Get dynamic select options using provided credentials directly.
Used for edit mode when credentials have been modified but not yet saved.
Security: credential_id is validated against tenant_id to ensure
users can only access their own credentials.
"""
from constants import HIDDEN_VALUE
# Get original subscription to replace hidden values (with tenant_id check for security)
original_subscription = TriggerProviderService.get_subscription_by_id(tenant_id, credential_id)
if not original_subscription:
raise ValueError(f"Subscription {credential_id} not found")
# Replace [__HIDDEN__] with original values
resolved_credentials: dict[str, Any] = {
key: (original_subscription.credentials.get(key) if value == HIDDEN_VALUE else value)
for key, value in credentials.items()
}
return (
DynamicSelectClient()
.fetch_dynamic_select_options(
tenant_id,
user_id,
plugin_id,
provider,
action,
resolved_credentials,
CredentialType.API_KEY.value,
parameter,
)
.options
)

View File

@ -94,16 +94,23 @@ class TriggerProviderService:
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
for subscription in subscriptions:
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
subscription.credentials = dict(
encrypter.mask_credentials(dict(encrypter.decrypt(subscription.credentials)))
credential_encrypter.mask_credentials(dict(credential_encrypter.decrypt(subscription.credentials)))
)
subscription.properties = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.properties))))
subscription.parameters = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.parameters))))
properties_encrypter, _ = create_trigger_provider_encrypter_for_properties(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
subscription.properties = dict(
properties_encrypter.mask_credentials(dict(properties_encrypter.decrypt(subscription.properties)))
)
subscription.parameters = dict(subscription.parameters)
count = workflows_in_use_map.get(subscription.id)
subscription.workflows_in_use = count if count is not None else 0
@ -209,6 +216,101 @@ class TriggerProviderService:
logger.exception("Failed to add trigger provider")
raise ValueError(str(e))
@classmethod
def update_trigger_subscription(
cls,
tenant_id: str,
subscription_id: str,
name: str | None = None,
properties: Mapping[str, Any] | None = None,
parameters: Mapping[str, Any] | None = None,
credentials: Mapping[str, Any] | None = None,
credential_expires_at: int | None = None,
expires_at: int | None = None,
) -> None:
"""
Update an existing trigger subscription.
:param tenant_id: Tenant ID
:param subscription_id: Subscription instance ID
:param name: Optional new name for this subscription
:param properties: Optional new properties
:param parameters: Optional new parameters
:param credentials: Optional new credentials
:param credential_expires_at: Optional new credential expiration timestamp
:param expires_at: Optional new expiration timestamp
:return: Success response with updated subscription info
"""
with Session(db.engine, expire_on_commit=False) as session:
# Use distributed lock to prevent race conditions on the same subscription
lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}"
with redis_client.lock(lock_key, timeout=20):
subscription: TriggerSubscription | None = (
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
)
if not subscription:
raise ValueError(f"Trigger subscription {subscription_id} not found")
provider_id = TriggerProviderID(subscription.provider_id)
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
# Check for name uniqueness if name is being updated
if name is not None and name != subscription.name:
existing = (
session.query(TriggerSubscription)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
.first()
)
if existing:
raise ValueError(f"Subscription name '{name}' already exists for this provider")
subscription.name = name
# Update properties if provided
if properties is not None:
properties_encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=provider_controller.get_properties_schema(),
cache=NoOpProviderCredentialCache(),
)
# Handle hidden values - preserve original encrypted values
original_properties = properties_encrypter.decrypt(subscription.properties)
new_properties: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else original_properties.get(key, UNKNOWN_VALUE)
for key, value in properties.items()
}
subscription.properties = dict(properties_encrypter.encrypt(new_properties))
# Update parameters if provided
if parameters is not None:
subscription.parameters = dict(parameters)
# Update credentials if provided
if credentials is not None:
credential_type = CredentialType.of(subscription.credential_type)
credential_encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=provider_controller.get_credential_schema_config(credential_type),
cache=NoOpProviderCredentialCache(),
)
subscription.credentials = dict(credential_encrypter.encrypt(dict(credentials)))
# Update credential expiration timestamp if provided
if credential_expires_at is not None:
subscription.credential_expires_at = credential_expires_at
# Update expiration timestamp if provided
if expires_at is not None:
subscription.expires_at = expires_at
session.commit()
# Clear subscription cache
delete_cache_for_subscription(
tenant_id=tenant_id,
provider_id=subscription.provider_id,
subscription_id=subscription.id,
)
@classmethod
def get_subscription_by_id(cls, tenant_id: str, subscription_id: str | None = None) -> TriggerSubscription | None:
"""
@ -257,17 +359,18 @@ class TriggerProviderService:
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
credential_type: CredentialType = CredentialType.of(subscription.credential_type)
provider_id = TriggerProviderID(subscription.provider_id)
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
is_auto_created: bool = credential_type in [CredentialType.OAUTH2, CredentialType.API_KEY]
if is_auto_created:
provider_id = TriggerProviderID(subscription.provider_id)
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=tenant_id, provider_id=provider_id
)
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
tenant_id=tenant_id,
controller=provider_controller,
subscription=subscription,
)
try:
TriggerManager.unsubscribe_trigger(
tenant_id=tenant_id,
@ -280,8 +383,8 @@ class TriggerProviderService:
except Exception as e:
logger.exception("Error unsubscribing trigger", exc_info=e)
# Clear cache
session.delete(subscription)
# Clear cache
delete_cache_for_subscription(
tenant_id=tenant_id,
provider_id=subscription.provider_id,
@ -688,3 +791,125 @@ class TriggerProviderService:
)
subscription.properties = dict(properties_encrypter.decrypt(subscription.properties))
return subscription
@classmethod
def verify_subscription_credentials(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
subscription_id: str,
credentials: Mapping[str, Any],
) -> dict[str, Any]:
"""
Verify credentials for an existing subscription without updating it.
This is used in edit mode to validate new credentials before rebuild.
:param tenant_id: Tenant ID
:param user_id: User ID
:param provider_id: Provider identifier
:param subscription_id: Subscription ID
:param credentials: New credentials to verify
:return: dict with 'verified' boolean
"""
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
if not provider_controller:
raise ValueError(f"Provider {provider_id} not found")
subscription = cls.get_subscription_by_id(
tenant_id=tenant_id,
subscription_id=subscription_id,
)
if not subscription:
raise ValueError(f"Subscription {subscription_id} not found")
credential_type = CredentialType.of(subscription.credential_type)
# For API Key, validate the new credentials
if credential_type == CredentialType.API_KEY:
new_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
try:
provider_controller.validate_credentials(user_id, credentials=new_credentials)
return {"verified": True}
except Exception as e:
raise ValueError(f"Invalid credentials: {e}") from e
return {"verified": True}
@classmethod
def rebuild_trigger_subscription(
cls,
tenant_id: str,
provider_id: TriggerProviderID,
subscription_id: str,
credentials: Mapping[str, Any],
parameters: Mapping[str, Any],
name: str | None = None,
) -> None:
"""
Create a subscription builder for rebuilding an existing subscription.
This method creates a builder pre-filled with data from the rebuild request,
keeping the same subscription_id and endpoint_id so the webhook URL remains unchanged.
:param tenant_id: Tenant ID
:param name: Name for the subscription
:param subscription_id: Subscription ID
:param provider_id: Provider identifier
:param credentials: Credentials for the subscription
:param parameters: Parameters for the subscription
:return: SubscriptionBuilderApiEntity
"""
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
if not provider_controller:
raise ValueError(f"Provider {provider_id} not found")
subscription = TriggerProviderService.get_subscription_by_id(
tenant_id=tenant_id,
subscription_id=subscription_id,
)
if not subscription:
raise ValueError(f"Subscription {subscription_id} not found")
credential_type = CredentialType.of(subscription.credential_type)
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
raise ValueError("Credential type not supported for rebuild")
# TODO: Trying to invoke update api of the plugin trigger provider
# FALLBACK: If the update api is not implemented, delete the previous subscription and create a new one
# Delete the previous subscription
user_id = subscription.user_id
TriggerManager.unsubscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
subscription=subscription.to_entity(),
credentials=subscription.credentials,
credential_type=credential_type,
)
# Create a new subscription with the same subscription_id and endpoint_id
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
parameters=parameters,
credentials=credentials,
credential_type=credential_type,
)
TriggerProviderService.update_trigger_subscription(
tenant_id=tenant_id,
subscription_id=subscription.id,
name=name,
parameters=parameters,
credentials=credentials,
properties=new_subscription.properties,
expires_at=new_subscription.expires_at,
)

View File

@ -453,11 +453,12 @@ class TriggerSubscriptionBuilderService:
if not subscription_builder:
return None
# response to validation endpoint
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=subscription_builder.tenant_id, provider_id=TriggerProviderID(subscription_builder.provider_id)
)
try:
# response to validation endpoint
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
tenant_id=subscription_builder.tenant_id,
provider_id=TriggerProviderID(subscription_builder.provider_id),
)
dispatch_response: TriggerDispatchResponse = controller.dispatch(
request=request,
subscription=subscription_builder.to_subscription(),

View File

@ -163,34 +163,17 @@ class TestActivateApi:
"account": mock_account,
}
@pytest.fixture
def mock_token_pair(self):
"""Create mock token pair object."""
token_pair = MagicMock()
token_pair.access_token = "access_token"
token_pair.refresh_token = "refresh_token"
token_pair.csrf_token = "csrf_token"
token_pair.model_dump.return_value = {
"access_token": "access_token",
"refresh_token": "refresh_token",
"csrf_token": "csrf_token",
}
return token_pair
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
def test_successful_account_activation(
self,
mock_login,
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
mock_invitation,
mock_account,
mock_token_pair,
):
"""
Test successful account activation.
@ -198,12 +181,10 @@ class TestActivateApi:
Verifies that:
- Account is activated with user preferences
- Account status is set to ACTIVE
- User is logged in after activation
- Invitation token is revoked
"""
# Arrange
mock_get_invitation.return_value = mock_invitation
mock_login.return_value = mock_token_pair
# Act
with app.test_request_context(
@ -230,7 +211,6 @@ class TestActivateApi:
assert mock_account.initialized_at is not None
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
mock_db.session.commit.assert_called_once()
mock_login.assert_called_once()
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
def test_activation_with_invalid_token(self, mock_get_invitation, app):
@ -264,17 +244,14 @@ class TestActivateApi:
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
def test_activation_sets_interface_theme(
self,
mock_login,
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
mock_invitation,
mock_account,
mock_token_pair,
):
"""
Test that activation sets default interface theme.
@ -284,7 +261,6 @@ class TestActivateApi:
"""
# Arrange
mock_get_invitation.return_value = mock_invitation
mock_login.return_value = mock_token_pair
# Act
with app.test_request_context(
@ -317,17 +293,14 @@ class TestActivateApi:
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
def test_activation_with_different_locales(
self,
mock_login,
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
mock_invitation,
mock_account,
mock_token_pair,
language,
timezone,
):
@ -341,7 +314,6 @@ class TestActivateApi:
"""
# Arrange
mock_get_invitation.return_value = mock_invitation
mock_login.return_value = mock_token_pair
# Act
with app.test_request_context(
@ -367,27 +339,23 @@ class TestActivateApi:
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
def test_activation_returns_token_data(
def test_activation_returns_success_response(
self,
mock_login,
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
mock_invitation,
mock_token_pair,
):
"""
Test that activation returns authentication tokens.
Test that activation returns a success response without authentication tokens.
Verifies that:
- Token pair is returned in response
- All token types are included (access, refresh, csrf)
- Response contains a success result
- No token data is returned
"""
# Arrange
mock_get_invitation.return_value = mock_invitation
mock_login.return_value = mock_token_pair
# Act
with app.test_request_context(
@ -406,24 +374,18 @@ class TestActivateApi:
response = api.post()
# Assert
assert "data" in response
assert response["data"]["access_token"] == "access_token"
assert response["data"]["refresh_token"] == "refresh_token"
assert response["data"]["csrf_token"] == "csrf_token"
assert response == {"result": "success"}
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
def test_activation_without_workspace_id(
self,
mock_login,
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
mock_invitation,
mock_token_pair,
):
"""
Test account activation without workspace_id.
@ -434,7 +396,6 @@ class TestActivateApi:
"""
# Arrange
mock_get_invitation.return_value = mock_invitation
mock_login.return_value = mock_token_pair
# Act
with app.test_request_context(

View File

@ -0,0 +1,236 @@
from __future__ import annotations
import builtins
import uuid
from datetime import UTC, datetime
from unittest.mock import MagicMock
import pytest
from flask import Flask
from flask.views import MethodView as FlaskMethodView
_NEEDS_METHOD_VIEW_CLEANUP = False
if not hasattr(builtins, "MethodView"):
builtins.MethodView = FlaskMethodView
_NEEDS_METHOD_VIEW_CLEANUP = True
from constants import HIDDEN_VALUE
from controllers.console.extension import (
APIBasedExtensionAPI,
APIBasedExtensionDetailAPI,
CodeBasedExtensionAPI,
)
if _NEEDS_METHOD_VIEW_CLEANUP:
delattr(builtins, "MethodView")
from models.account import AccountStatus
from models.api_based_extension import APIBasedExtension
def _make_extension(
*,
name: str = "Sample Extension",
api_endpoint: str = "https://example.com/api",
api_key: str = "super-secret-key",
) -> APIBasedExtension:
extension = APIBasedExtension(
tenant_id="tenant-123",
name=name,
api_endpoint=api_endpoint,
api_key=api_key,
)
extension.id = f"{uuid.uuid4()}"
extension.created_at = datetime.now(tz=UTC)
return extension
@pytest.fixture(autouse=True)
def _mock_console_guards(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
"""Bypass console decorators so handlers can run in isolation."""
import controllers.console.extension as extension_module
from controllers.console import wraps as wraps_module
account = MagicMock()
account.status = AccountStatus.ACTIVE
account.current_tenant_id = "tenant-123"
account.id = "account-123"
account.is_authenticated = True
monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD")
monkeypatch.setattr("libs.login.dify_config.LOGIN_DISABLED", True)
monkeypatch.delenv("INIT_PASSWORD", raising=False)
monkeypatch.setattr(extension_module, "current_account_with_tenant", lambda: (account, "tenant-123"))
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (account, "tenant-123"))
# The login_required decorator consults the shared LocalProxy in libs.login.
monkeypatch.setattr("libs.login.current_user", account)
monkeypatch.setattr("libs.login.check_csrf_token", lambda *_, **__: None)
return account
@pytest.fixture(autouse=True)
def _restx_mask_defaults(app: Flask):
app.config.setdefault("RESTX_MASK_HEADER", "X-Fields")
app.config.setdefault("RESTX_MASK_SWAGGER", False)
def test_code_based_extension_get_returns_service_data(app: Flask, monkeypatch: pytest.MonkeyPatch):
service_result = {"entrypoint": "main:agent"}
service_mock = MagicMock(return_value=service_result)
monkeypatch.setattr(
"controllers.console.extension.CodeBasedExtensionService.get_code_based_extension",
service_mock,
)
with app.test_request_context(
"/console/api/code-based-extension",
method="GET",
query_string={"module": "workflow.tools"},
):
response = CodeBasedExtensionAPI().get()
assert response == {"module": "workflow.tools", "data": service_result}
service_mock.assert_called_once_with("workflow.tools")
def test_api_based_extension_get_returns_tenant_extensions(app: Flask, monkeypatch: pytest.MonkeyPatch):
extension = _make_extension(name="Weather API", api_key="abcdefghi123")
service_mock = MagicMock(return_value=[extension])
monkeypatch.setattr(
"controllers.console.extension.APIBasedExtensionService.get_all_by_tenant_id",
service_mock,
)
with app.test_request_context("/console/api/api-based-extension", method="GET"):
response = APIBasedExtensionAPI().get()
assert response[0]["id"] == extension.id
assert response[0]["name"] == "Weather API"
assert response[0]["api_endpoint"] == extension.api_endpoint
assert response[0]["api_key"].startswith(extension.api_key[:3])
service_mock.assert_called_once_with("tenant-123")
def test_api_based_extension_post_creates_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
saved_extension = _make_extension(name="Docs API", api_key="saved-secret")
save_mock = MagicMock(return_value=saved_extension)
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock)
payload = {
"name": "Docs API",
"api_endpoint": "https://docs.example.com/hook",
"api_key": "plain-secret",
}
with app.test_request_context("/console/api/api-based-extension", method="POST", json=payload):
response = APIBasedExtensionAPI().post()
args, _ = save_mock.call_args
created_extension: APIBasedExtension = args[0]
assert created_extension.tenant_id == "tenant-123"
assert created_extension.name == payload["name"]
assert created_extension.api_endpoint == payload["api_endpoint"]
assert created_extension.api_key == payload["api_key"]
assert response["name"] == saved_extension.name
save_mock.assert_called_once()
def test_api_based_extension_detail_get_fetches_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
extension = _make_extension(name="Docs API", api_key="abcdefg12345")
service_mock = MagicMock(return_value=extension)
monkeypatch.setattr(
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
service_mock,
)
extension_id = uuid.uuid4()
with app.test_request_context(f"/console/api/api-based-extension/{extension_id}", method="GET"):
response = APIBasedExtensionDetailAPI().get(extension_id)
assert response["id"] == extension.id
assert response["name"] == extension.name
service_mock.assert_called_once_with("tenant-123", str(extension_id))
def test_api_based_extension_detail_post_keeps_hidden_api_key(app: Flask, monkeypatch: pytest.MonkeyPatch):
existing_extension = _make_extension(name="Docs API", api_key="keep-me")
get_mock = MagicMock(return_value=existing_extension)
save_mock = MagicMock(return_value=existing_extension)
monkeypatch.setattr(
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
get_mock,
)
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock)
payload = {
"name": "Docs API Updated",
"api_endpoint": "https://docs.example.com/v2",
"api_key": HIDDEN_VALUE,
}
extension_id = uuid.uuid4()
with app.test_request_context(
f"/console/api/api-based-extension/{extension_id}",
method="POST",
json=payload,
):
response = APIBasedExtensionDetailAPI().post(extension_id)
assert existing_extension.name == payload["name"]
assert existing_extension.api_endpoint == payload["api_endpoint"]
assert existing_extension.api_key == "keep-me"
save_mock.assert_called_once_with(existing_extension)
assert response["name"] == payload["name"]
def test_api_based_extension_detail_post_updates_api_key_when_provided(app: Flask, monkeypatch: pytest.MonkeyPatch):
existing_extension = _make_extension(name="Docs API", api_key="old-secret")
get_mock = MagicMock(return_value=existing_extension)
save_mock = MagicMock(return_value=existing_extension)
monkeypatch.setattr(
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
get_mock,
)
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock)
payload = {
"name": "Docs API Updated",
"api_endpoint": "https://docs.example.com/v2",
"api_key": "new-secret",
}
extension_id = uuid.uuid4()
with app.test_request_context(
f"/console/api/api-based-extension/{extension_id}",
method="POST",
json=payload,
):
response = APIBasedExtensionDetailAPI().post(extension_id)
assert existing_extension.api_key == "new-secret"
save_mock.assert_called_once_with(existing_extension)
assert response["name"] == payload["name"]
def test_api_based_extension_detail_delete_removes_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
existing_extension = _make_extension()
get_mock = MagicMock(return_value=existing_extension)
delete_mock = MagicMock()
monkeypatch.setattr(
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
get_mock,
)
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.delete", delete_mock)
extension_id = uuid.uuid4()
with app.test_request_context(
f"/console/api/api-based-extension/{extension_id}",
method="DELETE",
):
response, status = APIBasedExtensionDetailAPI().delete(extension_id)
delete_mock.assert_called_once_with(existing_extension)
assert response == {"result": "success"}
assert status == 204

View File

@ -0,0 +1,195 @@
"""Unit tests for controllers.web.forgot_password endpoints."""
from __future__ import annotations
import base64
import builtins
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask.views import MethodView
# Ensure flask_restx.api finds MethodView during import.
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
def _load_controller_module():
"""Import controllers.web.forgot_password using a stub package."""
import importlib
import importlib.util
import sys
from types import ModuleType
parent_module_name = "controllers.web"
module_name = f"{parent_module_name}.forgot_password"
if parent_module_name not in sys.modules:
from flask_restx import Namespace
stub = ModuleType(parent_module_name)
stub.__file__ = "controllers/web/__init__.py"
stub.__path__ = ["controllers/web"]
stub.__package__ = "controllers"
stub.__spec__ = importlib.util.spec_from_loader(parent_module_name, loader=None, is_package=True)
stub.web_ns = Namespace("web", description="Web API", path="/")
sys.modules[parent_module_name] = stub
return importlib.import_module(module_name)
forgot_password_module = _load_controller_module()
ForgotPasswordCheckApi = forgot_password_module.ForgotPasswordCheckApi
ForgotPasswordResetApi = forgot_password_module.ForgotPasswordResetApi
ForgotPasswordSendEmailApi = forgot_password_module.ForgotPasswordSendEmailApi
@pytest.fixture
def app() -> Flask:
"""Configure a minimal Flask app for request contexts."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture(autouse=True)
def _enable_web_endpoint_guards():
"""Stub enterprise and feature toggles used by route decorators."""
features = SimpleNamespace(enable_email_password_login=True)
with (
patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True),
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=features),
):
yield
@pytest.fixture(autouse=True)
def _mock_controller_db():
"""Replace controller-level db reference with a simple stub."""
fake_db = SimpleNamespace(engine=MagicMock(name="engine"))
fake_wraps_db = SimpleNamespace(
session=MagicMock(query=MagicMock(return_value=MagicMock(first=MagicMock(return_value=True))))
)
with (
patch("controllers.web.forgot_password.db", fake_db),
patch("controllers.console.wraps.db", fake_wraps_db),
):
yield fake_db
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email", return_value="reset-token")
@patch("controllers.web.forgot_password.Session")
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="203.0.113.10")
def test_send_reset_email_success(
mock_extract_ip: MagicMock,
mock_is_ip_limit: MagicMock,
mock_session: MagicMock,
mock_send_email: MagicMock,
app: Flask,
):
"""POST /forgot-password returns token when email exists and limits allow."""
mock_account = MagicMock()
session_ctx = MagicMock()
mock_session.return_value.__enter__.return_value = session_ctx
session_ctx.execute.return_value.scalar_one_or_none.return_value = mock_account
with app.test_request_context(
"/forgot-password",
method="POST",
json={"email": "user@example.com"},
):
response = ForgotPasswordSendEmailApi().post()
assert response == {"result": "success", "data": "reset-token"}
mock_extract_ip.assert_called_once()
mock_is_ip_limit.assert_called_once_with("203.0.113.10")
mock_send_email.assert_called_once_with(account=mock_account, email="user@example.com", language="en-US")
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token", return_value=({}, "new-token"))
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit", return_value=False)
def test_check_token_success(
mock_is_rate_limited: MagicMock,
mock_get_data: MagicMock,
mock_revoke: MagicMock,
mock_generate: MagicMock,
mock_reset_limit: MagicMock,
app: Flask,
):
"""POST /forgot-password/validity validates the code and refreshes token."""
mock_get_data.return_value = {"email": "user@example.com", "code": "123456"}
with app.test_request_context(
"/forgot-password/validity",
method="POST",
json={"email": "user@example.com", "code": "123456", "token": "old-token"},
):
response = ForgotPasswordCheckApi().post()
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
mock_is_rate_limited.assert_called_once_with("user@example.com")
mock_get_data.assert_called_once_with("old-token")
mock_revoke.assert_called_once_with("old-token")
mock_generate.assert_called_once_with(
"user@example.com",
code="123456",
additional_data={"phase": "reset"},
)
mock_reset_limit.assert_called_once_with("user@example.com")
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
@patch("controllers.web.forgot_password.Session")
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_success(
mock_get_data: MagicMock,
mock_revoke_token: MagicMock,
mock_session: MagicMock,
mock_token_bytes: MagicMock,
mock_hash_password: MagicMock,
app: Flask,
):
"""POST /forgot-password/resets updates the stored password when token is valid."""
mock_get_data.return_value = {"email": "user@example.com", "phase": "reset"}
account = MagicMock()
session_ctx = MagicMock()
mock_session.return_value.__enter__.return_value = session_ctx
session_ctx.execute.return_value.scalar_one_or_none.return_value = account
with app.test_request_context(
"/forgot-password/resets",
method="POST",
json={
"token": "reset-token",
"new_password": "StrongPass123!",
"password_confirm": "StrongPass123!",
},
):
response = ForgotPasswordResetApi().post()
assert response == {"result": "success"}
mock_get_data.assert_called_once_with("reset-token")
mock_revoke_token.assert_called_once_with("reset-token")
mock_token_bytes.assert_called_once_with(16)
mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef")
expected_password = base64.b64encode(b"hashed-value").decode()
assert account.password == expected_password
expected_salt = base64.b64encode(b"0123456789abcdef").decode()
assert account.password_salt == expected_salt
session_ctx.commit.assert_called_once()

View File

@ -287,7 +287,7 @@ def test_validate_inputs_optional_file_with_empty_string():
def test_validate_inputs_optional_file_list_with_empty_list():
"""Test that optional FILE_LIST variable with empty list returns None"""
"""Test that optional FILE_LIST variable with empty list returns empty list (not None)"""
base_app_generator = BaseAppGenerator()
var_file_list = VariableEntity(
@ -302,6 +302,28 @@ def test_validate_inputs_optional_file_list_with_empty_list():
value=[],
)
# Empty list should be preserved, not converted to None
# This allows downstream components like document_extractor to handle empty lists properly
assert result == []
def test_validate_inputs_optional_file_list_with_empty_string():
"""Test that optional FILE_LIST variable with empty string returns None"""
base_app_generator = BaseAppGenerator()
var_file_list = VariableEntity(
variable="test_file_list",
label="test_file_list",
type=VariableEntityType.FILE_LIST,
required=False,
)
result = base_app_generator._validate_inputs(
variable_entity=var_file_list,
value="",
)
# Empty string should be treated as unset
assert result is None

View File

@ -96,7 +96,7 @@ class TestNotionExtractorAuthentication:
def test_init_with_integration_token_fallback(self, mock_get_token, mock_config, mock_document_model):
"""Test NotionExtractor falls back to integration token when credential not found."""
# Arrange
mock_get_token.return_value = None
mock_get_token.side_effect = Exception("No credential id found")
mock_config.NOTION_INTEGRATION_TOKEN = "integration-token-fallback"
# Act
@ -105,7 +105,7 @@ class TestNotionExtractorAuthentication:
notion_obj_id="page-456",
notion_page_type="page",
tenant_id="tenant-789",
credential_id="cred-123",
credential_id=None,
document_model=mock_document_model,
)
@ -117,7 +117,7 @@ class TestNotionExtractorAuthentication:
def test_init_missing_credentials_raises_error(self, mock_get_token, mock_config, mock_document_model):
"""Test NotionExtractor raises error when no credentials available."""
# Arrange
mock_get_token.return_value = None
mock_get_token.side_effect = Exception("No credential id found")
mock_config.NOTION_INTEGRATION_TOKEN = None
# Act & Assert
@ -127,7 +127,7 @@ class TestNotionExtractorAuthentication:
notion_obj_id="page-456",
notion_page_type="page",
tenant_id="tenant-789",
credential_id="cred-123",
credential_id=None,
document_model=mock_document_model,
)
assert "Must specify `integration_token`" in str(exc_info.value)

View File

@ -1,52 +1,109 @@
import secrets
from unittest.mock import MagicMock, patch
import pytest
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
from core.helper.ssrf_proxy import (
SSRF_DEFAULT_MAX_RETRIES,
_get_user_provided_host_header,
make_request,
)
@patch("httpx.Client.request")
def test_successful_request(mock_request):
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_successful_request(mock_get_client):
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_request.return_value = mock_response
mock_client.send.return_value = mock_response
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com")
assert response.status_code == 200
@patch("httpx.Client.request")
def test_retry_exceed_max_retries(mock_request):
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_retry_exceed_max_retries(mock_get_client):
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 500
side_effects = [mock_response] * SSRF_DEFAULT_MAX_RETRIES
mock_request.side_effect = side_effects
mock_client.send.return_value = mock_response
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
with pytest.raises(Exception) as e:
make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1)
assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
@patch("httpx.Client.request")
def test_retry_logic_success(mock_request):
side_effects = []
class TestGetUserProvidedHostHeader:
"""Tests for _get_user_provided_host_header function."""
for _ in range(SSRF_DEFAULT_MAX_RETRIES):
status_code = secrets.choice(STATUS_FORCELIST)
mock_response = MagicMock()
mock_response.status_code = status_code
side_effects.append(mock_response)
def test_returns_none_when_headers_is_none(self):
assert _get_user_provided_host_header(None) is None
mock_response_200 = MagicMock()
mock_response_200.status_code = 200
side_effects.append(mock_response_200)
def test_returns_none_when_headers_is_empty(self):
assert _get_user_provided_host_header({}) is None
mock_request.side_effect = side_effects
def test_returns_none_when_host_header_not_present(self):
headers = {"Content-Type": "application/json", "Authorization": "Bearer token"}
assert _get_user_provided_host_header(headers) is None
response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES)
def test_returns_host_header_lowercase(self):
headers = {"host": "example.com"}
assert _get_user_provided_host_header(headers) == "example.com"
def test_returns_host_header_uppercase(self):
headers = {"HOST": "example.com"}
assert _get_user_provided_host_header(headers) == "example.com"
def test_returns_host_header_mixed_case(self):
headers = {"HoSt": "example.com"}
assert _get_user_provided_host_header(headers) == "example.com"
def test_returns_host_header_from_multiple_headers(self):
headers = {"Content-Type": "application/json", "Host": "api.example.com", "Authorization": "Bearer token"}
assert _get_user_provided_host_header(headers) == "api.example.com"
def test_returns_first_host_header_when_duplicates(self):
headers = {"host": "first.com", "Host": "second.com"}
# Should return the first one encountered (iteration order is preserved in dict)
result = _get_user_provided_host_header(headers)
assert result in ("first.com", "second.com")
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_host_header_preservation_without_user_header(mock_get_client):
"""Test that when no Host header is provided, the default behavior is maintained."""
mock_client = MagicMock()
mock_request = MagicMock()
mock_request.headers = {}
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com")
assert response.status_code == 200
# Host should not be set if not provided by user
assert "Host" not in mock_request.headers or mock_request.headers.get("Host") is None
@patch("core.helper.ssrf_proxy._get_ssrf_client")
def test_host_header_preservation_with_user_header(mock_get_client):
"""Test that user-provided Host header is preserved in the request."""
mock_client = MagicMock()
mock_request = MagicMock()
mock_request.headers = {}
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.request.return_value = mock_response
mock_get_client.return_value = mock_client
custom_host = "custom.example.com:8080"
response = make_request("GET", "http://example.com", headers={"Host": custom_host})
assert response.status_code == 200
assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
assert mock_request.call_args_list[0][1].get("method") == "GET"

View File

@ -1,6 +1,9 @@
import re
from datetime import datetime
import pytest
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
from core.ops.utils import generate_dotted_order, validate_project_name, validate_url, validate_url_with_path
class TestValidateUrl:
@ -136,3 +139,51 @@ class TestValidateProjectName:
"""Test custom default name"""
result = validate_project_name("", "Custom Default")
assert result == "Custom Default"
class TestGenerateDottedOrder:
"""Test cases for generate_dotted_order function"""
def test_dotted_order_has_6_digit_microseconds(self):
"""Test that timestamp includes full 6-digit microseconds for LangSmith API compatibility.
LangSmith API expects timestamps in format: YYYYMMDDTHHMMSSffffffZ (6-digit microseconds).
Previously, the code truncated to 3 digits which caused API errors:
'cannot parse .111 as .000000'
"""
start_time = datetime(2025, 12, 23, 4, 19, 55, 111000)
run_id = "test-run-id"
result = generate_dotted_order(run_id, start_time)
# Extract timestamp portion (before the run_id)
timestamp_match = re.match(r"^(\d{8}T\d{6})(\d+)Z", result)
assert timestamp_match is not None, "Timestamp format should match YYYYMMDDTHHMMSSffffffZ"
microseconds = timestamp_match.group(2)
assert len(microseconds) == 6, f"Microseconds should be 6 digits, got {len(microseconds)}: {microseconds}"
def test_dotted_order_format_matches_langsmith_expected(self):
"""Test that dotted_order format matches LangSmith API expected format."""
start_time = datetime(2025, 1, 15, 10, 30, 45, 123456)
run_id = "abc123"
result = generate_dotted_order(run_id, start_time)
# LangSmith expects: YYYYMMDDTHHMMSSffffffZ followed by run_id
assert result == "20250115T103045123456Zabc123"
def test_dotted_order_with_parent(self):
"""Test dotted_order generation with parent order uses dot separator."""
start_time = datetime(2025, 12, 23, 4, 19, 55, 111000)
run_id = "child-run-id"
parent_order = "20251223T041955000000Zparent-run-id"
result = generate_dotted_order(run_id, start_time, parent_order)
assert result == "20251223T041955000000Zparent-run-id.20251223T041955111000Zchild-run-id"
def test_dotted_order_without_parent_has_no_dot(self):
"""Test dotted_order generation without parent has no dot separator."""
start_time = datetime(2025, 12, 23, 4, 19, 55, 111000)
run_id = "test-run-id"
result = generate_dotted_order(run_id, start_time, None)
assert "." not in result

View File

@ -1,5 +1,7 @@
import os
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
@ -25,3 +27,35 @@ def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
assert job_id is not None
assert isinstance(job_id, str)
def test_build_url_normalizes_slashes_for_crawl(mocker: MockerFixture):
api_key = "fc-"
base_urls = ["https://custom.firecrawl.dev", "https://custom.firecrawl.dev/"]
for base in base_urls:
app = FirecrawlApp(api_key=api_key, base_url=base)
mock_post = mocker.patch("httpx.post")
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"id": "job123"}
mock_post.return_value = mock_resp
app.crawl_url("https://example.com", params=None)
called_url = mock_post.call_args[0][0]
assert called_url == "https://custom.firecrawl.dev/v2/crawl"
def test_error_handler_handles_non_json_error_bodies(mocker: MockerFixture):
api_key = "fc-"
app = FirecrawlApp(api_key=api_key, base_url="https://custom.firecrawl.dev/")
mock_post = mocker.patch("httpx.post")
mock_resp = MagicMock()
mock_resp.status_code = 404
mock_resp.text = "Not Found"
mock_resp.json.side_effect = Exception("Not JSON")
mock_post.return_value = mock_resp
with pytest.raises(Exception) as excinfo:
app.scrape_url("https://example.com")
# Should not raise a JSONDecodeError; current behavior reports status code only
assert str(excinfo.value) == "Failed to scrape URL. Status code: 404"

View File

@ -1,3 +1,4 @@
import json
from unittest.mock import MagicMock, patch
import httpx
@ -110,9 +111,11 @@ class TestFirecrawlAuth:
@pytest.mark.parametrize(
("status_code", "response_text", "has_json_error", "expected_error_contains"),
[
(403, '{"error": "Forbidden"}', True, "Failed to authorize. Status code: 403. Error: Forbidden"),
(404, "", True, "Unexpected error occurred while trying to authorize. Status code: 404"),
(401, "Not JSON", True, "Expecting value"), # JSON decode error
(403, '{"error": "Forbidden"}', False, "Failed to authorize. Status code: 403. Error: Forbidden"),
# empty body falls back to generic message
(404, "", True, "Failed to authorize. Status code: 404. Error: Unknown error occurred"),
# non-JSON body is surfaced directly
(401, "Not JSON", True, "Failed to authorize. Status code: 401. Error: Not JSON"),
],
)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@ -124,12 +127,14 @@ class TestFirecrawlAuth:
mock_response.status_code = status_code
mock_response.text = response_text
if has_json_error:
mock_response.json.side_effect = Exception("Not JSON")
mock_response.json.side_effect = json.JSONDecodeError("Not JSON", "", 0)
else:
mock_response.json.return_value = {"error": "Forbidden"}
mock_post.return_value = mock_response
with pytest.raises(Exception) as exc_info:
auth_instance.validate_credentials()
assert expected_error_contains in str(exc_info.value)
assert str(exc_info.value) == expected_error_contains
@pytest.mark.parametrize(
("exception_type", "exception_message"),
@ -164,20 +169,21 @@ class TestFirecrawlAuth:
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_use_custom_base_url_in_validation(self, mock_post):
"""Test that custom base URL is used in validation"""
"""Test that custom base URL is used in validation and normalized"""
mock_response = MagicMock()
mock_response.status_code = 200
mock_post.return_value = mock_response
credentials = {
"auth_type": "bearer",
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.firecrawl.dev"},
}
auth = FirecrawlAuth(credentials)
result = auth.validate_credentials()
for base in ("https://custom.firecrawl.dev", "https://custom.firecrawl.dev/"):
credentials = {
"auth_type": "bearer",
"config": {"api_key": "test_api_key_123", "base_url": base},
}
auth = FirecrawlAuth(credentials)
result = auth.validate_credentials()
assert result is True
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
assert result is True
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
@patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):

View File

@ -0,0 +1,71 @@
from unittest.mock import MagicMock
import httpx
from models import Account
from services import app_dsl_service
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
def _build_response(url: str, status_code: int, content: bytes = b"") -> httpx.Response:
request = httpx.Request("GET", url)
return httpx.Response(status_code=status_code, request=request, content=content)
def _pending_yaml_content(version: str = "99.0.0") -> bytes:
return (f'version: "{version}"\nkind: app\napp:\n name: Loop Test\n mode: workflow\n').encode()
def _account_mock() -> MagicMock:
account = MagicMock(spec=Account)
account.current_tenant_id = "tenant-1"
return account
def test_import_app_yaml_url_user_attachments_keeps_original_url(monkeypatch):
yaml_url = "https://github.com/user-attachments/files/24290802/loop-test.yml"
raw_url = "https://raw.githubusercontent.com/user-attachments/files/24290802/loop-test.yml"
yaml_bytes = _pending_yaml_content()
def fake_get(url: str, **kwargs):
if url == raw_url:
return _build_response(url, status_code=404)
assert url == yaml_url
return _build_response(url, status_code=200, content=yaml_bytes)
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get)
service = AppDslService(MagicMock())
result = service.import_app(
account=_account_mock(),
import_mode=ImportMode.YAML_URL,
yaml_url=yaml_url,
)
assert result.status == ImportStatus.PENDING
assert result.imported_dsl_version == "99.0.0"
def test_import_app_yaml_url_github_blob_rewrites_to_raw(monkeypatch):
yaml_url = "https://github.com/acme/repo/blob/main/app.yml"
raw_url = "https://raw.githubusercontent.com/acme/repo/main/app.yml"
yaml_bytes = _pending_yaml_content()
requested_urls: list[str] = []
def fake_get(url: str, **kwargs):
requested_urls.append(url)
assert url == raw_url
return _build_response(url, status_code=200, content=yaml_bytes)
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get)
service = AppDslService(MagicMock())
result = service.import_app(
account=_account_mock(),
import_mode=ImportMode.YAML_URL,
yaml_url=yaml_url,
)
assert result.status == ImportStatus.PENDING
assert requested_urls == [raw_url]

View File

@ -1636,7 +1636,7 @@ requires-dist = [
{ name = "pydantic-extra-types", specifier = "~=2.10.3" },
{ name = "pydantic-settings", specifier = "~=2.11.0" },
{ name = "pyjwt", specifier = "~=2.10.1" },
{ name = "pypdfium2", specifier = "==4.30.0" },
{ name = "pypdfium2", specifier = "==5.2.0" },
{ name = "python-docx", specifier = "~=1.1.0" },
{ name = "python-dotenv", specifier = "==1.0.1" },
{ name = "pyyaml", specifier = "~=6.0.1" },
@ -4993,22 +4993,31 @@ wheels = [
[[package]]
name = "pypdfium2"
version = "4.30.0"
version = "5.2.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a1/14/838b3ba247a0ba92e4df5d23f2bea9478edcfd72b78a39d6ca36ccd84ad2/pypdfium2-4.30.0.tar.gz", hash = "sha256:48b5b7e5566665bc1015b9d69c1ebabe21f6aee468b509531c3c8318eeee2e16", size = 140239, upload-time = "2024-05-09T18:33:17.552Z" }
sdist = { url = "https://files.pythonhosted.org/packages/f6/ab/73c7d24e4eac9ba952569403b32b7cca9412fc5b9bef54fdbd669551389f/pypdfium2-5.2.0.tar.gz", hash = "sha256:43863625231ce999c1ebbed6721a88de818b2ab4d909c1de558d413b9a400256", size = 269999, upload-time = "2025-12-12T13:20:15.353Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c7/9a/c8ff5cc352c1b60b0b97642ae734f51edbab6e28b45b4fcdfe5306ee3c83/pypdfium2-4.30.0-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:b33ceded0b6ff5b2b93bc1fe0ad4b71aa6b7e7bd5875f1ca0cdfb6ba6ac01aab", size = 2837254, upload-time = "2024-05-09T18:32:48.653Z" },
{ url = "https://files.pythonhosted.org/packages/21/8b/27d4d5409f3c76b985f4ee4afe147b606594411e15ac4dc1c3363c9a9810/pypdfium2-4.30.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:4e55689f4b06e2d2406203e771f78789bd4f190731b5d57383d05cf611d829de", size = 2707624, upload-time = "2024-05-09T18:32:51.458Z" },
{ url = "https://files.pythonhosted.org/packages/11/63/28a73ca17c24b41a205d658e177d68e198d7dde65a8c99c821d231b6ee3d/pypdfium2-4.30.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e6e50f5ce7f65a40a33d7c9edc39f23140c57e37144c2d6d9e9262a2a854854", size = 2793126, upload-time = "2024-05-09T18:32:53.581Z" },
{ url = "https://files.pythonhosted.org/packages/d1/96/53b3ebf0955edbd02ac6da16a818ecc65c939e98fdeb4e0958362bd385c8/pypdfium2-4.30.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3d0dd3ecaffd0b6dbda3da663220e705cb563918249bda26058c6036752ba3a2", size = 2591077, upload-time = "2024-05-09T18:32:55.99Z" },
{ url = "https://files.pythonhosted.org/packages/ec/ee/0394e56e7cab8b5b21f744d988400948ef71a9a892cbeb0b200d324ab2c7/pypdfium2-4.30.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cc3bf29b0db8c76cdfaac1ec1cde8edf211a7de7390fbf8934ad2aa9b4d6dfad", size = 2864431, upload-time = "2024-05-09T18:32:57.911Z" },
{ url = "https://files.pythonhosted.org/packages/65/cd/3f1edf20a0ef4a212a5e20a5900e64942c5a374473671ac0780eaa08ea80/pypdfium2-4.30.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1f78d2189e0ddf9ac2b7a9b9bd4f0c66f54d1389ff6c17e9fd9dc034d06eb3f", size = 2812008, upload-time = "2024-05-09T18:32:59.886Z" },
{ url = "https://files.pythonhosted.org/packages/c8/91/2d517db61845698f41a2a974de90762e50faeb529201c6b3574935969045/pypdfium2-4.30.0-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:5eda3641a2da7a7a0b2f4dbd71d706401a656fea521b6b6faa0675b15d31a163", size = 6181543, upload-time = "2024-05-09T18:33:02.597Z" },
{ url = "https://files.pythonhosted.org/packages/ba/c4/ed1315143a7a84b2c7616569dfb472473968d628f17c231c39e29ae9d780/pypdfium2-4.30.0-py3-none-musllinux_1_1_i686.whl", hash = "sha256:0dfa61421b5eb68e1188b0b2231e7ba35735aef2d867d86e48ee6cab6975195e", size = 6175911, upload-time = "2024-05-09T18:33:05.376Z" },
{ url = "https://files.pythonhosted.org/packages/7a/c4/9e62d03f414e0e3051c56d5943c3bf42aa9608ede4e19dc96438364e9e03/pypdfium2-4.30.0-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:f33bd79e7a09d5f7acca3b0b69ff6c8a488869a7fab48fdf400fec6e20b9c8be", size = 6267430, upload-time = "2024-05-09T18:33:08.067Z" },
{ url = "https://files.pythonhosted.org/packages/90/47/eda4904f715fb98561e34012826e883816945934a851745570521ec89520/pypdfium2-4.30.0-py3-none-win32.whl", hash = "sha256:ee2410f15d576d976c2ab2558c93d392a25fb9f6635e8dd0a8a3a5241b275e0e", size = 2775951, upload-time = "2024-05-09T18:33:10.567Z" },
{ url = "https://files.pythonhosted.org/packages/25/bd/56d9ec6b9f0fc4e0d95288759f3179f0fcd34b1a1526b75673d2f6d5196f/pypdfium2-4.30.0-py3-none-win_amd64.whl", hash = "sha256:90dbb2ac07be53219f56be09961eb95cf2473f834d01a42d901d13ccfad64b4c", size = 2892098, upload-time = "2024-05-09T18:33:13.107Z" },
{ url = "https://files.pythonhosted.org/packages/be/7a/097801205b991bc3115e8af1edb850d30aeaf0118520b016354cf5ccd3f6/pypdfium2-4.30.0-py3-none-win_arm64.whl", hash = "sha256:119b2969a6d6b1e8d55e99caaf05290294f2d0fe49c12a3f17102d01c441bd29", size = 2752118, upload-time = "2024-05-09T18:33:15.489Z" },
{ url = "https://files.pythonhosted.org/packages/fb/0c/9108ae5266ee4cdf495f99205c44d4b5c83b4eb227c2b610d35c9e9fe961/pypdfium2-5.2.0-py3-none-android_23_arm64_v8a.whl", hash = "sha256:1ba4187a45ce4cf08f2a8c7e0f8970c36b9aa1770c8a3412a70781c1d80fb145", size = 2763268, upload-time = "2025-12-12T13:19:37.354Z" },
{ url = "https://files.pythonhosted.org/packages/35/8c/55f5c8a2c6b293f5c020be4aa123eaa891e797c514e5eccd8cb042740d37/pypdfium2-5.2.0-py3-none-android_23_armeabi_v7a.whl", hash = "sha256:80c55e10a8c9242f0901d35a9a306dd09accce8e497507bb23fcec017d45fe2e", size = 2301821, upload-time = "2025-12-12T13:19:39.484Z" },
{ url = "https://files.pythonhosted.org/packages/5e/7d/efa013e3795b41c59dd1e472f7201c241232c3a6553be4917e3a26b9f225/pypdfium2-5.2.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:73523ae69cd95c084c1342096893b2143ea73c36fdde35494780ba431e6a7d6e", size = 2816428, upload-time = "2025-12-12T13:19:41.735Z" },
{ url = "https://files.pythonhosted.org/packages/ec/ae/8c30af6ff2ab41a7cb84753ee79dd1e0a8932c9bda9fe19759d69cbbf115/pypdfium2-5.2.0-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:19c501d22ef5eb98e42416d22cc3ac66d4808b436e3d06686392f24d8d9f708d", size = 2939486, upload-time = "2025-12-12T13:19:43.176Z" },
{ url = "https://files.pythonhosted.org/packages/64/64/454a73c49a04c2c290917ad86184e4da959e9e5aba94b3b046328c89be93/pypdfium2-5.2.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ed15a3f58d6ee4905f0d0a731e30b381b457c30689512589c7f57950b0cdcec", size = 2979235, upload-time = "2025-12-12T13:19:44.635Z" },
{ url = "https://files.pythonhosted.org/packages/4e/29/f1cab8e31192dd367dc7b1afa71f45cfcb8ff0b176f1d2a0f528faf04052/pypdfium2-5.2.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:329cd1e9f068e8729e0d0b79a070d6126f52bc48ff1e40505cb207a5e20ce0ba", size = 2763001, upload-time = "2025-12-12T13:19:47.598Z" },
{ url = "https://files.pythonhosted.org/packages/bc/5d/e95fad8fdac960854173469c4b6931d5de5e09d05e6ee7d9756f8b95eef0/pypdfium2-5.2.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:325259759886e66619504df4721fef3b8deabf8a233e4f4a66e0c32ebae60c2f", size = 3057024, upload-time = "2025-12-12T13:19:49.179Z" },
{ url = "https://files.pythonhosted.org/packages/f4/32/468591d017ab67f8142d40f4db8163b6d8bb404fe0d22da75a5c661dc144/pypdfium2-5.2.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5683e8f08ab38ed05e0e59e611451ec74332803d4e78f8c45658ea1d372a17af", size = 3448598, upload-time = "2025-12-12T13:19:50.979Z" },
{ url = "https://files.pythonhosted.org/packages/f9/a5/57b4e389b77ab5f7e9361dc7fc03b5378e678ba81b21e791e85350fbb235/pypdfium2-5.2.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da4815426a5adcf03bf4d2c5f26c0ff8109dbfaf2c3415984689931bc6006ef9", size = 2993946, upload-time = "2025-12-12T13:19:53.154Z" },
{ url = "https://files.pythonhosted.org/packages/84/3a/e03e9978f817632aa56183bb7a4989284086fdd45de3245ead35f147179b/pypdfium2-5.2.0-py3-none-manylinux_2_27_s390x.manylinux_2_28_s390x.whl", hash = "sha256:64bf5c039b2c314dab1fd158bfff99db96299a5b5c6d96fc056071166056f1de", size = 3673148, upload-time = "2025-12-12T13:19:54.528Z" },
{ url = "https://files.pythonhosted.org/packages/13/ee/e581506806553afa4b7939d47bf50dca35c1151b8cc960f4542a6eb135ce/pypdfium2-5.2.0-py3-none-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:76b42a17748ac7dc04d5ef04d0561c6a0a4b546d113ec1d101d59650c6a340f7", size = 2964757, upload-time = "2025-12-12T13:19:56.406Z" },
{ url = "https://files.pythonhosted.org/packages/00/be/3715c652aff30f12284523dd337843d0efe3e721020f0ec303a99ffffd8d/pypdfium2-5.2.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:9d4367d471439fae846f0aba91ff9e8d66e524edcf3c8d6e02fe96fa306e13b9", size = 4130319, upload-time = "2025-12-12T13:19:57.889Z" },
{ url = "https://files.pythonhosted.org/packages/b0/0b/28aa2ede9004dd4192266bbad394df0896787f7c7bcfa4d1a6e091ad9a2c/pypdfium2-5.2.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:613f6bb2b47d76b66c0bf2ca581c7c33e3dd9dcb29d65d8c34fef4135f933149", size = 3746488, upload-time = "2025-12-12T13:19:59.469Z" },
{ url = "https://files.pythonhosted.org/packages/bc/04/1b791e1219652bbfc51df6498267d8dcec73ad508b99388b2890902ccd9d/pypdfium2-5.2.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c03fad3f2fa68d358f5dd4deb07e438482fa26fae439c49d127576d969769ca1", size = 4336534, upload-time = "2025-12-12T13:20:01.28Z" },
{ url = "https://files.pythonhosted.org/packages/4f/e3/6f00f963bb702ffd2e3e2d9c7286bc3bb0bebcdfa96ca897d466f66976c6/pypdfium2-5.2.0-py3-none-musllinux_1_2_ppc64le.whl", hash = "sha256:f10be1900ae21879d02d9f4d58c2d2db3a2e6da611736a8e9decc22d1fb02909", size = 4375079, upload-time = "2025-12-12T13:20:03.117Z" },
{ url = "https://files.pythonhosted.org/packages/3a/2a/7ec2b191b5e1b7716a0dfc14e6860e89bb355fb3b94ed0c1d46db526858c/pypdfium2-5.2.0-py3-none-musllinux_1_2_riscv64.whl", hash = "sha256:97c1a126d30378726872f94866e38c055740cae80313638dafd1cd448d05e7c0", size = 3928648, upload-time = "2025-12-12T13:20:05.041Z" },
{ url = "https://files.pythonhosted.org/packages/bf/c3/c6d972fa095ff3ace76f9d3a91ceaf8a9dbbe0d9a5a84ac1d6178a46630e/pypdfium2-5.2.0-py3-none-musllinux_1_2_s390x.whl", hash = "sha256:c369f183a90781b788af9a357a877bc8caddc24801e8346d0bf23f3295f89f3a", size = 4997772, upload-time = "2025-12-12T13:20:06.453Z" },
{ url = "https://files.pythonhosted.org/packages/22/45/2c64584b7a3ca5c4652280a884f4b85b8ed24e27662adeebdc06d991c917/pypdfium2-5.2.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b391f1cceb454934b612a05b54e90f98aafeffe5e73830d71700b17f0812226b", size = 4180046, upload-time = "2025-12-12T13:20:08.715Z" },
{ url = "https://files.pythonhosted.org/packages/d6/99/8d1ff87b626649400e62a2840e6e10fe258443ba518798e071fee4cd86f9/pypdfium2-5.2.0-py3-none-win32.whl", hash = "sha256:c68067938f617c37e4d17b18de7cac231fc7ce0eb7b6653b7283ebe8764d4999", size = 2990175, upload-time = "2025-12-12T13:20:10.241Z" },
{ url = "https://files.pythonhosted.org/packages/93/fc/114fff8895b620aac4984808e93d01b6d7b93e342a1635fcfe2a5f39cf39/pypdfium2-5.2.0-py3-none-win_amd64.whl", hash = "sha256:eb0591b720e8aaeab9475c66d653655ec1be0464b946f3f48a53922e843f0f3b", size = 3098615, upload-time = "2025-12-12T13:20:11.795Z" },
{ url = "https://files.pythonhosted.org/packages/08/97/eb738bff5998760d6e0cbcb7dd04cbf1a95a97b997fac6d4e57562a58992/pypdfium2-5.2.0-py3-none-win_arm64.whl", hash = "sha256:5dd1ef579f19fa3719aee4959b28bda44b1072405756708b5e83df8806a19521", size = 2939479, upload-time = "2025-12-12T13:20:13.815Z" },
]
[[package]]

View File

@ -468,6 +468,7 @@ ALIYUN_OSS_REGION=ap-southeast-1
ALIYUN_OSS_AUTH_VERSION=v4
# Don't start with '/'. OSS doesn't support leading slash in object names.
ALIYUN_OSS_PATH=your-path
ALIYUN_CLOUDBOX_ID=your-cloudbox-id
# Tencent COS Configuration
#
@ -491,6 +492,7 @@ HUAWEI_OBS_BUCKET_NAME=your-bucket-name
HUAWEI_OBS_SECRET_KEY=your-secret-key
HUAWEI_OBS_ACCESS_KEY=your-access-key
HUAWEI_OBS_SERVER=your-server-url
HUAWEI_OBS_PATH_STYLE=false
# Volcengine TOS Configuration
#

View File

@ -23,6 +23,10 @@ Welcome to the new `docker` directory for deploying Dify using Docker Compose. T
- Navigate to the `docker` directory.
- Copy the `.env.example` file to a new file named `.env` by running `cp .env.example .env`.
- Customize the `.env` file as needed. Refer to the `.env.example` file for detailed configuration options.
- **Optional (Recommended for upgrades)**:
You may use the environment synchronization tool to help keep your `.env` file aligned with the latest `.env.example` updates, while preserving your custom settings.
This is especially useful when upgrading Dify or managing a large, customized `.env` file.
See the [Environment Variables Synchronization](#environment-variables-synchronization) section below.
1. **Running the Services**:
- Execute `docker compose up` from the `docker` directory to start the services.
- To specify a vector database, set the `VECTOR_STORE` variable in your `.env` file to your desired vector database service, such as `milvus`, `weaviate`, or `opensearch`.
@ -111,6 +115,47 @@ The `.env.example` file provided in the Docker setup is extensive and covers a w
- Each service like `nginx`, `redis`, `db`, and vector databases have specific environment variables that are directly referenced in the `docker-compose.yaml`.
### Environment Variables Synchronization
When upgrading Dify or pulling the latest changes, new environment variables may be introduced in `.env.example`.
To help keep your existing `.env` file up to date **without losing your custom values**, an optional environment variables synchronization tool is provided.
> This tool performs a **one-way synchronization** from `.env.example` to `.env`.
> Existing values in `.env` are never overwritten automatically.
#### `dify-env-sync.sh` (Optional)
This script compares your current `.env` file with the latest `.env.example` template and helps safely apply new or updated environment variables.
**What it does**
- Creates a backup of the current `.env` file before making any changes
- Synchronizes newly added environment variables from `.env.example`
- Preserves all existing custom values in `.env`
- Displays differences and variables removed from `.env.example` for review
**Backup behavior**
Before synchronization, the current `.env` file is saved to the `env-backup/` directory with a timestamped filename
(e.g. `env-backup/.env.backup_20231218_143022`).
**When to use**
- After upgrading Dify to a newer version
- When `.env.example` has been updated with new environment variables
- When managing a large or heavily customized `.env` file
**Usage**
```bash
# Grant execution permission (first time only)
chmod +x dify-env-sync.sh
# Run the synchronization
./dify-env-sync.sh
```
### Additional Information
- **Continuous Improvement Phase**: We are actively seeking feedback from the community to refine and enhance the deployment process. As more users adopt this new method, we will continue to make improvements based on your experiences and suggestions.

465
docker/dify-env-sync.sh Executable file
View File

@ -0,0 +1,465 @@
#!/bin/bash
# ================================================================
# Dify Environment Variables Synchronization Script
#
# Features:
# - Synchronize latest settings from .env.example to .env
# - Preserve custom settings in existing .env
# - Add new environment variables
# - Detect removed environment variables
# - Create backup files
# ================================================================
set -eo pipefail # Exit on error and pipe failures (safer for complex variable handling)
# Error handling function
# Arguments:
# $1 - Line number where error occurred
# $2 - Error code
handle_error() {
local line_no=$1
local error_code=$2
echo -e "\033[0;31m[ERROR]\033[0m Script error: line $line_no with error code $error_code" >&2
echo -e "\033[0;31m[ERROR]\033[0m Debug info: current working directory $(pwd)" >&2
exit $error_code
}
# Set error trap
trap 'handle_error ${LINENO} $?' ERR
# Color settings for output
readonly RED='\033[0;31m'
readonly GREEN='\033[0;32m'
readonly YELLOW='\033[1;33m'
readonly BLUE='\033[0;34m'
readonly NC='\033[0m' # No Color
# Logging functions
# Print informational message in blue
# Arguments: $1 - Message to print
log_info() {
echo -e "${BLUE}[INFO]${NC} $1"
}
# Print success message in green
# Arguments: $1 - Message to print
log_success() {
echo -e "${GREEN}[SUCCESS]${NC} $1"
}
# Print warning message in yellow
# Arguments: $1 - Message to print
log_warning() {
echo -e "${YELLOW}[WARNING]${NC} $1" >&2
}
# Print error message in red to stderr
# Arguments: $1 - Message to print
log_error() {
echo -e "${RED}[ERROR]${NC} $1" >&2
}
# Check for required files and create .env if missing
# Verifies that .env.example exists and creates .env from template if needed
check_files() {
log_info "Checking required files..."
if [[ ! -f ".env.example" ]]; then
log_error ".env.example file not found"
exit 1
fi
if [[ ! -f ".env" ]]; then
log_warning ".env file does not exist. Creating from .env.example."
cp ".env.example" ".env"
log_success ".env file created"
fi
log_success "Required files verified"
}
# Create timestamped backup of .env file
# Creates env-backup directory if needed and backs up current .env file
create_backup() {
local timestamp=$(date +"%Y%m%d_%H%M%S")
local backup_dir="env-backup"
# Create backup directory if it doesn't exist
if [[ ! -d "$backup_dir" ]]; then
mkdir -p "$backup_dir"
log_info "Created backup directory: $backup_dir"
fi
if [[ -f ".env" ]]; then
local backup_file="${backup_dir}/.env.backup_${timestamp}"
cp ".env" "$backup_file"
log_success "Backed up existing .env to $backup_file"
fi
}
# Detect differences between .env and .env.example (optimized for large files)
detect_differences() {
log_info "Detecting differences between .env and .env.example..."
# Create secure temporary directory
local temp_dir=$(mktemp -d)
local temp_diff="$temp_dir/env_diff"
# Store diff file path as global variable
declare -g DIFF_FILE="$temp_diff"
declare -g TEMP_DIR="$temp_dir"
# Initialize difference file
> "$temp_diff"
# Use awk for efficient comparison (much faster for large files)
local diff_count=$(awk -F= '
BEGIN { OFS="\x01" }
FNR==NR {
if (!/^[[:space:]]*#/ && !/^[[:space:]]*$/ && /=/) {
gsub(/^[[:space:]]+|[[:space:]]+$/, "", $1)
key = $1
value = substr($0, index($0,"=")+1)
gsub(/^[[:space:]]+|[[:space:]]+$/, "", value)
env_values[key] = value
}
next
}
{
if (!/^[[:space:]]*#/ && !/^[[:space:]]*$/ && /=/) {
gsub(/^[[:space:]]+|[[:space:]]+$/, "", $1)
key = $1
example_value = substr($0, index($0,"=")+1)
gsub(/^[[:space:]]+|[[:space:]]+$/, "", example_value)
if (key in env_values && env_values[key] != example_value) {
print key, env_values[key], example_value > "'$temp_diff'"
diff_count++
}
}
}
END { print diff_count }
' .env .env.example)
if [[ $diff_count -gt 0 ]]; then
log_success "Detected differences in $diff_count environment variables"
# Show detailed differences
show_differences_detail
else
log_info "No differences detected"
fi
}
# Parse environment variable line
# Extracts key-value pairs from .env file format lines
# Arguments:
# $1 - Line to parse
# Returns:
# 0 - Success, outputs "key|value" format
# 1 - Skip (empty line, comment, or invalid format)
parse_env_line() {
local line="$1"
local key=""
local value=""
# Skip empty lines or comment lines
[[ -z "$line" || "$line" =~ ^[[:space:]]*# ]] && return 1
# Split by =
if [[ "$line" =~ ^([^=]+)=(.*)$ ]]; then
key="${BASH_REMATCH[1]}"
value="${BASH_REMATCH[2]}"
# Remove leading and trailing whitespace
key=$(echo "$key" | sed 's/^[[:space:]]*//; s/[[:space:]]*$//')
value=$(echo "$value" | sed 's/^[[:space:]]*//; s/[[:space:]]*$//')
if [[ -n "$key" ]]; then
echo "$key|$value"
return 0
fi
fi
return 1
}
# Show detailed differences
show_differences_detail() {
log_info ""
log_info "=== Environment Variable Differences ==="
# Read differences from the already created diff file
if [[ ! -s "$DIFF_FILE" ]]; then
log_info "No differences to display"
return
fi
# Display differences
local count=1
while IFS=$'\x01' read -r key env_value example_value; do
echo ""
echo -e "${YELLOW}[$count] $key${NC}"
echo -e " ${GREEN}.env (current)${NC} : ${env_value}"
echo -e " ${BLUE}.env.example (recommended)${NC}: ${example_value}"
# Analyze value changes
analyze_value_change "$env_value" "$example_value"
((count++))
done < "$DIFF_FILE"
echo ""
log_info "=== Difference Analysis Complete ==="
log_info "Note: Consider changing to the recommended values above."
log_info "Current implementation preserves .env values."
echo ""
}
# Analyze value changes
analyze_value_change() {
local current_value="$1"
local recommended_value="$2"
# Analyze value characteristics
local analysis=""
# Empty value check
if [[ -z "$current_value" && -n "$recommended_value" ]]; then
analysis=" ${RED}→ Setting from empty to recommended value${NC}"
elif [[ -n "$current_value" && -z "$recommended_value" ]]; then
analysis=" ${RED}→ Recommended value changed to empty${NC}"
# Numeric check - using arithmetic evaluation for robust comparison
elif [[ "$current_value" =~ ^[0-9]+$ && "$recommended_value" =~ ^[0-9]+$ ]]; then
# Use arithmetic evaluation to handle leading zeros correctly
if (( 10#$current_value < 10#$recommended_value )); then
analysis=" ${BLUE}→ Numeric increase (${current_value} < ${recommended_value})${NC}"
elif (( 10#$current_value > 10#$recommended_value )); then
analysis=" ${YELLOW}→ Numeric decrease (${current_value} > ${recommended_value})${NC}"
fi
# Boolean check
elif [[ "$current_value" =~ ^(true|false)$ && "$recommended_value" =~ ^(true|false)$ ]]; then
if [[ "$current_value" != "$recommended_value" ]]; then
analysis=" ${BLUE}→ Boolean value change (${current_value}${recommended_value})${NC}"
fi
# URL/endpoint check
elif [[ "$current_value" =~ ^https?:// || "$recommended_value" =~ ^https?:// ]]; then
analysis=" ${BLUE}→ URL/endpoint change${NC}"
# File path check
elif [[ "$current_value" =~ ^/ || "$recommended_value" =~ ^/ ]]; then
analysis=" ${BLUE}→ File path change${NC}"
else
# Length comparison
local current_len=${#current_value}
local recommended_len=${#recommended_value}
if [[ $current_len -ne $recommended_len ]]; then
analysis=" ${YELLOW}→ String length change (${current_len}${recommended_len} characters)${NC}"
fi
fi
if [[ -n "$analysis" ]]; then
echo -e "$analysis"
fi
}
# Synchronize .env file with .env.example while preserving custom values
# Creates a new .env file based on .env.example structure, preserving existing custom values
# Global variables used: DIFF_FILE, TEMP_DIR
sync_env_file() {
log_info "Starting partial synchronization of .env file..."
local new_env_file=".env.new"
local preserved_count=0
local updated_count=0
# Pre-process diff file for efficient lookup
local lookup_file=""
if [[ -f "$DIFF_FILE" && -s "$DIFF_FILE" ]]; then
lookup_file="${DIFF_FILE}.lookup"
# Create sorted lookup file for fast search
sort "$DIFF_FILE" > "$lookup_file"
log_info "Created lookup file for $(wc -l < "$DIFF_FILE") preserved values"
fi
# Use AWK for efficient processing (much faster than bash loop for large files)
log_info "Processing $(wc -l < .env.example) lines with AWK..."
local preserved_keys_file="${TEMP_DIR}/preserved_keys"
local awk_preserved_count_file="${TEMP_DIR}/awk_preserved_count"
local awk_updated_count_file="${TEMP_DIR}/awk_updated_count"
awk -F'=' -v lookup_file="$lookup_file" -v preserved_file="$preserved_keys_file" \
-v preserved_count_file="$awk_preserved_count_file" -v updated_count_file="$awk_updated_count_file" '
BEGIN {
preserved_count = 0
updated_count = 0
# Load preserved values if lookup file exists
if (lookup_file != "") {
while ((getline line < lookup_file) > 0) {
split(line, parts, "\x01")
key = parts[1]
value = parts[2]
preserved_values[key] = value
}
close(lookup_file)
}
}
# Process each line
{
# Check if this is an environment variable line
if (/^[[:space:]]*[A-Za-z_][A-Za-z0-9_]*[[:space:]]*=/) {
# Extract key
key = $1
gsub(/^[[:space:]]+|[[:space:]]+$/, "", key)
# Check if key should be preserved
if (key in preserved_values) {
print key "=" preserved_values[key]
print key > preserved_file
preserved_count++
} else {
print $0
updated_count++
}
} else {
# Not an env var line, preserve as-is
print $0
}
}
END {
print preserved_count > preserved_count_file
print updated_count > updated_count_file
}
' .env.example > "$new_env_file"
# Read counters and preserved keys
if [[ -f "$awk_preserved_count_file" ]]; then
preserved_count=$(cat "$awk_preserved_count_file")
fi
if [[ -f "$awk_updated_count_file" ]]; then
updated_count=$(cat "$awk_updated_count_file")
fi
# Show what was preserved
if [[ -f "$preserved_keys_file" ]]; then
while read -r key; do
[[ -n "$key" ]] && log_info " Preserved: $key (.env value)"
done < "$preserved_keys_file"
fi
# Clean up lookup file
[[ -n "$lookup_file" ]] && rm -f "$lookup_file"
# Replace the original .env file
if mv "$new_env_file" ".env"; then
log_success "Successfully created new .env file"
else
log_error "Failed to replace .env file"
rm -f "$new_env_file"
return 1
fi
# Clean up difference file and temporary directory
if [[ -n "${TEMP_DIR:-}" ]]; then
rm -rf "${TEMP_DIR}"
unset TEMP_DIR
fi
if [[ -n "${DIFF_FILE:-}" ]]; then
unset DIFF_FILE
fi
log_success "Partial synchronization of .env file completed"
log_info " Preserved .env values: $preserved_count"
log_info " Updated to .env.example values: $updated_count"
}
# Detect removed environment variables
detect_removed_variables() {
log_info "Detecting removed environment variables..."
if [[ ! -f ".env" ]]; then
return
fi
# Use temporary files for efficient lookup
local temp_dir="${TEMP_DIR:-$(mktemp -d)}"
local temp_example_keys="$temp_dir/example_keys"
local temp_current_keys="$temp_dir/current_keys"
local cleanup_temp_dir=""
# Set flag if we created a new temp directory
if [[ -z "${TEMP_DIR:-}" ]]; then
cleanup_temp_dir="$temp_dir"
fi
# Get keys from .env.example and .env, sorted for comm
awk -F= '!/^[[:space:]]*#/ && /=/ {gsub(/^[[:space:]]+|[[:space:]]+$/, "", $1); print $1}' .env.example | sort > "$temp_example_keys"
awk -F= '!/^[[:space:]]*#/ && /=/ {gsub(/^[[:space:]]+|[[:space:]]+$/, "", $1); print $1}' .env | sort > "$temp_current_keys"
# Get keys from existing .env and check for removals
local removed_vars=()
while IFS= read -r var; do
removed_vars+=("$var")
done < <(comm -13 "$temp_example_keys" "$temp_current_keys")
# Clean up temporary files if we created a new temp directory
if [[ -n "$cleanup_temp_dir" ]]; then
rm -rf "$cleanup_temp_dir"
fi
if [[ ${#removed_vars[@]} -gt 0 ]]; then
log_warning "The following environment variables have been removed from .env.example:"
for var in "${removed_vars[@]}"; do
log_warning " - $var"
done
log_warning "Consider manually removing these variables from .env"
else
log_success "No removed environment variables found"
fi
}
# Show statistics
show_statistics() {
log_info "Synchronization statistics:"
local total_example=$(grep -c "^[^#]*=" .env.example 2>/dev/null || echo "0")
local total_env=$(grep -c "^[^#]*=" .env 2>/dev/null || echo "0")
log_info " .env.example environment variables: $total_example"
log_info " .env environment variables: $total_env"
}
# Main execution function
# Orchestrates the complete synchronization process in the correct order
main() {
log_info "=== Dify Environment Variables Synchronization Script ==="
log_info "Execution started: $(date)"
# Check prerequisites
check_files
# Create backup
create_backup
# Detect differences
detect_differences
# Detect removed variables (before sync)
detect_removed_variables
# Synchronize environment file
sync_env_file
# Show statistics
show_statistics
log_success "=== Synchronization process completed successfully ==="
log_info "Execution finished: $(date)"
}
# Execute main function only when script is run directly
if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then
main "$@"
fi

View File

@ -270,7 +270,7 @@ services:
# plugin daemon
plugin_daemon:
image: langgenius/dify-plugin-daemon:0.5.1-local
image: langgenius/dify-plugin-daemon:0.5.2-local
restart: always
environment:
# Use the shared environment variables.

View File

@ -123,7 +123,7 @@ services:
# plugin daemon
plugin_daemon:
image: langgenius/dify-plugin-daemon:0.5.1-local
image: langgenius/dify-plugin-daemon:0.5.2-local
restart: always
env_file:
- ./middleware.env

View File

@ -134,6 +134,7 @@ x-shared-env: &shared-api-worker-env
ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-ap-southeast-1}
ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4}
ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-your-path}
ALIYUN_CLOUDBOX_ID: ${ALIYUN_CLOUDBOX_ID:-your-cloudbox-id}
TENCENT_COS_BUCKET_NAME: ${TENCENT_COS_BUCKET_NAME:-your-bucket-name}
TENCENT_COS_SECRET_KEY: ${TENCENT_COS_SECRET_KEY:-your-secret-key}
TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id}
@ -148,6 +149,7 @@ x-shared-env: &shared-api-worker-env
HUAWEI_OBS_SECRET_KEY: ${HUAWEI_OBS_SECRET_KEY:-your-secret-key}
HUAWEI_OBS_ACCESS_KEY: ${HUAWEI_OBS_ACCESS_KEY:-your-access-key}
HUAWEI_OBS_SERVER: ${HUAWEI_OBS_SERVER:-your-server-url}
HUAWEI_OBS_PATH_STYLE: ${HUAWEI_OBS_PATH_STYLE:-false}
VOLCENGINE_TOS_BUCKET_NAME: ${VOLCENGINE_TOS_BUCKET_NAME:-your-bucket-name}
VOLCENGINE_TOS_SECRET_KEY: ${VOLCENGINE_TOS_SECRET_KEY:-your-secret-key}
VOLCENGINE_TOS_ACCESS_KEY: ${VOLCENGINE_TOS_ACCESS_KEY:-your-access-key}
@ -939,7 +941,7 @@ services:
# plugin daemon
plugin_daemon:
image: langgenius/dify-plugin-daemon:0.5.1-local
image: langgenius/dify-plugin-daemon:0.5.2-local
restart: always
environment:
# Use the shared environment variables.

View File

@ -1,48 +1,40 @@
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
# Dependencies
node_modules/
# dependencies
/node_modules
/.pnp
.pnp.js
# Build output
dist/
# testing
/coverage
# Testing
coverage/
# next.js
/.next/
/out/
# IDE
.idea/
.vscode/
*.swp
*.swo
# production
/build
# misc
# OS
.DS_Store
*.pem
Thumbs.db
# debug
# Debug logs
npm-debug.log*
yarn-debug.log*
yarn-error.log*
.pnpm-debug.log*
pnpm-debug.log*
# local env files
.env*.local
# Environment
.env
.env.local
.env.*.local
# vercel
.vercel
# typescript
# TypeScript
*.tsbuildinfo
next-env.d.ts
# npm
# Lock files (use pnpm-lock.yaml in CI if needed)
package-lock.json
yarn.lock
# yarn
.pnp.cjs
.pnp.loader.mjs
.yarn/
.yarnrc.yml
# pmpm
pnpm-lock.yaml
# Misc
*.pem
*.tgz

View File

@ -0,0 +1,22 @@
MIT License
Copyright (c) 2023 LangGenius
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -13,54 +13,92 @@ npm install dify-client
After installing the SDK, you can use it in your project like this:
```js
import { DifyClient, ChatClient, CompletionClient } from 'dify-client'
import {
DifyClient,
ChatClient,
CompletionClient,
WorkflowClient,
KnowledgeBaseClient,
WorkspaceClient
} from 'dify-client'
const API_KEY = 'your-api-key-here'
const user = `random-user-id`
const API_KEY = 'your-app-api-key'
const DATASET_API_KEY = 'your-dataset-api-key'
const user = 'random-user-id'
const query = 'Please tell me a short story in 10 words or less.'
const remote_url_files = [{
type: 'image',
transfer_method: 'remote_url',
url: 'your_url_address'
}]
// Create a completion client
const completionClient = new CompletionClient(API_KEY)
// Create a completion message
completionClient.createCompletionMessage({'query': query}, user)
// Create a completion message with vision model
completionClient.createCompletionMessage({'query': 'Describe the picture.'}, user, false, remote_url_files)
// Create a chat client
const chatClient = new ChatClient(API_KEY)
// Create a chat message in stream mode
const response = await chatClient.createChatMessage({}, query, user, true, null)
const stream = response.data;
stream.on('data', data => {
console.log(data);
});
stream.on('end', () => {
console.log('stream done');
});
// Create a chat message with vision model
chatClient.createChatMessage({}, 'Describe the picture.', user, false, null, remote_url_files)
// Fetch conversations
chatClient.getConversations(user)
// Fetch conversation messages
chatClient.getConversationMessages(conversationId, user)
// Rename conversation
chatClient.renameConversation(conversationId, name, user)
const completionClient = new CompletionClient(API_KEY)
const workflowClient = new WorkflowClient(API_KEY)
const kbClient = new KnowledgeBaseClient(DATASET_API_KEY)
const workspaceClient = new WorkspaceClient(DATASET_API_KEY)
const client = new DifyClient(API_KEY)
// Fetch application parameters
client.getApplicationParameters(user)
// Provide feedback for a message
client.messageFeedback(messageId, rating, user)
// App core
await client.getApplicationParameters(user)
await client.messageFeedback('message-id', 'like', user)
// Completion (blocking)
await completionClient.createCompletionMessage({
inputs: { query },
user,
response_mode: 'blocking'
})
// Chat (streaming)
const stream = await chatClient.createChatMessage({
inputs: {},
query,
user,
response_mode: 'streaming'
})
for await (const event of stream) {
console.log(event.event, event.data)
}
// Chatflow (advanced chat via workflow_id)
await chatClient.createChatMessage({
inputs: {},
query,
user,
workflow_id: 'workflow-id',
response_mode: 'blocking'
})
// Workflow run (blocking or streaming)
await workflowClient.run({
inputs: { query },
user,
response_mode: 'blocking'
})
// Knowledge base (dataset token required)
await kbClient.listDatasets({ page: 1, limit: 20 })
await kbClient.createDataset({ name: 'KB', indexing_technique: 'economy' })
// RAG pipeline (may require service API route registration)
const pipelineStream = await kbClient.runPipeline('dataset-id', {
inputs: {},
datasource_type: 'online_document',
datasource_info_list: [],
start_node_id: 'start-node-id',
is_published: true,
response_mode: 'streaming'
})
for await (const event of pipelineStream) {
console.log(event.data)
}
// Workspace models (dataset token required)
await workspaceClient.getModelsByType('text-embedding')
```
Replace 'your-api-key-here' with your actual Dify API key.Replace 'your-app-id-here' with your actual Dify APP ID.
Notes:
- App endpoints use an app API token; knowledge base and workspace endpoints use a dataset API token.
- Chat/completion require a stable `user` identifier in the request payload.
- For streaming responses, iterate the returned AsyncIterable. Use `stream.toText()` to collect text.
## License

View File

@ -1,12 +0,0 @@
module.exports = {
presets: [
[
"@babel/preset-env",
{
targets: {
node: "current",
},
},
],
],
};

View File

@ -0,0 +1,45 @@
import js from "@eslint/js";
import tsParser from "@typescript-eslint/parser";
import tsPlugin from "@typescript-eslint/eslint-plugin";
import { fileURLToPath } from "node:url";
import path from "node:path";
const tsconfigRootDir = path.dirname(fileURLToPath(import.meta.url));
const typeCheckedRules =
tsPlugin.configs["recommended-type-checked"]?.rules ??
tsPlugin.configs.recommendedTypeChecked?.rules ??
{};
export default [
{
ignores: ["dist", "node_modules", "scripts", "tests", "**/*.test.*", "**/*.spec.*"],
},
js.configs.recommended,
{
files: ["src/**/*.ts"],
languageOptions: {
parser: tsParser,
ecmaVersion: "latest",
parserOptions: {
project: "./tsconfig.json",
tsconfigRootDir,
sourceType: "module",
},
},
plugins: {
"@typescript-eslint": tsPlugin,
},
rules: {
...tsPlugin.configs.recommended.rules,
...typeCheckedRules,
"no-undef": "off",
"no-unused-vars": "off",
"@typescript-eslint/no-unsafe-call": "error",
"@typescript-eslint/no-unsafe-return": "error",
"@typescript-eslint/consistent-type-imports": [
"error",
{ prefer: "type-imports", fixStyle: "separate-type-imports" },
],
},
},
];

View File

@ -1,107 +0,0 @@
// Types.d.ts
export const BASE_URL: string;
export type RequestMethods = 'GET' | 'POST' | 'PATCH' | 'DELETE';
interface Params {
[key: string]: any;
}
interface HeaderParams {
[key: string]: string;
}
interface User {
}
interface DifyFileBase {
type: "image"
}
export interface DifyRemoteFile extends DifyFileBase {
transfer_method: "remote_url"
url: string
}
export interface DifyLocalFile extends DifyFileBase {
transfer_method: "local_file"
upload_file_id: string
}
export type DifyFile = DifyRemoteFile | DifyLocalFile;
export declare class DifyClient {
constructor(apiKey: string, baseUrl?: string);
updateApiKey(apiKey: string): void;
sendRequest(
method: RequestMethods,
endpoint: string,
data?: any,
params?: Params,
stream?: boolean,
headerParams?: HeaderParams
): Promise<any>;
messageFeedback(message_id: string, rating: number, user: User): Promise<any>;
getApplicationParameters(user: User): Promise<any>;
fileUpload(data: FormData): Promise<any>;
textToAudio(text: string ,user: string, streaming?: boolean): Promise<any>;
getMeta(user: User): Promise<any>;
}
export declare class CompletionClient extends DifyClient {
createCompletionMessage(
inputs: any,
user: User,
stream?: boolean,
files?: DifyFile[] | null
): Promise<any>;
}
export declare class ChatClient extends DifyClient {
createChatMessage(
inputs: any,
query: string,
user: User,
stream?: boolean,
conversation_id?: string | null,
files?: DifyFile[] | null
): Promise<any>;
getSuggested(message_id: string, user: User): Promise<any>;
stopMessage(task_id: string, user: User) : Promise<any>;
getConversations(
user: User,
first_id?: string | null,
limit?: number | null,
pinned?: boolean | null
): Promise<any>;
getConversationMessages(
user: User,
conversation_id?: string,
first_id?: string | null,
limit?: number | null
): Promise<any>;
renameConversation(conversation_id: string, name: string, user: User,auto_generate:boolean): Promise<any>;
deleteConversation(conversation_id: string, user: User): Promise<any>;
audioToText(data: FormData): Promise<any>;
}
export declare class WorkflowClient extends DifyClient {
run(inputs: any, user: User, stream?: boolean,): Promise<any>;
stop(task_id: string, user: User): Promise<any>;
}

View File

@ -1,351 +0,0 @@
import axios from "axios";
export const BASE_URL = "https://api.dify.ai/v1";
export const routes = {
// app's
feedback: {
method: "POST",
url: (message_id) => `/messages/${message_id}/feedbacks`,
},
application: {
method: "GET",
url: () => `/parameters`,
},
fileUpload: {
method: "POST",
url: () => `/files/upload`,
},
textToAudio: {
method: "POST",
url: () => `/text-to-audio`,
},
getMeta: {
method: "GET",
url: () => `/meta`,
},
// completion's
createCompletionMessage: {
method: "POST",
url: () => `/completion-messages`,
},
// chat's
createChatMessage: {
method: "POST",
url: () => `/chat-messages`,
},
getSuggested:{
method: "GET",
url: (message_id) => `/messages/${message_id}/suggested`,
},
stopChatMessage: {
method: "POST",
url: (task_id) => `/chat-messages/${task_id}/stop`,
},
getConversations: {
method: "GET",
url: () => `/conversations`,
},
getConversationMessages: {
method: "GET",
url: () => `/messages`,
},
renameConversation: {
method: "POST",
url: (conversation_id) => `/conversations/${conversation_id}/name`,
},
deleteConversation: {
method: "DELETE",
url: (conversation_id) => `/conversations/${conversation_id}`,
},
audioToText: {
method: "POST",
url: () => `/audio-to-text`,
},
// workflows
runWorkflow: {
method: "POST",
url: () => `/workflows/run`,
},
stopWorkflow: {
method: "POST",
url: (task_id) => `/workflows/tasks/${task_id}/stop`,
}
};
export class DifyClient {
constructor(apiKey, baseUrl = BASE_URL) {
this.apiKey = apiKey;
this.baseUrl = baseUrl;
}
updateApiKey(apiKey) {
this.apiKey = apiKey;
}
async sendRequest(
method,
endpoint,
data = null,
params = null,
stream = false,
headerParams = {}
) {
const isFormData =
(typeof FormData !== "undefined" && data instanceof FormData) ||
(data && data.constructor && data.constructor.name === "FormData");
const headers = {
Authorization: `Bearer ${this.apiKey}`,
...(isFormData ? {} : { "Content-Type": "application/json" }),
...headerParams,
};
const url = `${this.baseUrl}${endpoint}`;
let response;
if (stream) {
response = await axios({
method,
url,
data,
params,
headers,
responseType: "stream",
});
} else {
response = await axios({
method,
url,
...(method !== "GET" && { data }),
params,
headers,
responseType: "json",
});
}
return response;
}
messageFeedback(message_id, rating, user) {
const data = {
rating,
user,
};
return this.sendRequest(
routes.feedback.method,
routes.feedback.url(message_id),
data
);
}
getApplicationParameters(user) {
const params = { user };
return this.sendRequest(
routes.application.method,
routes.application.url(),
null,
params
);
}
fileUpload(data) {
return this.sendRequest(
routes.fileUpload.method,
routes.fileUpload.url(),
data
);
}
textToAudio(text, user, streaming = false) {
const data = {
text,
user,
streaming
};
return this.sendRequest(
routes.textToAudio.method,
routes.textToAudio.url(),
data,
null,
streaming
);
}
getMeta(user) {
const params = { user };
return this.sendRequest(
routes.getMeta.method,
routes.getMeta.url(),
null,
params
);
}
}
export class CompletionClient extends DifyClient {
createCompletionMessage(inputs, user, stream = false, files = null) {
const data = {
inputs,
user,
response_mode: stream ? "streaming" : "blocking",
files,
};
return this.sendRequest(
routes.createCompletionMessage.method,
routes.createCompletionMessage.url(),
data,
null,
stream
);
}
runWorkflow(inputs, user, stream = false, files = null) {
const data = {
inputs,
user,
response_mode: stream ? "streaming" : "blocking",
};
return this.sendRequest(
routes.runWorkflow.method,
routes.runWorkflow.url(),
data,
null,
stream
);
}
}
export class ChatClient extends DifyClient {
createChatMessage(
inputs,
query,
user,
stream = false,
conversation_id = null,
files = null
) {
const data = {
inputs,
query,
user,
response_mode: stream ? "streaming" : "blocking",
files,
};
if (conversation_id) data.conversation_id = conversation_id;
return this.sendRequest(
routes.createChatMessage.method,
routes.createChatMessage.url(),
data,
null,
stream
);
}
getSuggested(message_id, user) {
const data = { user };
return this.sendRequest(
routes.getSuggested.method,
routes.getSuggested.url(message_id),
data
);
}
stopMessage(task_id, user) {
const data = { user };
return this.sendRequest(
routes.stopChatMessage.method,
routes.stopChatMessage.url(task_id),
data
);
}
getConversations(user, first_id = null, limit = null, pinned = null) {
const params = { user, first_id: first_id, limit, pinned };
return this.sendRequest(
routes.getConversations.method,
routes.getConversations.url(),
null,
params
);
}
getConversationMessages(
user,
conversation_id = "",
first_id = null,
limit = null
) {
const params = { user };
if (conversation_id) params.conversation_id = conversation_id;
if (first_id) params.first_id = first_id;
if (limit) params.limit = limit;
return this.sendRequest(
routes.getConversationMessages.method,
routes.getConversationMessages.url(),
null,
params
);
}
renameConversation(conversation_id, name, user, auto_generate) {
const data = { name, user, auto_generate };
return this.sendRequest(
routes.renameConversation.method,
routes.renameConversation.url(conversation_id),
data
);
}
deleteConversation(conversation_id, user) {
const data = { user };
return this.sendRequest(
routes.deleteConversation.method,
routes.deleteConversation.url(conversation_id),
data
);
}
audioToText(data) {
return this.sendRequest(
routes.audioToText.method,
routes.audioToText.url(),
data
);
}
}
export class WorkflowClient extends DifyClient {
run(inputs,user,stream) {
const data = {
inputs,
response_mode: stream ? "streaming" : "blocking",
user
};
return this.sendRequest(
routes.runWorkflow.method,
routes.runWorkflow.url(),
data,
null,
stream
);
}
stop(task_id, user) {
const data = { user };
return this.sendRequest(
routes.stopWorkflow.method,
routes.stopWorkflow.url(task_id),
data
);
}
}

View File

@ -1,141 +0,0 @@
import { DifyClient, WorkflowClient, BASE_URL, routes } from ".";
import axios from 'axios'
jest.mock('axios')
afterEach(() => {
jest.resetAllMocks()
})
describe('Client', () => {
let difyClient
beforeEach(() => {
difyClient = new DifyClient('test')
})
test('should create a client', () => {
expect(difyClient).toBeDefined();
})
// test updateApiKey
test('should update the api key', () => {
difyClient.updateApiKey('test2');
expect(difyClient.apiKey).toBe('test2');
})
});
describe('Send Requests', () => {
let difyClient
beforeEach(() => {
difyClient = new DifyClient('test')
})
it('should make a successful request to the application parameter', async () => {
const method = 'GET'
const endpoint = routes.application.url()
const expectedResponse = { data: 'response' }
axios.mockResolvedValue(expectedResponse)
await difyClient.sendRequest(method, endpoint)
expect(axios).toHaveBeenCalledWith({
method,
url: `${BASE_URL}${endpoint}`,
params: null,
headers: {
Authorization: `Bearer ${difyClient.apiKey}`,
'Content-Type': 'application/json',
},
responseType: 'json',
})
})
it('should handle errors from the API', async () => {
const method = 'GET'
const endpoint = '/test-endpoint'
const errorMessage = 'Request failed with status code 404'
axios.mockRejectedValue(new Error(errorMessage))
await expect(difyClient.sendRequest(method, endpoint)).rejects.toThrow(
errorMessage
)
})
it('uses the getMeta route configuration', async () => {
axios.mockResolvedValue({ data: 'ok' })
await difyClient.getMeta('end-user')
expect(axios).toHaveBeenCalledWith({
method: routes.getMeta.method,
url: `${BASE_URL}${routes.getMeta.url()}`,
params: { user: 'end-user' },
headers: {
Authorization: `Bearer ${difyClient.apiKey}`,
'Content-Type': 'application/json',
},
responseType: 'json',
})
})
})
describe('File uploads', () => {
let difyClient
const OriginalFormData = global.FormData
beforeAll(() => {
global.FormData = class FormDataMock {}
})
afterAll(() => {
global.FormData = OriginalFormData
})
beforeEach(() => {
difyClient = new DifyClient('test')
})
it('does not override multipart boundary headers for FormData', async () => {
const form = new FormData()
axios.mockResolvedValue({ data: 'ok' })
await difyClient.fileUpload(form)
expect(axios).toHaveBeenCalledWith({
method: routes.fileUpload.method,
url: `${BASE_URL}${routes.fileUpload.url()}`,
data: form,
params: null,
headers: {
Authorization: `Bearer ${difyClient.apiKey}`,
},
responseType: 'json',
})
})
})
describe('Workflow client', () => {
let workflowClient
beforeEach(() => {
workflowClient = new WorkflowClient('test')
})
it('uses tasks stop path for workflow stop', async () => {
axios.mockResolvedValue({ data: 'stopped' })
await workflowClient.stop('task-1', 'end-user')
expect(axios).toHaveBeenCalledWith({
method: routes.stopWorkflow.method,
url: `${BASE_URL}${routes.stopWorkflow.url('task-1')}`,
data: { user: 'end-user' },
params: null,
headers: {
Authorization: `Bearer ${workflowClient.apiKey}`,
'Content-Type': 'application/json',
},
responseType: 'json',
})
})
})

View File

@ -1,6 +0,0 @@
module.exports = {
testEnvironment: "node",
transform: {
"^.+\\.[tj]sx?$": "babel-jest",
},
};

View File

@ -1,30 +1,70 @@
{
"name": "dify-client",
"version": "2.3.2",
"version": "3.0.0",
"description": "This is the Node.js SDK for the Dify.AI API, which allows you to easily integrate Dify.AI into your Node.js applications.",
"main": "index.js",
"type": "module",
"types":"index.d.ts",
"main": "./dist/index.js",
"types": "./dist/index.d.ts",
"exports": {
".": {
"types": "./dist/index.d.ts",
"import": "./dist/index.js"
}
},
"engines": {
"node": ">=18.0.0"
},
"files": [
"dist",
"README.md",
"LICENSE"
],
"keywords": [
"Dify",
"Dify.AI",
"LLM"
"LLM",
"AI",
"SDK",
"API"
],
"author": "Joel",
"author": "LangGenius",
"contributors": [
"<crazywoola> <<427733928@qq.com>> (https://github.com/crazywoola)"
"Joel <iamjoel007@gmail.com> (https://github.com/iamjoel)",
"lyzno1 <yuanyouhuilyz@gmail.com> (https://github.com/lyzno1)",
"crazywoola <427733928@qq.com> (https://github.com/crazywoola)"
],
"repository": {
"type": "git",
"url": "https://github.com/langgenius/dify.git",
"directory": "sdks/nodejs-client"
},
"bugs": {
"url": "https://github.com/langgenius/dify/issues"
},
"homepage": "https://dify.ai",
"license": "MIT",
"scripts": {
"test": "jest"
"build": "tsup",
"lint": "eslint",
"lint:fix": "eslint --fix",
"type-check": "tsc -p tsconfig.json --noEmit",
"test": "vitest run",
"test:coverage": "vitest run --coverage",
"publish:check": "./scripts/publish.sh --dry-run",
"publish:npm": "./scripts/publish.sh"
},
"dependencies": {
"axios": "^1.3.5"
},
"devDependencies": {
"@babel/core": "^7.21.8",
"@babel/preset-env": "^7.21.5",
"babel-jest": "^29.5.0",
"jest": "^29.5.0"
"@eslint/js": "^9.2.0",
"@types/node": "^20.11.30",
"@typescript-eslint/eslint-plugin": "^8.50.1",
"@typescript-eslint/parser": "^8.50.1",
"@vitest/coverage-v8": "1.6.1",
"eslint": "^9.2.0",
"tsup": "^8.5.1",
"typescript": "^5.4.5",
"vitest": "^1.5.0"
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,261 @@
#!/usr/bin/env bash
#
# Dify Node.js SDK Publish Script
# ================================
# A beautiful and reliable script to publish the SDK to npm
#
# Usage:
# ./scripts/publish.sh # Normal publish
# ./scripts/publish.sh --dry-run # Test without publishing
# ./scripts/publish.sh --skip-tests # Skip tests (not recommended)
#
set -euo pipefail
# ============================================================================
# Colors and Formatting
# ============================================================================
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
MAGENTA='\033[0;35m'
CYAN='\033[0;36m'
BOLD='\033[1m'
DIM='\033[2m'
NC='\033[0m' # No Color
# ============================================================================
# Helper Functions
# ============================================================================
print_banner() {
echo -e "${CYAN}"
echo "╔═══════════════════════════════════════════════════════════════╗"
echo "║ ║"
echo "║ 🚀 Dify Node.js SDK Publish Script 🚀 ║"
echo "║ ║"
echo "╚═══════════════════════════════════════════════════════════════╝"
echo -e "${NC}"
}
info() {
echo -e "${BLUE} ${NC}$1"
}
success() {
echo -e "${GREEN}${NC}$1"
}
warning() {
echo -e "${YELLOW}${NC}$1"
}
error() {
echo -e "${RED}${NC}$1"
}
step() {
echo -e "\n${MAGENTA}${BOLD}$1${NC}"
}
divider() {
echo -e "${DIM}─────────────────────────────────────────────────────────────────${NC}"
}
# ============================================================================
# Configuration
# ============================================================================
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_DIR="$(cd "$SCRIPT_DIR/.." && pwd)"
DRY_RUN=false
SKIP_TESTS=false
# Parse arguments
for arg in "$@"; do
case $arg in
--dry-run)
DRY_RUN=true
;;
--skip-tests)
SKIP_TESTS=true
;;
--help|-h)
echo "Usage: $0 [options]"
echo ""
echo "Options:"
echo " --dry-run Run without actually publishing"
echo " --skip-tests Skip running tests (not recommended)"
echo " --help, -h Show this help message"
exit 0
;;
esac
done
# ============================================================================
# Main Script
# ============================================================================
main() {
print_banner
cd "$PROJECT_DIR"
# Show mode
if [[ "$DRY_RUN" == true ]]; then
warning "Running in DRY-RUN mode - no actual publish will occur"
divider
fi
# ========================================================================
# Step 1: Environment Check
# ========================================================================
step "Step 1/6: Checking environment..."
# Check Node.js
if ! command -v node &> /dev/null; then
error "Node.js is not installed"
exit 1
fi
NODE_VERSION=$(node -v)
success "Node.js: $NODE_VERSION"
# Check npm
if ! command -v npm &> /dev/null; then
error "npm is not installed"
exit 1
fi
NPM_VERSION=$(npm -v)
success "npm: v$NPM_VERSION"
# Check pnpm (optional, for local dev)
if command -v pnpm &> /dev/null; then
PNPM_VERSION=$(pnpm -v)
success "pnpm: v$PNPM_VERSION"
else
info "pnpm not found (optional)"
fi
# Check npm login status
if ! npm whoami &> /dev/null; then
error "Not logged in to npm. Run 'npm login' first."
exit 1
fi
NPM_USER=$(npm whoami)
success "Logged in as: ${BOLD}$NPM_USER${NC}"
# ========================================================================
# Step 2: Read Package Info
# ========================================================================
step "Step 2/6: Reading package info..."
PACKAGE_NAME=$(node -p "require('./package.json').name")
PACKAGE_VERSION=$(node -p "require('./package.json').version")
success "Package: ${BOLD}$PACKAGE_NAME${NC}"
success "Version: ${BOLD}$PACKAGE_VERSION${NC}"
# Check if version already exists on npm
if npm view "$PACKAGE_NAME@$PACKAGE_VERSION" version &> /dev/null; then
error "Version $PACKAGE_VERSION already exists on npm!"
echo ""
info "Current published versions:"
npm view "$PACKAGE_NAME" versions --json 2>/dev/null | tail -5
echo ""
warning "Please update the version in package.json before publishing."
exit 1
fi
success "Version $PACKAGE_VERSION is available"
# ========================================================================
# Step 3: Install Dependencies
# ========================================================================
step "Step 3/6: Installing dependencies..."
if command -v pnpm &> /dev/null; then
pnpm install --frozen-lockfile 2>/dev/null || pnpm install
else
npm ci 2>/dev/null || npm install
fi
success "Dependencies installed"
# ========================================================================
# Step 4: Run Tests
# ========================================================================
step "Step 4/6: Running tests..."
if [[ "$SKIP_TESTS" == true ]]; then
warning "Skipping tests (--skip-tests flag)"
else
if command -v pnpm &> /dev/null; then
pnpm test
else
npm test
fi
success "All tests passed"
fi
# ========================================================================
# Step 5: Build
# ========================================================================
step "Step 5/6: Building package..."
# Clean previous build
rm -rf dist
if command -v pnpm &> /dev/null; then
pnpm run build
else
npm run build
fi
success "Build completed"
# Verify build output
if [[ ! -f "dist/index.js" ]]; then
error "Build failed - dist/index.js not found"
exit 1
fi
if [[ ! -f "dist/index.d.ts" ]]; then
error "Build failed - dist/index.d.ts not found"
exit 1
fi
success "Build output verified"
# ========================================================================
# Step 6: Publish
# ========================================================================
step "Step 6/6: Publishing to npm..."
divider
echo -e "${CYAN}Package contents:${NC}"
npm pack --dry-run 2>&1 | head -30
divider
if [[ "$DRY_RUN" == true ]]; then
warning "DRY-RUN: Skipping actual publish"
echo ""
info "To publish for real, run without --dry-run flag"
else
echo ""
echo -e "${YELLOW}About to publish ${BOLD}$PACKAGE_NAME@$PACKAGE_VERSION${NC}${YELLOW} to npm${NC}"
echo -e "${DIM}Press Enter to continue, or Ctrl+C to cancel...${NC}"
read -r
npm publish --access public
echo ""
success "🎉 Successfully published ${BOLD}$PACKAGE_NAME@$PACKAGE_VERSION${NC} to npm!"
echo ""
echo -e "${GREEN}Install with:${NC}"
echo -e " ${CYAN}npm install $PACKAGE_NAME${NC}"
echo -e " ${CYAN}pnpm add $PACKAGE_NAME${NC}"
echo -e " ${CYAN}yarn add $PACKAGE_NAME${NC}"
echo ""
echo -e "${GREEN}View on npm:${NC}"
echo -e " ${CYAN}https://www.npmjs.com/package/$PACKAGE_NAME${NC}"
fi
divider
echo -e "${GREEN}${BOLD}✨ All done!${NC}"
}
# Run main function
main "$@"

View File

@ -0,0 +1,175 @@
import { beforeEach, describe, expect, it, vi } from "vitest";
import { DifyClient } from "./base";
import { ValidationError } from "../errors/dify-error";
import { createHttpClientWithSpies } from "../../tests/test-utils";
describe("DifyClient base", () => {
beforeEach(() => {
vi.restoreAllMocks();
});
it("getRoot calls root endpoint", async () => {
const { client, request } = createHttpClientWithSpies();
const dify = new DifyClient(client);
await dify.getRoot();
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/",
});
});
it("getApplicationParameters includes optional user", async () => {
const { client, request } = createHttpClientWithSpies();
const dify = new DifyClient(client);
await dify.getApplicationParameters();
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/parameters",
query: undefined,
});
await dify.getApplicationParameters("user-1");
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/parameters",
query: { user: "user-1" },
});
});
it("getMeta includes optional user", async () => {
const { client, request } = createHttpClientWithSpies();
const dify = new DifyClient(client);
await dify.getMeta("user-1");
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/meta",
query: { user: "user-1" },
});
});
it("getInfo and getSite support optional user", async () => {
const { client, request } = createHttpClientWithSpies();
const dify = new DifyClient(client);
await dify.getInfo();
await dify.getSite("user");
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/info",
query: undefined,
});
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/site",
query: { user: "user" },
});
});
it("messageFeedback builds payload from request object", async () => {
const { client, request } = createHttpClientWithSpies();
const dify = new DifyClient(client);
await dify.messageFeedback({
messageId: "msg",
user: "user",
rating: "like",
content: "good",
});
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/messages/msg/feedbacks",
data: { user: "user", rating: "like", content: "good" },
});
});
it("fileUpload appends user to form data", async () => {
const { client, request } = createHttpClientWithSpies();
const dify = new DifyClient(client);
const form = { append: vi.fn(), getHeaders: () => ({}) };
await dify.fileUpload(form, "user");
expect(form.append).toHaveBeenCalledWith("user", "user");
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/files/upload",
data: form,
});
});
it("filePreview uses arraybuffer response", async () => {
const { client, request } = createHttpClientWithSpies();
const dify = new DifyClient(client);
await dify.filePreview("file", "user", true);
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/files/file/preview",
query: { user: "user", as_attachment: "true" },
responseType: "arraybuffer",
});
});
it("audioToText appends user and sends form", async () => {
const { client, request } = createHttpClientWithSpies();
const dify = new DifyClient(client);
const form = { append: vi.fn(), getHeaders: () => ({}) };
await dify.audioToText(form, "user");
expect(form.append).toHaveBeenCalledWith("user", "user");
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/audio-to-text",
data: form,
});
});
it("textToAudio supports streaming and message id", async () => {
const { client, request, requestBinaryStream } = createHttpClientWithSpies();
const dify = new DifyClient(client);
await dify.textToAudio({
user: "user",
message_id: "msg",
streaming: true,
});
expect(requestBinaryStream).toHaveBeenCalledWith({
method: "POST",
path: "/text-to-audio",
data: {
user: "user",
message_id: "msg",
streaming: true,
},
});
await dify.textToAudio("hello", "user", false, "voice");
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/text-to-audio",
data: {
text: "hello",
user: "user",
streaming: false,
voice: "voice",
},
responseType: "arraybuffer",
});
});
it("textToAudio requires text or message id", async () => {
const { client } = createHttpClientWithSpies();
const dify = new DifyClient(client);
expect(() => dify.textToAudio({ user: "user" })).toThrow(ValidationError);
});
});

View File

@ -0,0 +1,284 @@
import type {
BinaryStream,
DifyClientConfig,
DifyResponse,
MessageFeedbackRequest,
QueryParams,
RequestMethod,
TextToAudioRequest,
} from "../types/common";
import { HttpClient } from "../http/client";
import { ensureNonEmptyString, ensureRating } from "./validation";
import { FileUploadError, ValidationError } from "../errors/dify-error";
import { isFormData } from "../http/form-data";
const toConfig = (
init: string | DifyClientConfig,
baseUrl?: string
): DifyClientConfig => {
if (typeof init === "string") {
return {
apiKey: init,
baseUrl,
};
}
return init;
};
const appendUserToFormData = (form: unknown, user: string): void => {
if (!isFormData(form)) {
throw new FileUploadError("FormData is required for file uploads");
}
if (typeof form.append === "function") {
form.append("user", user);
}
};
export class DifyClient {
protected http: HttpClient;
constructor(config: string | DifyClientConfig | HttpClient, baseUrl?: string) {
if (config instanceof HttpClient) {
this.http = config;
} else {
this.http = new HttpClient(toConfig(config, baseUrl));
}
}
updateApiKey(apiKey: string): void {
ensureNonEmptyString(apiKey, "apiKey");
this.http.updateApiKey(apiKey);
}
getHttpClient(): HttpClient {
return this.http;
}
sendRequest(
method: RequestMethod,
endpoint: string,
data: unknown = null,
params: QueryParams | null = null,
stream = false,
headerParams: Record<string, string> = {}
): ReturnType<HttpClient["requestRaw"]> {
return this.http.requestRaw({
method,
path: endpoint,
data,
query: params ?? undefined,
headers: headerParams,
responseType: stream ? "stream" : "json",
});
}
getRoot(): Promise<DifyResponse<unknown>> {
return this.http.request({
method: "GET",
path: "/",
});
}
getApplicationParameters(user?: string): Promise<DifyResponse<unknown>> {
if (user) {
ensureNonEmptyString(user, "user");
}
return this.http.request({
method: "GET",
path: "/parameters",
query: user ? { user } : undefined,
});
}
async getParameters(user?: string): Promise<DifyResponse<unknown>> {
return this.getApplicationParameters(user);
}
getMeta(user?: string): Promise<DifyResponse<unknown>> {
if (user) {
ensureNonEmptyString(user, "user");
}
return this.http.request({
method: "GET",
path: "/meta",
query: user ? { user } : undefined,
});
}
messageFeedback(
request: MessageFeedbackRequest
): Promise<DifyResponse<Record<string, unknown>>>;
messageFeedback(
messageId: string,
rating: "like" | "dislike" | null,
user: string,
content?: string
): Promise<DifyResponse<Record<string, unknown>>>;
messageFeedback(
messageIdOrRequest: string | MessageFeedbackRequest,
rating?: "like" | "dislike" | null,
user?: string,
content?: string
): Promise<DifyResponse<Record<string, unknown>>> {
let messageId: string;
const payload: Record<string, unknown> = {};
if (typeof messageIdOrRequest === "string") {
messageId = messageIdOrRequest;
ensureNonEmptyString(messageId, "messageId");
ensureNonEmptyString(user, "user");
payload.user = user;
if (rating !== undefined && rating !== null) {
ensureRating(rating);
payload.rating = rating;
}
if (content !== undefined) {
payload.content = content;
}
} else {
const request = messageIdOrRequest;
messageId = request.messageId;
ensureNonEmptyString(messageId, "messageId");
ensureNonEmptyString(request.user, "user");
payload.user = request.user;
if (request.rating !== undefined && request.rating !== null) {
ensureRating(request.rating);
payload.rating = request.rating;
}
if (request.content !== undefined) {
payload.content = request.content;
}
}
return this.http.request({
method: "POST",
path: `/messages/${messageId}/feedbacks`,
data: payload,
});
}
getInfo(user?: string): Promise<DifyResponse<unknown>> {
if (user) {
ensureNonEmptyString(user, "user");
}
return this.http.request({
method: "GET",
path: "/info",
query: user ? { user } : undefined,
});
}
getSite(user?: string): Promise<DifyResponse<unknown>> {
if (user) {
ensureNonEmptyString(user, "user");
}
return this.http.request({
method: "GET",
path: "/site",
query: user ? { user } : undefined,
});
}
fileUpload(form: unknown, user: string): Promise<DifyResponse<unknown>> {
if (!isFormData(form)) {
throw new FileUploadError("FormData is required for file uploads");
}
ensureNonEmptyString(user, "user");
appendUserToFormData(form, user);
return this.http.request({
method: "POST",
path: "/files/upload",
data: form,
});
}
filePreview(
fileId: string,
user: string,
asAttachment?: boolean
): Promise<DifyResponse<Buffer>> {
ensureNonEmptyString(fileId, "fileId");
ensureNonEmptyString(user, "user");
return this.http.request<Buffer>({
method: "GET",
path: `/files/${fileId}/preview`,
query: {
user,
as_attachment: asAttachment ? "true" : undefined,
},
responseType: "arraybuffer",
});
}
audioToText(form: unknown, user: string): Promise<DifyResponse<unknown>> {
if (!isFormData(form)) {
throw new FileUploadError("FormData is required for audio uploads");
}
ensureNonEmptyString(user, "user");
appendUserToFormData(form, user);
return this.http.request({
method: "POST",
path: "/audio-to-text",
data: form,
});
}
textToAudio(
request: TextToAudioRequest
): Promise<DifyResponse<Buffer> | BinaryStream>;
textToAudio(
text: string,
user: string,
streaming?: boolean,
voice?: string
): Promise<DifyResponse<Buffer> | BinaryStream>;
textToAudio(
textOrRequest: string | TextToAudioRequest,
user?: string,
streaming = false,
voice?: string
): Promise<DifyResponse<Buffer> | BinaryStream> {
let payload: TextToAudioRequest;
if (typeof textOrRequest === "string") {
ensureNonEmptyString(textOrRequest, "text");
ensureNonEmptyString(user, "user");
payload = {
text: textOrRequest,
user,
streaming,
};
if (voice) {
payload.voice = voice;
}
} else {
payload = { ...textOrRequest };
ensureNonEmptyString(payload.user, "user");
if (payload.text !== undefined && payload.text !== null) {
ensureNonEmptyString(payload.text, "text");
}
if (payload.message_id !== undefined && payload.message_id !== null) {
ensureNonEmptyString(payload.message_id, "messageId");
}
if (!payload.text && !payload.message_id) {
throw new ValidationError("text or message_id is required");
}
payload.streaming = payload.streaming ?? false;
}
if (payload.streaming) {
return this.http.requestBinaryStream({
method: "POST",
path: "/text-to-audio",
data: payload,
});
}
return this.http.request<Buffer>({
method: "POST",
path: "/text-to-audio",
data: payload,
responseType: "arraybuffer",
});
}
}

View File

@ -0,0 +1,239 @@
import { beforeEach, describe, expect, it, vi } from "vitest";
import { ChatClient } from "./chat";
import { ValidationError } from "../errors/dify-error";
import { createHttpClientWithSpies } from "../../tests/test-utils";
describe("ChatClient", () => {
beforeEach(() => {
vi.restoreAllMocks();
});
it("creates chat messages in blocking mode", async () => {
const { client, request } = createHttpClientWithSpies();
const chat = new ChatClient(client);
await chat.createChatMessage({ input: "x" }, "hello", "user", false, null);
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/chat-messages",
data: {
inputs: { input: "x" },
query: "hello",
user: "user",
response_mode: "blocking",
files: undefined,
},
});
});
it("creates chat messages in streaming mode", async () => {
const { client, requestStream } = createHttpClientWithSpies();
const chat = new ChatClient(client);
await chat.createChatMessage({
inputs: { input: "x" },
query: "hello",
user: "user",
response_mode: "streaming",
});
expect(requestStream).toHaveBeenCalledWith({
method: "POST",
path: "/chat-messages",
data: {
inputs: { input: "x" },
query: "hello",
user: "user",
response_mode: "streaming",
},
});
});
it("stops chat messages", async () => {
const { client, request } = createHttpClientWithSpies();
const chat = new ChatClient(client);
await chat.stopChatMessage("task", "user");
await chat.stopMessage("task", "user");
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/chat-messages/task/stop",
data: { user: "user" },
});
});
it("gets suggested questions", async () => {
const { client, request } = createHttpClientWithSpies();
const chat = new ChatClient(client);
await chat.getSuggested("msg", "user");
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/messages/msg/suggested",
query: { user: "user" },
});
});
it("submits message feedback", async () => {
const { client, request } = createHttpClientWithSpies();
const chat = new ChatClient(client);
await chat.messageFeedback("msg", "like", "user", "good");
await chat.messageFeedback({
messageId: "msg",
user: "user",
rating: "dislike",
});
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/messages/msg/feedbacks",
data: { user: "user", rating: "like", content: "good" },
});
});
it("lists app feedbacks", async () => {
const { client, request } = createHttpClientWithSpies();
const chat = new ChatClient(client);
await chat.getAppFeedbacks(2, 5);
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/app/feedbacks",
query: { page: 2, limit: 5 },
});
});
it("lists conversations and messages", async () => {
const { client, request } = createHttpClientWithSpies();
const chat = new ChatClient(client);
await chat.getConversations("user", "last", 10, "-updated_at");
await chat.getConversationMessages("user", "conv", "first", 5);
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/conversations",
query: {
user: "user",
last_id: "last",
limit: 10,
sort_by: "-updated_at",
},
});
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/messages",
query: {
user: "user",
conversation_id: "conv",
first_id: "first",
limit: 5,
},
});
});
it("renames conversations with optional auto-generate", async () => {
const { client, request } = createHttpClientWithSpies();
const chat = new ChatClient(client);
await chat.renameConversation("conv", "name", "user", false);
await chat.renameConversation("conv", "user", { autoGenerate: true });
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/conversations/conv/name",
data: { user: "user", auto_generate: false, name: "name" },
});
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/conversations/conv/name",
data: { user: "user", auto_generate: true },
});
});
it("requires name when autoGenerate is false", async () => {
const { client } = createHttpClientWithSpies();
const chat = new ChatClient(client);
expect(() =>
chat.renameConversation("conv", "", "user", false)
).toThrow(ValidationError);
});
it("deletes conversations", async () => {
const { client, request } = createHttpClientWithSpies();
const chat = new ChatClient(client);
await chat.deleteConversation("conv", "user");
expect(request).toHaveBeenCalledWith({
method: "DELETE",
path: "/conversations/conv",
data: { user: "user" },
});
});
it("manages conversation variables", async () => {
const { client, request } = createHttpClientWithSpies();
const chat = new ChatClient(client);
await chat.getConversationVariables("conv", "user", "last", 10, "name");
await chat.updateConversationVariable("conv", "var", "user", "value");
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/conversations/conv/variables",
query: {
user: "user",
last_id: "last",
limit: 10,
variable_name: "name",
},
});
expect(request).toHaveBeenCalledWith({
method: "PUT",
path: "/conversations/conv/variables/var",
data: { user: "user", value: "value" },
});
});
it("handles annotation APIs", async () => {
const { client, request } = createHttpClientWithSpies();
const chat = new ChatClient(client);
await chat.annotationReplyAction("enable", {
score_threshold: 0.5,
embedding_provider_name: "prov",
embedding_model_name: "model",
});
await chat.getAnnotationReplyStatus("enable", "job");
await chat.listAnnotations({ page: 1, limit: 10, keyword: "k" });
await chat.createAnnotation({ question: "q", answer: "a" });
await chat.updateAnnotation("id", { question: "q", answer: "a" });
await chat.deleteAnnotation("id");
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/apps/annotation-reply/enable",
data: {
score_threshold: 0.5,
embedding_provider_name: "prov",
embedding_model_name: "model",
},
});
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/apps/annotation-reply/enable/status/job",
});
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/apps/annotations",
query: { page: 1, limit: 10, keyword: "k" },
});
});
});

View File

@ -0,0 +1,377 @@
import { DifyClient } from "./base";
import type { ChatMessageRequest, ChatMessageResponse } from "../types/chat";
import type {
AnnotationCreateRequest,
AnnotationListOptions,
AnnotationReplyActionRequest,
AnnotationResponse,
} from "../types/annotation";
import type {
DifyResponse,
DifyStream,
QueryParams,
} from "../types/common";
import {
ensureNonEmptyString,
ensureOptionalInt,
ensureOptionalString,
} from "./validation";
export class ChatClient extends DifyClient {
createChatMessage(
request: ChatMessageRequest
): Promise<DifyResponse<ChatMessageResponse> | DifyStream<ChatMessageResponse>>;
createChatMessage(
inputs: Record<string, unknown>,
query: string,
user: string,
stream?: boolean,
conversationId?: string | null,
files?: Array<Record<string, unknown>> | null
): Promise<DifyResponse<ChatMessageResponse> | DifyStream<ChatMessageResponse>>;
createChatMessage(
inputOrRequest: ChatMessageRequest | Record<string, unknown>,
query?: string,
user?: string,
stream = false,
conversationId?: string | null,
files?: Array<Record<string, unknown>> | null
): Promise<DifyResponse<ChatMessageResponse> | DifyStream<ChatMessageResponse>> {
let payload: ChatMessageRequest;
let shouldStream = stream;
if (query === undefined && "user" in (inputOrRequest as ChatMessageRequest)) {
payload = inputOrRequest as ChatMessageRequest;
shouldStream = payload.response_mode === "streaming";
} else {
ensureNonEmptyString(query, "query");
ensureNonEmptyString(user, "user");
payload = {
inputs: inputOrRequest as Record<string, unknown>,
query,
user,
response_mode: stream ? "streaming" : "blocking",
files,
};
if (conversationId) {
payload.conversation_id = conversationId;
}
}
ensureNonEmptyString(payload.user, "user");
ensureNonEmptyString(payload.query, "query");
if (shouldStream) {
return this.http.requestStream<ChatMessageResponse>({
method: "POST",
path: "/chat-messages",
data: payload,
});
}
return this.http.request<ChatMessageResponse>({
method: "POST",
path: "/chat-messages",
data: payload,
});
}
stopChatMessage(
taskId: string,
user: string
): Promise<DifyResponse<ChatMessageResponse>> {
ensureNonEmptyString(taskId, "taskId");
ensureNonEmptyString(user, "user");
return this.http.request<ChatMessageResponse>({
method: "POST",
path: `/chat-messages/${taskId}/stop`,
data: { user },
});
}
stopMessage(
taskId: string,
user: string
): Promise<DifyResponse<ChatMessageResponse>> {
return this.stopChatMessage(taskId, user);
}
getSuggested(
messageId: string,
user: string
): Promise<DifyResponse<ChatMessageResponse>> {
ensureNonEmptyString(messageId, "messageId");
ensureNonEmptyString(user, "user");
return this.http.request<ChatMessageResponse>({
method: "GET",
path: `/messages/${messageId}/suggested`,
query: { user },
});
}
// Note: messageFeedback is inherited from DifyClient
getAppFeedbacks(
page?: number,
limit?: number
): Promise<DifyResponse<Record<string, unknown>>> {
ensureOptionalInt(page, "page");
ensureOptionalInt(limit, "limit");
return this.http.request({
method: "GET",
path: "/app/feedbacks",
query: {
page,
limit,
},
});
}
getConversations(
user: string,
lastId?: string | null,
limit?: number | null,
sortByOrPinned?: string | boolean | null
): Promise<DifyResponse<Record<string, unknown>>> {
ensureNonEmptyString(user, "user");
ensureOptionalString(lastId, "lastId");
ensureOptionalInt(limit, "limit");
const params: QueryParams = { user };
if (lastId) {
params.last_id = lastId;
}
if (limit) {
params.limit = limit;
}
if (typeof sortByOrPinned === "string") {
params.sort_by = sortByOrPinned;
} else if (typeof sortByOrPinned === "boolean") {
params.pinned = sortByOrPinned;
}
return this.http.request({
method: "GET",
path: "/conversations",
query: params,
});
}
getConversationMessages(
user: string,
conversationId: string,
firstId?: string | null,
limit?: number | null
): Promise<DifyResponse<Record<string, unknown>>> {
ensureNonEmptyString(user, "user");
ensureNonEmptyString(conversationId, "conversationId");
ensureOptionalString(firstId, "firstId");
ensureOptionalInt(limit, "limit");
const params: QueryParams = { user };
params.conversation_id = conversationId;
if (firstId) {
params.first_id = firstId;
}
if (limit) {
params.limit = limit;
}
return this.http.request({
method: "GET",
path: "/messages",
query: params,
});
}
renameConversation(
conversationId: string,
name: string,
user: string,
autoGenerate?: boolean
): Promise<DifyResponse<Record<string, unknown>>>;
renameConversation(
conversationId: string,
user: string,
options?: { name?: string | null; autoGenerate?: boolean }
): Promise<DifyResponse<Record<string, unknown>>>;
renameConversation(
conversationId: string,
nameOrUser: string,
userOrOptions?: string | { name?: string | null; autoGenerate?: boolean },
autoGenerate?: boolean
): Promise<DifyResponse<Record<string, unknown>>> {
ensureNonEmptyString(conversationId, "conversationId");
let name: string | null | undefined;
let user: string;
let resolvedAutoGenerate: boolean;
if (typeof userOrOptions === "string" || userOrOptions === undefined) {
name = nameOrUser;
user = userOrOptions ?? "";
resolvedAutoGenerate = autoGenerate ?? false;
} else {
user = nameOrUser;
name = userOrOptions.name;
resolvedAutoGenerate = userOrOptions.autoGenerate ?? false;
}
ensureNonEmptyString(user, "user");
if (!resolvedAutoGenerate) {
ensureNonEmptyString(name, "name");
}
const payload: Record<string, unknown> = {
user,
auto_generate: resolvedAutoGenerate,
};
if (typeof name === "string" && name.trim().length > 0) {
payload.name = name;
}
return this.http.request({
method: "POST",
path: `/conversations/${conversationId}/name`,
data: payload,
});
}
deleteConversation(
conversationId: string,
user: string
): Promise<DifyResponse<Record<string, unknown>>> {
ensureNonEmptyString(conversationId, "conversationId");
ensureNonEmptyString(user, "user");
return this.http.request({
method: "DELETE",
path: `/conversations/${conversationId}`,
data: { user },
});
}
getConversationVariables(
conversationId: string,
user: string,
lastId?: string | null,
limit?: number | null,
variableName?: string | null
): Promise<DifyResponse<Record<string, unknown>>> {
ensureNonEmptyString(conversationId, "conversationId");
ensureNonEmptyString(user, "user");
ensureOptionalString(lastId, "lastId");
ensureOptionalInt(limit, "limit");
ensureOptionalString(variableName, "variableName");
return this.http.request({
method: "GET",
path: `/conversations/${conversationId}/variables`,
query: {
user,
last_id: lastId ?? undefined,
limit: limit ?? undefined,
variable_name: variableName ?? undefined,
},
});
}
updateConversationVariable(
conversationId: string,
variableId: string,
user: string,
value: unknown
): Promise<DifyResponse<Record<string, unknown>>> {
ensureNonEmptyString(conversationId, "conversationId");
ensureNonEmptyString(variableId, "variableId");
ensureNonEmptyString(user, "user");
return this.http.request({
method: "PUT",
path: `/conversations/${conversationId}/variables/${variableId}`,
data: {
user,
value,
},
});
}
annotationReplyAction(
action: "enable" | "disable",
request: AnnotationReplyActionRequest
): Promise<DifyResponse<AnnotationResponse>> {
ensureNonEmptyString(action, "action");
ensureNonEmptyString(request.embedding_provider_name, "embedding_provider_name");
ensureNonEmptyString(request.embedding_model_name, "embedding_model_name");
return this.http.request({
method: "POST",
path: `/apps/annotation-reply/${action}`,
data: request,
});
}
getAnnotationReplyStatus(
action: "enable" | "disable",
jobId: string
): Promise<DifyResponse<AnnotationResponse>> {
ensureNonEmptyString(action, "action");
ensureNonEmptyString(jobId, "jobId");
return this.http.request({
method: "GET",
path: `/apps/annotation-reply/${action}/status/${jobId}`,
});
}
listAnnotations(
options?: AnnotationListOptions
): Promise<DifyResponse<AnnotationResponse>> {
ensureOptionalInt(options?.page, "page");
ensureOptionalInt(options?.limit, "limit");
ensureOptionalString(options?.keyword, "keyword");
return this.http.request({
method: "GET",
path: "/apps/annotations",
query: {
page: options?.page,
limit: options?.limit,
keyword: options?.keyword ?? undefined,
},
});
}
createAnnotation(
request: AnnotationCreateRequest
): Promise<DifyResponse<AnnotationResponse>> {
ensureNonEmptyString(request.question, "question");
ensureNonEmptyString(request.answer, "answer");
return this.http.request({
method: "POST",
path: "/apps/annotations",
data: request,
});
}
updateAnnotation(
annotationId: string,
request: AnnotationCreateRequest
): Promise<DifyResponse<AnnotationResponse>> {
ensureNonEmptyString(annotationId, "annotationId");
ensureNonEmptyString(request.question, "question");
ensureNonEmptyString(request.answer, "answer");
return this.http.request({
method: "PUT",
path: `/apps/annotations/${annotationId}`,
data: request,
});
}
deleteAnnotation(
annotationId: string
): Promise<DifyResponse<AnnotationResponse>> {
ensureNonEmptyString(annotationId, "annotationId");
return this.http.request({
method: "DELETE",
path: `/apps/annotations/${annotationId}`,
});
}
// Note: audioToText is inherited from DifyClient
}

View File

@ -0,0 +1,83 @@
import { beforeEach, describe, expect, it, vi } from "vitest";
import { CompletionClient } from "./completion";
import { createHttpClientWithSpies } from "../../tests/test-utils";
describe("CompletionClient", () => {
beforeEach(() => {
vi.restoreAllMocks();
});
it("creates completion messages in blocking mode", async () => {
const { client, request } = createHttpClientWithSpies();
const completion = new CompletionClient(client);
await completion.createCompletionMessage({ input: "x" }, "user", false);
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/completion-messages",
data: {
inputs: { input: "x" },
user: "user",
files: undefined,
response_mode: "blocking",
},
});
});
it("creates completion messages in streaming mode", async () => {
const { client, requestStream } = createHttpClientWithSpies();
const completion = new CompletionClient(client);
await completion.createCompletionMessage({
inputs: { input: "x" },
user: "user",
response_mode: "streaming",
});
expect(requestStream).toHaveBeenCalledWith({
method: "POST",
path: "/completion-messages",
data: {
inputs: { input: "x" },
user: "user",
response_mode: "streaming",
},
});
});
it("stops completion messages", async () => {
const { client, request } = createHttpClientWithSpies();
const completion = new CompletionClient(client);
await completion.stopCompletionMessage("task", "user");
await completion.stop("task", "user");
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/completion-messages/task/stop",
data: { user: "user" },
});
});
it("supports deprecated runWorkflow", async () => {
const { client, request, requestStream } = createHttpClientWithSpies();
const completion = new CompletionClient(client);
const warn = vi.spyOn(console, "warn").mockImplementation(() => {});
await completion.runWorkflow({ input: "x" }, "user", false);
await completion.runWorkflow({ input: "x" }, "user", true);
expect(warn).toHaveBeenCalled();
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/workflows/run",
data: { inputs: { input: "x" }, user: "user", response_mode: "blocking" },
});
expect(requestStream).toHaveBeenCalledWith({
method: "POST",
path: "/workflows/run",
data: { inputs: { input: "x" }, user: "user", response_mode: "streaming" },
});
});
});

View File

@ -0,0 +1,111 @@
import { DifyClient } from "./base";
import type { CompletionRequest, CompletionResponse } from "../types/completion";
import type { DifyResponse, DifyStream } from "../types/common";
import { ensureNonEmptyString } from "./validation";
const warned = new Set<string>();
const warnOnce = (message: string): void => {
if (warned.has(message)) {
return;
}
warned.add(message);
console.warn(message);
};
export class CompletionClient extends DifyClient {
createCompletionMessage(
request: CompletionRequest
): Promise<DifyResponse<CompletionResponse> | DifyStream<CompletionResponse>>;
createCompletionMessage(
inputs: Record<string, unknown>,
user: string,
stream?: boolean,
files?: Array<Record<string, unknown>> | null
): Promise<DifyResponse<CompletionResponse> | DifyStream<CompletionResponse>>;
createCompletionMessage(
inputOrRequest: CompletionRequest | Record<string, unknown>,
user?: string,
stream = false,
files?: Array<Record<string, unknown>> | null
): Promise<DifyResponse<CompletionResponse> | DifyStream<CompletionResponse>> {
let payload: CompletionRequest;
let shouldStream = stream;
if (user === undefined && "user" in (inputOrRequest as CompletionRequest)) {
payload = inputOrRequest as CompletionRequest;
shouldStream = payload.response_mode === "streaming";
} else {
ensureNonEmptyString(user, "user");
payload = {
inputs: inputOrRequest as Record<string, unknown>,
user,
files,
response_mode: stream ? "streaming" : "blocking",
};
}
ensureNonEmptyString(payload.user, "user");
if (shouldStream) {
return this.http.requestStream<CompletionResponse>({
method: "POST",
path: "/completion-messages",
data: payload,
});
}
return this.http.request<CompletionResponse>({
method: "POST",
path: "/completion-messages",
data: payload,
});
}
stopCompletionMessage(
taskId: string,
user: string
): Promise<DifyResponse<CompletionResponse>> {
ensureNonEmptyString(taskId, "taskId");
ensureNonEmptyString(user, "user");
return this.http.request<CompletionResponse>({
method: "POST",
path: `/completion-messages/${taskId}/stop`,
data: { user },
});
}
stop(
taskId: string,
user: string
): Promise<DifyResponse<CompletionResponse>> {
return this.stopCompletionMessage(taskId, user);
}
runWorkflow(
inputs: Record<string, unknown>,
user: string,
stream = false
): Promise<DifyResponse<Record<string, unknown>> | DifyStream<Record<string, unknown>>> {
warnOnce(
"CompletionClient.runWorkflow is deprecated. Use WorkflowClient.run instead."
);
ensureNonEmptyString(user, "user");
const payload = {
inputs,
user,
response_mode: stream ? "streaming" : "blocking",
};
if (stream) {
return this.http.requestStream<Record<string, unknown>>({
method: "POST",
path: "/workflows/run",
data: payload,
});
}
return this.http.request<Record<string, unknown>>({
method: "POST",
path: "/workflows/run",
data: payload,
});
}
}

View File

@ -0,0 +1,249 @@
import { beforeEach, describe, expect, it, vi } from "vitest";
import { KnowledgeBaseClient } from "./knowledge-base";
import { createHttpClientWithSpies } from "../../tests/test-utils";
describe("KnowledgeBaseClient", () => {
beforeEach(() => {
vi.restoreAllMocks();
});
it("handles dataset and tag operations", async () => {
const { client, request } = createHttpClientWithSpies();
const kb = new KnowledgeBaseClient(client);
await kb.listDatasets({
page: 1,
limit: 2,
keyword: "k",
includeAll: true,
tagIds: ["t1"],
});
await kb.createDataset({ name: "dataset" });
await kb.getDataset("ds");
await kb.updateDataset("ds", { name: "new" });
await kb.deleteDataset("ds");
await kb.updateDocumentStatus("ds", "enable", ["doc1"]);
await kb.listTags();
await kb.createTag({ name: "tag" });
await kb.updateTag({ tag_id: "tag", name: "name" });
await kb.deleteTag({ tag_id: "tag" });
await kb.bindTags({ tag_ids: ["tag"], target_id: "doc" });
await kb.unbindTags({ tag_id: "tag", target_id: "doc" });
await kb.getDatasetTags("ds");
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/datasets",
query: {
page: 1,
limit: 2,
keyword: "k",
include_all: true,
tag_ids: ["t1"],
},
});
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/datasets",
data: { name: "dataset" },
});
expect(request).toHaveBeenCalledWith({
method: "PATCH",
path: "/datasets/ds",
data: { name: "new" },
});
expect(request).toHaveBeenCalledWith({
method: "PATCH",
path: "/datasets/ds/documents/status/enable",
data: { document_ids: ["doc1"] },
});
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/datasets/tags/binding",
data: { tag_ids: ["tag"], target_id: "doc" },
});
});
it("handles document operations", async () => {
const { client, request } = createHttpClientWithSpies();
const kb = new KnowledgeBaseClient(client);
const form = { append: vi.fn(), getHeaders: () => ({}) };
await kb.createDocumentByText("ds", { name: "doc", text: "text" });
await kb.updateDocumentByText("ds", "doc", { name: "doc2" });
await kb.createDocumentByFile("ds", form);
await kb.updateDocumentByFile("ds", "doc", form);
await kb.listDocuments("ds", { page: 1, limit: 20, keyword: "k" });
await kb.getDocument("ds", "doc", { metadata: "all" });
await kb.deleteDocument("ds", "doc");
await kb.getDocumentIndexingStatus("ds", "batch");
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/datasets/ds/document/create_by_text",
data: { name: "doc", text: "text" },
});
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/datasets/ds/documents/doc/update_by_text",
data: { name: "doc2" },
});
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/datasets/ds/document/create_by_file",
data: form,
});
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/datasets/ds/documents",
query: { page: 1, limit: 20, keyword: "k", status: undefined },
});
});
it("handles segments and child chunks", async () => {
const { client, request } = createHttpClientWithSpies();
const kb = new KnowledgeBaseClient(client);
await kb.createSegments("ds", "doc", { segments: [{ content: "x" }] });
await kb.listSegments("ds", "doc", { page: 1, limit: 10, keyword: "k" });
await kb.getSegment("ds", "doc", "seg");
await kb.updateSegment("ds", "doc", "seg", {
segment: { content: "y" },
});
await kb.deleteSegment("ds", "doc", "seg");
await kb.createChildChunk("ds", "doc", "seg", { content: "c" });
await kb.listChildChunks("ds", "doc", "seg", { page: 1, limit: 10 });
await kb.updateChildChunk("ds", "doc", "seg", "child", {
content: "c2",
});
await kb.deleteChildChunk("ds", "doc", "seg", "child");
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/datasets/ds/documents/doc/segments",
data: { segments: [{ content: "x" }] },
});
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/datasets/ds/documents/doc/segments/seg",
data: { segment: { content: "y" } },
});
expect(request).toHaveBeenCalledWith({
method: "PATCH",
path: "/datasets/ds/documents/doc/segments/seg/child_chunks/child",
data: { content: "c2" },
});
});
it("handles metadata and retrieval", async () => {
const { client, request } = createHttpClientWithSpies();
const kb = new KnowledgeBaseClient(client);
await kb.listMetadata("ds");
await kb.createMetadata("ds", { name: "m", type: "string" });
await kb.updateMetadata("ds", "mid", { name: "m2" });
await kb.deleteMetadata("ds", "mid");
await kb.listBuiltInMetadata("ds");
await kb.updateBuiltInMetadata("ds", "enable");
await kb.updateDocumentsMetadata("ds", {
operation_data: [
{ document_id: "doc", metadata_list: [{ id: "m", name: "n" }] },
],
});
await kb.hitTesting("ds", { query: "q" });
await kb.retrieve("ds", { query: "q" });
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/datasets/ds/metadata",
});
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/datasets/ds/metadata",
data: { name: "m", type: "string" },
});
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/datasets/ds/hit-testing",
data: { query: "q" },
});
});
it("handles pipeline operations", async () => {
const { client, request, requestStream } = createHttpClientWithSpies();
const kb = new KnowledgeBaseClient(client);
const warn = vi.spyOn(console, "warn").mockImplementation(() => {});
const form = { append: vi.fn(), getHeaders: () => ({}) };
await kb.listDatasourcePlugins("ds", { isPublished: true });
await kb.runDatasourceNode("ds", "node", {
inputs: { input: "x" },
datasource_type: "custom",
is_published: true,
});
await kb.runPipeline("ds", {
inputs: { input: "x" },
datasource_type: "custom",
datasource_info_list: [],
start_node_id: "start",
is_published: true,
response_mode: "streaming",
});
await kb.runPipeline("ds", {
inputs: { input: "x" },
datasource_type: "custom",
datasource_info_list: [],
start_node_id: "start",
is_published: true,
response_mode: "blocking",
});
await kb.uploadPipelineFile(form);
expect(warn).toHaveBeenCalled();
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/datasets/ds/pipeline/datasource-plugins",
query: { is_published: true },
});
expect(requestStream).toHaveBeenCalledWith({
method: "POST",
path: "/datasets/ds/pipeline/datasource/nodes/node/run",
data: {
inputs: { input: "x" },
datasource_type: "custom",
is_published: true,
},
});
expect(requestStream).toHaveBeenCalledWith({
method: "POST",
path: "/datasets/ds/pipeline/run",
data: {
inputs: { input: "x" },
datasource_type: "custom",
datasource_info_list: [],
start_node_id: "start",
is_published: true,
response_mode: "streaming",
},
});
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/datasets/ds/pipeline/run",
data: {
inputs: { input: "x" },
datasource_type: "custom",
datasource_info_list: [],
start_node_id: "start",
is_published: true,
response_mode: "blocking",
},
});
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/datasets/pipeline/file-upload",
data: form,
});
});
});

View File

@ -0,0 +1,706 @@
import { DifyClient } from "./base";
import type {
DatasetCreateRequest,
DatasetListOptions,
DatasetTagBindingRequest,
DatasetTagCreateRequest,
DatasetTagDeleteRequest,
DatasetTagUnbindingRequest,
DatasetTagUpdateRequest,
DatasetUpdateRequest,
DocumentGetOptions,
DocumentListOptions,
DocumentStatusAction,
DocumentTextCreateRequest,
DocumentTextUpdateRequest,
SegmentCreateRequest,
SegmentListOptions,
SegmentUpdateRequest,
ChildChunkCreateRequest,
ChildChunkListOptions,
ChildChunkUpdateRequest,
MetadataCreateRequest,
MetadataOperationRequest,
MetadataUpdateRequest,
HitTestingRequest,
DatasourcePluginListOptions,
DatasourceNodeRunRequest,
PipelineRunRequest,
KnowledgeBaseResponse,
PipelineStreamEvent,
} from "../types/knowledge-base";
import type { DifyResponse, DifyStream, QueryParams } from "../types/common";
import {
ensureNonEmptyString,
ensureOptionalBoolean,
ensureOptionalInt,
ensureOptionalString,
ensureStringArray,
} from "./validation";
import { FileUploadError, ValidationError } from "../errors/dify-error";
import { isFormData } from "../http/form-data";
const warned = new Set<string>();
const warnOnce = (message: string): void => {
if (warned.has(message)) {
return;
}
warned.add(message);
console.warn(message);
};
const ensureFormData = (form: unknown, context: string): void => {
if (!isFormData(form)) {
throw new FileUploadError(`${context} requires FormData`);
}
};
const ensureNonEmptyArray = (value: unknown, name: string): void => {
if (!Array.isArray(value) || value.length === 0) {
throw new ValidationError(`${name} must be a non-empty array`);
}
};
const warnPipelineRoutes = (): void => {
warnOnce(
"RAG pipeline endpoints may be unavailable unless the service API registers dataset/rag_pipeline routes."
);
};
export class KnowledgeBaseClient extends DifyClient {
async listDatasets(
options?: DatasetListOptions
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureOptionalInt(options?.page, "page");
ensureOptionalInt(options?.limit, "limit");
ensureOptionalString(options?.keyword, "keyword");
ensureOptionalBoolean(options?.includeAll, "includeAll");
const query: QueryParams = {
page: options?.page,
limit: options?.limit,
keyword: options?.keyword ?? undefined,
include_all: options?.includeAll ?? undefined,
};
if (options?.tagIds && options.tagIds.length > 0) {
ensureStringArray(options.tagIds, "tagIds");
query.tag_ids = options.tagIds;
}
return this.http.request({
method: "GET",
path: "/datasets",
query,
});
}
async createDataset(
request: DatasetCreateRequest
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(request.name, "name");
return this.http.request({
method: "POST",
path: "/datasets",
data: request,
});
}
async getDataset(datasetId: string): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
return this.http.request({
method: "GET",
path: `/datasets/${datasetId}`,
});
}
async updateDataset(
datasetId: string,
request: DatasetUpdateRequest
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
if (request.name !== undefined && request.name !== null) {
ensureNonEmptyString(request.name, "name");
}
return this.http.request({
method: "PATCH",
path: `/datasets/${datasetId}`,
data: request,
});
}
async deleteDataset(datasetId: string): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
return this.http.request({
method: "DELETE",
path: `/datasets/${datasetId}`,
});
}
async updateDocumentStatus(
datasetId: string,
action: DocumentStatusAction,
documentIds: string[]
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(action, "action");
ensureStringArray(documentIds, "documentIds");
return this.http.request({
method: "PATCH",
path: `/datasets/${datasetId}/documents/status/${action}`,
data: {
document_ids: documentIds,
},
});
}
async listTags(): Promise<DifyResponse<KnowledgeBaseResponse>> {
return this.http.request({
method: "GET",
path: "/datasets/tags",
});
}
async createTag(
request: DatasetTagCreateRequest
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(request.name, "name");
return this.http.request({
method: "POST",
path: "/datasets/tags",
data: request,
});
}
async updateTag(
request: DatasetTagUpdateRequest
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(request.tag_id, "tag_id");
ensureNonEmptyString(request.name, "name");
return this.http.request({
method: "PATCH",
path: "/datasets/tags",
data: request,
});
}
async deleteTag(
request: DatasetTagDeleteRequest
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(request.tag_id, "tag_id");
return this.http.request({
method: "DELETE",
path: "/datasets/tags",
data: request,
});
}
async bindTags(
request: DatasetTagBindingRequest
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureStringArray(request.tag_ids, "tag_ids");
ensureNonEmptyString(request.target_id, "target_id");
return this.http.request({
method: "POST",
path: "/datasets/tags/binding",
data: request,
});
}
async unbindTags(
request: DatasetTagUnbindingRequest
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(request.tag_id, "tag_id");
ensureNonEmptyString(request.target_id, "target_id");
return this.http.request({
method: "POST",
path: "/datasets/tags/unbinding",
data: request,
});
}
async getDatasetTags(
datasetId: string
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
return this.http.request({
method: "GET",
path: `/datasets/${datasetId}/tags`,
});
}
async createDocumentByText(
datasetId: string,
request: DocumentTextCreateRequest
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(request.name, "name");
ensureNonEmptyString(request.text, "text");
return this.http.request({
method: "POST",
path: `/datasets/${datasetId}/document/create_by_text`,
data: request,
});
}
async updateDocumentByText(
datasetId: string,
documentId: string,
request: DocumentTextUpdateRequest
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(documentId, "documentId");
if (request.name !== undefined && request.name !== null) {
ensureNonEmptyString(request.name, "name");
}
return this.http.request({
method: "POST",
path: `/datasets/${datasetId}/documents/${documentId}/update_by_text`,
data: request,
});
}
async createDocumentByFile(
datasetId: string,
form: unknown
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureFormData(form, "createDocumentByFile");
return this.http.request({
method: "POST",
path: `/datasets/${datasetId}/document/create_by_file`,
data: form,
});
}
async updateDocumentByFile(
datasetId: string,
documentId: string,
form: unknown
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(documentId, "documentId");
ensureFormData(form, "updateDocumentByFile");
return this.http.request({
method: "POST",
path: `/datasets/${datasetId}/documents/${documentId}/update_by_file`,
data: form,
});
}
async listDocuments(
datasetId: string,
options?: DocumentListOptions
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureOptionalInt(options?.page, "page");
ensureOptionalInt(options?.limit, "limit");
ensureOptionalString(options?.keyword, "keyword");
ensureOptionalString(options?.status, "status");
return this.http.request({
method: "GET",
path: `/datasets/${datasetId}/documents`,
query: {
page: options?.page,
limit: options?.limit,
keyword: options?.keyword ?? undefined,
status: options?.status ?? undefined,
},
});
}
async getDocument(
datasetId: string,
documentId: string,
options?: DocumentGetOptions
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(documentId, "documentId");
if (options?.metadata) {
const allowed = new Set(["all", "only", "without"]);
if (!allowed.has(options.metadata)) {
throw new ValidationError("metadata must be one of all, only, without");
}
}
return this.http.request({
method: "GET",
path: `/datasets/${datasetId}/documents/${documentId}`,
query: {
metadata: options?.metadata ?? undefined,
},
});
}
async deleteDocument(
datasetId: string,
documentId: string
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(documentId, "documentId");
return this.http.request({
method: "DELETE",
path: `/datasets/${datasetId}/documents/${documentId}`,
});
}
async getDocumentIndexingStatus(
datasetId: string,
batch: string
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(batch, "batch");
return this.http.request({
method: "GET",
path: `/datasets/${datasetId}/documents/${batch}/indexing-status`,
});
}
async createSegments(
datasetId: string,
documentId: string,
request: SegmentCreateRequest
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(documentId, "documentId");
ensureNonEmptyArray(request.segments, "segments");
return this.http.request({
method: "POST",
path: `/datasets/${datasetId}/documents/${documentId}/segments`,
data: request,
});
}
async listSegments(
datasetId: string,
documentId: string,
options?: SegmentListOptions
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(documentId, "documentId");
ensureOptionalInt(options?.page, "page");
ensureOptionalInt(options?.limit, "limit");
ensureOptionalString(options?.keyword, "keyword");
if (options?.status && options.status.length > 0) {
ensureStringArray(options.status, "status");
}
const query: QueryParams = {
page: options?.page,
limit: options?.limit,
keyword: options?.keyword ?? undefined,
};
if (options?.status && options.status.length > 0) {
query.status = options.status;
}
return this.http.request({
method: "GET",
path: `/datasets/${datasetId}/documents/${documentId}/segments`,
query,
});
}
async getSegment(
datasetId: string,
documentId: string,
segmentId: string
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(documentId, "documentId");
ensureNonEmptyString(segmentId, "segmentId");
return this.http.request({
method: "GET",
path: `/datasets/${datasetId}/documents/${documentId}/segments/${segmentId}`,
});
}
async updateSegment(
datasetId: string,
documentId: string,
segmentId: string,
request: SegmentUpdateRequest
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(documentId, "documentId");
ensureNonEmptyString(segmentId, "segmentId");
return this.http.request({
method: "POST",
path: `/datasets/${datasetId}/documents/${documentId}/segments/${segmentId}`,
data: request,
});
}
async deleteSegment(
datasetId: string,
documentId: string,
segmentId: string
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(documentId, "documentId");
ensureNonEmptyString(segmentId, "segmentId");
return this.http.request({
method: "DELETE",
path: `/datasets/${datasetId}/documents/${documentId}/segments/${segmentId}`,
});
}
async createChildChunk(
datasetId: string,
documentId: string,
segmentId: string,
request: ChildChunkCreateRequest
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(documentId, "documentId");
ensureNonEmptyString(segmentId, "segmentId");
ensureNonEmptyString(request.content, "content");
return this.http.request({
method: "POST",
path: `/datasets/${datasetId}/documents/${documentId}/segments/${segmentId}/child_chunks`,
data: request,
});
}
async listChildChunks(
datasetId: string,
documentId: string,
segmentId: string,
options?: ChildChunkListOptions
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(documentId, "documentId");
ensureNonEmptyString(segmentId, "segmentId");
ensureOptionalInt(options?.page, "page");
ensureOptionalInt(options?.limit, "limit");
ensureOptionalString(options?.keyword, "keyword");
return this.http.request({
method: "GET",
path: `/datasets/${datasetId}/documents/${documentId}/segments/${segmentId}/child_chunks`,
query: {
page: options?.page,
limit: options?.limit,
keyword: options?.keyword ?? undefined,
},
});
}
async updateChildChunk(
datasetId: string,
documentId: string,
segmentId: string,
childChunkId: string,
request: ChildChunkUpdateRequest
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(documentId, "documentId");
ensureNonEmptyString(segmentId, "segmentId");
ensureNonEmptyString(childChunkId, "childChunkId");
ensureNonEmptyString(request.content, "content");
return this.http.request({
method: "PATCH",
path: `/datasets/${datasetId}/documents/${documentId}/segments/${segmentId}/child_chunks/${childChunkId}`,
data: request,
});
}
async deleteChildChunk(
datasetId: string,
documentId: string,
segmentId: string,
childChunkId: string
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(documentId, "documentId");
ensureNonEmptyString(segmentId, "segmentId");
ensureNonEmptyString(childChunkId, "childChunkId");
return this.http.request({
method: "DELETE",
path: `/datasets/${datasetId}/documents/${documentId}/segments/${segmentId}/child_chunks/${childChunkId}`,
});
}
async listMetadata(
datasetId: string
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
return this.http.request({
method: "GET",
path: `/datasets/${datasetId}/metadata`,
});
}
async createMetadata(
datasetId: string,
request: MetadataCreateRequest
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(request.name, "name");
ensureNonEmptyString(request.type, "type");
return this.http.request({
method: "POST",
path: `/datasets/${datasetId}/metadata`,
data: request,
});
}
async updateMetadata(
datasetId: string,
metadataId: string,
request: MetadataUpdateRequest
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(metadataId, "metadataId");
ensureNonEmptyString(request.name, "name");
return this.http.request({
method: "PATCH",
path: `/datasets/${datasetId}/metadata/${metadataId}`,
data: request,
});
}
async deleteMetadata(
datasetId: string,
metadataId: string
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(metadataId, "metadataId");
return this.http.request({
method: "DELETE",
path: `/datasets/${datasetId}/metadata/${metadataId}`,
});
}
async listBuiltInMetadata(
datasetId: string
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
return this.http.request({
method: "GET",
path: `/datasets/${datasetId}/metadata/built-in`,
});
}
async updateBuiltInMetadata(
datasetId: string,
action: "enable" | "disable"
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(action, "action");
return this.http.request({
method: "POST",
path: `/datasets/${datasetId}/metadata/built-in/${action}`,
});
}
async updateDocumentsMetadata(
datasetId: string,
request: MetadataOperationRequest
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyArray(request.operation_data, "operation_data");
return this.http.request({
method: "POST",
path: `/datasets/${datasetId}/documents/metadata`,
data: request,
});
}
async hitTesting(
datasetId: string,
request: HitTestingRequest
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
if (request.query !== undefined && request.query !== null) {
ensureOptionalString(request.query, "query");
}
if (request.attachment_ids && request.attachment_ids.length > 0) {
ensureStringArray(request.attachment_ids, "attachment_ids");
}
return this.http.request({
method: "POST",
path: `/datasets/${datasetId}/hit-testing`,
data: request,
});
}
async retrieve(
datasetId: string,
request: HitTestingRequest
): Promise<DifyResponse<KnowledgeBaseResponse>> {
ensureNonEmptyString(datasetId, "datasetId");
return this.http.request({
method: "POST",
path: `/datasets/${datasetId}/retrieve`,
data: request,
});
}
async listDatasourcePlugins(
datasetId: string,
options?: DatasourcePluginListOptions
): Promise<DifyResponse<KnowledgeBaseResponse>> {
warnPipelineRoutes();
ensureNonEmptyString(datasetId, "datasetId");
ensureOptionalBoolean(options?.isPublished, "isPublished");
return this.http.request({
method: "GET",
path: `/datasets/${datasetId}/pipeline/datasource-plugins`,
query: {
is_published: options?.isPublished ?? undefined,
},
});
}
async runDatasourceNode(
datasetId: string,
nodeId: string,
request: DatasourceNodeRunRequest
): Promise<DifyStream<PipelineStreamEvent>> {
warnPipelineRoutes();
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(nodeId, "nodeId");
ensureNonEmptyString(request.datasource_type, "datasource_type");
return this.http.requestStream<PipelineStreamEvent>({
method: "POST",
path: `/datasets/${datasetId}/pipeline/datasource/nodes/${nodeId}/run`,
data: request,
});
}
async runPipeline(
datasetId: string,
request: PipelineRunRequest
): Promise<DifyResponse<KnowledgeBaseResponse> | DifyStream<PipelineStreamEvent>> {
warnPipelineRoutes();
ensureNonEmptyString(datasetId, "datasetId");
ensureNonEmptyString(request.datasource_type, "datasource_type");
ensureNonEmptyString(request.start_node_id, "start_node_id");
const shouldStream = request.response_mode === "streaming";
if (shouldStream) {
return this.http.requestStream<PipelineStreamEvent>({
method: "POST",
path: `/datasets/${datasetId}/pipeline/run`,
data: request,
});
}
return this.http.request<KnowledgeBaseResponse>({
method: "POST",
path: `/datasets/${datasetId}/pipeline/run`,
data: request,
});
}
async uploadPipelineFile(
form: unknown
): Promise<DifyResponse<KnowledgeBaseResponse>> {
warnPipelineRoutes();
ensureFormData(form, "uploadPipelineFile");
return this.http.request({
method: "POST",
path: "/datasets/pipeline/file-upload",
data: form,
});
}
}

View File

@ -0,0 +1,91 @@
import { describe, expect, it } from "vitest";
import {
ensureNonEmptyString,
ensureOptionalBoolean,
ensureOptionalInt,
ensureOptionalString,
ensureOptionalStringArray,
ensureRating,
ensureStringArray,
validateParams,
} from "./validation";
const makeLongString = (length) => "a".repeat(length);
describe("validation utilities", () => {
it("ensureNonEmptyString throws on empty or whitespace", () => {
expect(() => ensureNonEmptyString("", "name")).toThrow();
expect(() => ensureNonEmptyString(" ", "name")).toThrow();
});
it("ensureNonEmptyString throws on overly long strings", () => {
expect(() =>
ensureNonEmptyString(makeLongString(10001), "name")
).toThrow();
});
it("ensureOptionalString ignores undefined and validates when set", () => {
expect(() => ensureOptionalString(undefined, "opt")).not.toThrow();
expect(() => ensureOptionalString("", "opt")).toThrow();
});
it("ensureOptionalString throws on overly long strings", () => {
expect(() => ensureOptionalString(makeLongString(10001), "opt")).toThrow();
});
it("ensureOptionalInt validates integer", () => {
expect(() => ensureOptionalInt(undefined, "limit")).not.toThrow();
expect(() => ensureOptionalInt(1.2, "limit")).toThrow();
});
it("ensureOptionalBoolean validates boolean", () => {
expect(() => ensureOptionalBoolean(undefined, "flag")).not.toThrow();
expect(() => ensureOptionalBoolean("yes", "flag")).toThrow();
});
it("ensureStringArray enforces size and content", () => {
expect(() => ensureStringArray([], "items")).toThrow();
expect(() => ensureStringArray([""], "items")).toThrow();
expect(() =>
ensureStringArray(Array.from({ length: 1001 }, () => "a"), "items")
).toThrow();
expect(() => ensureStringArray(["ok"], "items")).not.toThrow();
});
it("ensureOptionalStringArray ignores undefined", () => {
expect(() => ensureOptionalStringArray(undefined, "tags")).not.toThrow();
});
it("ensureOptionalStringArray validates when set", () => {
expect(() => ensureOptionalStringArray(["valid"], "tags")).not.toThrow();
expect(() => ensureOptionalStringArray([], "tags")).toThrow();
expect(() => ensureOptionalStringArray([""], "tags")).toThrow();
});
it("ensureRating validates allowed values", () => {
expect(() => ensureRating(undefined)).not.toThrow();
expect(() => ensureRating("like")).not.toThrow();
expect(() => ensureRating("bad")).toThrow();
});
it("validateParams enforces generic rules", () => {
expect(() => validateParams({ user: 123 })).toThrow();
expect(() => validateParams({ rating: "bad" })).toThrow();
expect(() => validateParams({ page: 1.1 })).toThrow();
expect(() => validateParams({ files: "bad" })).toThrow();
// Empty strings are allowed for optional params (e.g., keyword: "" means no filter)
expect(() => validateParams({ keyword: "" })).not.toThrow();
expect(() => validateParams({ name: makeLongString(10001) })).toThrow();
expect(() =>
validateParams({ items: Array.from({ length: 1001 }, () => "a") })
).toThrow();
expect(() =>
validateParams({
data: Object.fromEntries(
Array.from({ length: 101 }, (_, i) => [String(i), i])
),
})
).toThrow();
expect(() => validateParams({ user: "u", page: 1 })).not.toThrow();
});
});

View File

@ -0,0 +1,136 @@
import { ValidationError } from "../errors/dify-error";
const MAX_STRING_LENGTH = 10000;
const MAX_LIST_LENGTH = 1000;
const MAX_DICT_LENGTH = 100;
export function ensureNonEmptyString(
value: unknown,
name: string
): asserts value is string {
if (typeof value !== "string" || value.trim().length === 0) {
throw new ValidationError(`${name} must be a non-empty string`);
}
if (value.length > MAX_STRING_LENGTH) {
throw new ValidationError(
`${name} exceeds maximum length of ${MAX_STRING_LENGTH} characters`
);
}
}
/**
* Validates optional string fields that must be non-empty when provided.
* Use this for fields like `name` that are optional but should not be empty strings.
*
* For filter parameters that accept empty strings (e.g., `keyword: ""`),
* use `validateParams` which allows empty strings for optional params.
*/
export function ensureOptionalString(value: unknown, name: string): void {
if (value === undefined || value === null) {
return;
}
if (typeof value !== "string" || value.trim().length === 0) {
throw new ValidationError(`${name} must be a non-empty string when set`);
}
if (value.length > MAX_STRING_LENGTH) {
throw new ValidationError(
`${name} exceeds maximum length of ${MAX_STRING_LENGTH} characters`
);
}
}
export function ensureOptionalInt(value: unknown, name: string): void {
if (value === undefined || value === null) {
return;
}
if (!Number.isInteger(value)) {
throw new ValidationError(`${name} must be an integer when set`);
}
}
export function ensureOptionalBoolean(value: unknown, name: string): void {
if (value === undefined || value === null) {
return;
}
if (typeof value !== "boolean") {
throw new ValidationError(`${name} must be a boolean when set`);
}
}
export function ensureStringArray(value: unknown, name: string): void {
if (!Array.isArray(value) || value.length === 0) {
throw new ValidationError(`${name} must be a non-empty string array`);
}
if (value.length > MAX_LIST_LENGTH) {
throw new ValidationError(
`${name} exceeds maximum size of ${MAX_LIST_LENGTH} items`
);
}
value.forEach((item) => {
if (typeof item !== "string" || item.trim().length === 0) {
throw new ValidationError(`${name} must contain non-empty strings`);
}
});
}
export function ensureOptionalStringArray(value: unknown, name: string): void {
if (value === undefined || value === null) {
return;
}
ensureStringArray(value, name);
}
export function ensureRating(value: unknown): void {
if (value === undefined || value === null) {
return;
}
if (value !== "like" && value !== "dislike") {
throw new ValidationError("rating must be either 'like' or 'dislike'");
}
}
export function validateParams(params: Record<string, unknown>): void {
Object.entries(params).forEach(([key, value]) => {
if (value === undefined || value === null) {
return;
}
// Only check max length for strings; empty strings are allowed for optional params
// Required fields are validated at method level via ensureNonEmptyString
if (typeof value === "string") {
if (value.length > MAX_STRING_LENGTH) {
throw new ValidationError(
`Parameter '${key}' exceeds maximum length of ${MAX_STRING_LENGTH} characters`
);
}
} else if (Array.isArray(value)) {
if (value.length > MAX_LIST_LENGTH) {
throw new ValidationError(
`Parameter '${key}' exceeds maximum size of ${MAX_LIST_LENGTH} items`
);
}
} else if (typeof value === "object") {
if (Object.keys(value as Record<string, unknown>).length > MAX_DICT_LENGTH) {
throw new ValidationError(
`Parameter '${key}' exceeds maximum size of ${MAX_DICT_LENGTH} items`
);
}
}
if (key === "user" && typeof value !== "string") {
throw new ValidationError(`Parameter '${key}' must be a string`);
}
if (
(key === "page" || key === "limit" || key === "page_size") &&
!Number.isInteger(value)
) {
throw new ValidationError(`Parameter '${key}' must be an integer`);
}
if (key === "files" && !Array.isArray(value) && typeof value !== "object") {
throw new ValidationError(`Parameter '${key}' must be a list or dict`);
}
if (key === "rating" && value !== "like" && value !== "dislike") {
throw new ValidationError(`Parameter '${key}' must be 'like' or 'dislike'`);
}
});
}

View File

@ -0,0 +1,119 @@
import { beforeEach, describe, expect, it, vi } from "vitest";
import { WorkflowClient } from "./workflow";
import { createHttpClientWithSpies } from "../../tests/test-utils";
describe("WorkflowClient", () => {
beforeEach(() => {
vi.restoreAllMocks();
});
it("runs workflows with blocking and streaming modes", async () => {
const { client, request, requestStream } = createHttpClientWithSpies();
const workflow = new WorkflowClient(client);
await workflow.run({ inputs: { input: "x" }, user: "user" });
await workflow.run({ input: "x" }, "user", true);
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/workflows/run",
data: {
inputs: { input: "x" },
user: "user",
},
});
expect(requestStream).toHaveBeenCalledWith({
method: "POST",
path: "/workflows/run",
data: {
inputs: { input: "x" },
user: "user",
response_mode: "streaming",
},
});
});
it("runs workflow by id", async () => {
const { client, request, requestStream } = createHttpClientWithSpies();
const workflow = new WorkflowClient(client);
await workflow.runById("wf", {
inputs: { input: "x" },
user: "user",
response_mode: "blocking",
});
await workflow.runById("wf", {
inputs: { input: "x" },
user: "user",
response_mode: "streaming",
});
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/workflows/wf/run",
data: {
inputs: { input: "x" },
user: "user",
response_mode: "blocking",
},
});
expect(requestStream).toHaveBeenCalledWith({
method: "POST",
path: "/workflows/wf/run",
data: {
inputs: { input: "x" },
user: "user",
response_mode: "streaming",
},
});
});
it("gets run details and stops workflow", async () => {
const { client, request } = createHttpClientWithSpies();
const workflow = new WorkflowClient(client);
await workflow.getRun("run");
await workflow.stop("task", "user");
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/workflows/run/run",
});
expect(request).toHaveBeenCalledWith({
method: "POST",
path: "/workflows/tasks/task/stop",
data: { user: "user" },
});
});
it("fetches workflow logs", async () => {
const { client, request } = createHttpClientWithSpies();
const workflow = new WorkflowClient(client);
// Use createdByEndUserSessionId to filter by user session (backend API parameter)
await workflow.getLogs({
keyword: "k",
status: "succeeded",
startTime: "2024-01-01",
endTime: "2024-01-02",
createdByEndUserSessionId: "session-123",
page: 1,
limit: 20,
});
expect(request).toHaveBeenCalledWith({
method: "GET",
path: "/workflows/logs",
query: {
keyword: "k",
status: "succeeded",
created_at__before: "2024-01-02",
created_at__after: "2024-01-01",
created_by_end_user_session_id: "session-123",
created_by_account: undefined,
page: 1,
limit: 20,
},
});
});
});

View File

@ -0,0 +1,165 @@
import { DifyClient } from "./base";
import type { WorkflowRunRequest, WorkflowRunResponse } from "../types/workflow";
import type { DifyResponse, DifyStream, QueryParams } from "../types/common";
import {
ensureNonEmptyString,
ensureOptionalInt,
ensureOptionalString,
} from "./validation";
export class WorkflowClient extends DifyClient {
run(
request: WorkflowRunRequest
): Promise<DifyResponse<WorkflowRunResponse> | DifyStream<WorkflowRunResponse>>;
run(
inputs: Record<string, unknown>,
user: string,
stream?: boolean
): Promise<DifyResponse<WorkflowRunResponse> | DifyStream<WorkflowRunResponse>>;
run(
inputOrRequest: WorkflowRunRequest | Record<string, unknown>,
user?: string,
stream = false
): Promise<DifyResponse<WorkflowRunResponse> | DifyStream<WorkflowRunResponse>> {
let payload: WorkflowRunRequest;
let shouldStream = stream;
if (user === undefined && "user" in (inputOrRequest as WorkflowRunRequest)) {
payload = inputOrRequest as WorkflowRunRequest;
shouldStream = payload.response_mode === "streaming";
} else {
ensureNonEmptyString(user, "user");
payload = {
inputs: inputOrRequest as Record<string, unknown>,
user,
response_mode: stream ? "streaming" : "blocking",
};
}
ensureNonEmptyString(payload.user, "user");
if (shouldStream) {
return this.http.requestStream<WorkflowRunResponse>({
method: "POST",
path: "/workflows/run",
data: payload,
});
}
return this.http.request<WorkflowRunResponse>({
method: "POST",
path: "/workflows/run",
data: payload,
});
}
runById(
workflowId: string,
request: WorkflowRunRequest
): Promise<DifyResponse<WorkflowRunResponse> | DifyStream<WorkflowRunResponse>> {
ensureNonEmptyString(workflowId, "workflowId");
ensureNonEmptyString(request.user, "user");
if (request.response_mode === "streaming") {
return this.http.requestStream<WorkflowRunResponse>({
method: "POST",
path: `/workflows/${workflowId}/run`,
data: request,
});
}
return this.http.request<WorkflowRunResponse>({
method: "POST",
path: `/workflows/${workflowId}/run`,
data: request,
});
}
getRun(workflowRunId: string): Promise<DifyResponse<WorkflowRunResponse>> {
ensureNonEmptyString(workflowRunId, "workflowRunId");
return this.http.request({
method: "GET",
path: `/workflows/run/${workflowRunId}`,
});
}
stop(
taskId: string,
user: string
): Promise<DifyResponse<WorkflowRunResponse>> {
ensureNonEmptyString(taskId, "taskId");
ensureNonEmptyString(user, "user");
return this.http.request<WorkflowRunResponse>({
method: "POST",
path: `/workflows/tasks/${taskId}/stop`,
data: { user },
});
}
/**
* Get workflow execution logs with filtering options.
*
* Note: The backend API filters by `createdByEndUserSessionId` (end user session ID)
* or `createdByAccount` (account ID), not by a generic `user` parameter.
*/
getLogs(options?: {
keyword?: string;
status?: string;
createdAtBefore?: string;
createdAtAfter?: string;
createdByEndUserSessionId?: string;
createdByAccount?: string;
page?: number;
limit?: number;
startTime?: string;
endTime?: string;
}): Promise<DifyResponse<Record<string, unknown>>> {
if (options?.keyword) {
ensureOptionalString(options.keyword, "keyword");
}
if (options?.status) {
ensureOptionalString(options.status, "status");
}
if (options?.createdAtBefore) {
ensureOptionalString(options.createdAtBefore, "createdAtBefore");
}
if (options?.createdAtAfter) {
ensureOptionalString(options.createdAtAfter, "createdAtAfter");
}
if (options?.createdByEndUserSessionId) {
ensureOptionalString(
options.createdByEndUserSessionId,
"createdByEndUserSessionId"
);
}
if (options?.createdByAccount) {
ensureOptionalString(options.createdByAccount, "createdByAccount");
}
if (options?.startTime) {
ensureOptionalString(options.startTime, "startTime");
}
if (options?.endTime) {
ensureOptionalString(options.endTime, "endTime");
}
ensureOptionalInt(options?.page, "page");
ensureOptionalInt(options?.limit, "limit");
const createdAtAfter = options?.createdAtAfter ?? options?.startTime;
const createdAtBefore = options?.createdAtBefore ?? options?.endTime;
const query: QueryParams = {
keyword: options?.keyword,
status: options?.status,
created_at__before: createdAtBefore,
created_at__after: createdAtAfter,
created_by_end_user_session_id: options?.createdByEndUserSessionId,
created_by_account: options?.createdByAccount,
page: options?.page,
limit: options?.limit,
};
return this.http.request({
method: "GET",
path: "/workflows/logs",
query,
});
}
}

Some files were not shown because too many files have changed in this diff Show More