mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into main
This commit is contained in:
commit
20d134ba87
|
|
@ -1,13 +1,13 @@
|
||||||
---
|
---
|
||||||
name: Dify Frontend Testing
|
name: frontend-testing
|
||||||
description: Generate Jest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Jest, RTL, unit tests, integration tests, or write/review test requests.
|
description: Generate Vitest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Vitest, RTL, unit tests, integration tests, or write/review test requests.
|
||||||
---
|
---
|
||||||
|
|
||||||
# Dify Frontend Testing Skill
|
# Dify Frontend Testing Skill
|
||||||
|
|
||||||
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
|
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
|
||||||
|
|
||||||
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. When in doubt, always refer to that document as the canonical specification.
|
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. Use Vitest mock/timer APIs (`vi.*`).
|
||||||
|
|
||||||
## When to Apply This Skill
|
## When to Apply This Skill
|
||||||
|
|
||||||
|
|
@ -15,7 +15,7 @@ Apply this skill when the user:
|
||||||
|
|
||||||
- Asks to **write tests** for a component, hook, or utility
|
- Asks to **write tests** for a component, hook, or utility
|
||||||
- Asks to **review existing tests** for completeness
|
- Asks to **review existing tests** for completeness
|
||||||
- Mentions **Jest**, **React Testing Library**, **RTL**, or **spec files**
|
- Mentions **Vitest**, **React Testing Library**, **RTL**, or **spec files**
|
||||||
- Requests **test coverage** improvement
|
- Requests **test coverage** improvement
|
||||||
- Uses `pnpm analyze-component` output as context
|
- Uses `pnpm analyze-component` output as context
|
||||||
- Mentions **testing**, **unit tests**, or **integration tests** for frontend code
|
- Mentions **testing**, **unit tests**, or **integration tests** for frontend code
|
||||||
|
|
@ -33,9 +33,9 @@ Apply this skill when the user:
|
||||||
|
|
||||||
| Tool | Version | Purpose |
|
| Tool | Version | Purpose |
|
||||||
|------|---------|---------|
|
|------|---------|---------|
|
||||||
| Jest | 29.7 | Test runner |
|
| Vitest | 4.0.16 | Test runner |
|
||||||
| React Testing Library | 16.0 | Component testing |
|
| React Testing Library | 16.0 | Component testing |
|
||||||
| happy-dom | - | Test environment |
|
| jsdom | - | Test environment |
|
||||||
| nock | 14.0 | HTTP mocking |
|
| nock | 14.0 | HTTP mocking |
|
||||||
| TypeScript | 5.x | Type safety |
|
| TypeScript | 5.x | Type safety |
|
||||||
|
|
||||||
|
|
@ -46,13 +46,13 @@ Apply this skill when the user:
|
||||||
pnpm test
|
pnpm test
|
||||||
|
|
||||||
# Watch mode
|
# Watch mode
|
||||||
pnpm test -- --watch
|
pnpm test:watch
|
||||||
|
|
||||||
# Run specific file
|
# Run specific file
|
||||||
pnpm test -- path/to/file.spec.tsx
|
pnpm test path/to/file.spec.tsx
|
||||||
|
|
||||||
# Generate coverage report
|
# Generate coverage report
|
||||||
pnpm test -- --coverage
|
pnpm test:coverage
|
||||||
|
|
||||||
# Analyze component complexity
|
# Analyze component complexity
|
||||||
pnpm analyze-component <path>
|
pnpm analyze-component <path>
|
||||||
|
|
@ -77,9 +77,9 @@ import Component from './index'
|
||||||
// import { ChildComponent } from './child-component'
|
// import { ChildComponent } from './child-component'
|
||||||
|
|
||||||
// ✅ Mock external dependencies only
|
// ✅ Mock external dependencies only
|
||||||
jest.mock('@/service/api')
|
vi.mock('@/service/api')
|
||||||
jest.mock('next/navigation', () => ({
|
vi.mock('next/navigation', () => ({
|
||||||
useRouter: () => ({ push: jest.fn() }),
|
useRouter: () => ({ push: vi.fn() }),
|
||||||
usePathname: () => '/test',
|
usePathname: () => '/test',
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
@ -88,7 +88,7 @@ let mockSharedState = false
|
||||||
|
|
||||||
describe('ComponentName', () => {
|
describe('ComponentName', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks() // ✅ Reset mocks BEFORE each test
|
vi.clearAllMocks() // ✅ Reset mocks BEFORE each test
|
||||||
mockSharedState = false // ✅ Reset shared state
|
mockSharedState = false // ✅ Reset shared state
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -117,7 +117,7 @@ describe('ComponentName', () => {
|
||||||
// User Interactions
|
// User Interactions
|
||||||
describe('User Interactions', () => {
|
describe('User Interactions', () => {
|
||||||
it('should handle click events', () => {
|
it('should handle click events', () => {
|
||||||
const handleClick = jest.fn()
|
const handleClick = vi.fn()
|
||||||
render(<Component onClick={handleClick} />)
|
render(<Component onClick={handleClick} />)
|
||||||
|
|
||||||
fireEvent.click(screen.getByRole('button'))
|
fireEvent.click(screen.getByRole('button'))
|
||||||
|
|
@ -155,7 +155,7 @@ describe('ComponentName', () => {
|
||||||
For each file:
|
For each file:
|
||||||
┌────────────────────────────────────────┐
|
┌────────────────────────────────────────┐
|
||||||
│ 1. Write test │
|
│ 1. Write test │
|
||||||
│ 2. Run: pnpm test -- <file>.spec.tsx │
|
│ 2. Run: pnpm test <file>.spec.tsx │
|
||||||
│ 3. PASS? → Mark complete, next file │
|
│ 3. PASS? → Mark complete, next file │
|
||||||
│ FAIL? → Fix first, then continue │
|
│ FAIL? → Fix first, then continue │
|
||||||
└────────────────────────────────────────┘
|
└────────────────────────────────────────┘
|
||||||
|
|
@ -178,7 +178,7 @@ Process in this order for multi-file testing:
|
||||||
- **500+ lines**: Consider splitting before testing
|
- **500+ lines**: Consider splitting before testing
|
||||||
- **Many dependencies**: Extract logic into hooks first
|
- **Many dependencies**: Extract logic into hooks first
|
||||||
|
|
||||||
> 📖 See `guides/workflow.md` for complete workflow details and todo list format.
|
> 📖 See `references/workflow.md` for complete workflow details and todo list format.
|
||||||
|
|
||||||
## Testing Strategy
|
## Testing Strategy
|
||||||
|
|
||||||
|
|
@ -289,17 +289,18 @@ For each test file generated, aim for:
|
||||||
- ✅ **>95%** branch coverage
|
- ✅ **>95%** branch coverage
|
||||||
- ✅ **>95%** line coverage
|
- ✅ **>95%** line coverage
|
||||||
|
|
||||||
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `guides/workflow.md`.
|
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `references/workflow.md`.
|
||||||
|
|
||||||
## Detailed Guides
|
## Detailed Guides
|
||||||
|
|
||||||
For more detailed information, refer to:
|
For more detailed information, refer to:
|
||||||
|
|
||||||
- `guides/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
|
- `references/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
|
||||||
- `guides/mocking.md` - Mock patterns and best practices
|
- `references/mocking.md` - Mock patterns and best practices
|
||||||
- `guides/async-testing.md` - Async operations and API calls
|
- `references/async-testing.md` - Async operations and API calls
|
||||||
- `guides/domain-components.md` - Workflow, Dataset, Configuration testing
|
- `references/domain-components.md` - Workflow, Dataset, Configuration testing
|
||||||
- `guides/common-patterns.md` - Frequently used testing patterns
|
- `references/common-patterns.md` - Frequently used testing patterns
|
||||||
|
- `references/checklist.md` - Test generation checklist and validation steps
|
||||||
|
|
||||||
## Authoritative References
|
## Authoritative References
|
||||||
|
|
||||||
|
|
@ -315,7 +316,7 @@ For more detailed information, refer to:
|
||||||
|
|
||||||
### Project Configuration
|
### Project Configuration
|
||||||
|
|
||||||
- `web/jest.config.ts` - Jest configuration
|
- `web/vitest.config.ts` - Vitest configuration
|
||||||
- `web/jest.setup.ts` - Test environment setup
|
- `web/vitest.setup.ts` - Test environment setup
|
||||||
- `web/testing/analyze-component.js` - Component analysis tool
|
- `web/testing/analyze-component.js` - Component analysis tool
|
||||||
- `web/__mocks__/react-i18next.ts` - Shared i18n mock (auto-loaded by Jest, no explicit mock needed; override locally only for custom translations)
|
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.
|
||||||
|
|
|
||||||
|
|
@ -23,14 +23,14 @@ import userEvent from '@testing-library/user-event'
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// Mocks
|
// Mocks
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// WHY: Mocks must be hoisted to top of file (Jest requirement).
|
// WHY: Mocks must be hoisted to top of file (Vitest requirement).
|
||||||
// They run BEFORE imports, so keep them before component imports.
|
// They run BEFORE imports, so keep them before component imports.
|
||||||
|
|
||||||
// i18n (automatically mocked)
|
// i18n (automatically mocked)
|
||||||
// WHY: Shared mock at web/__mocks__/react-i18next.ts is auto-loaded by Jest
|
// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
|
||||||
// No explicit mock needed - it returns translation keys as-is
|
// No explicit mock needed - it returns translation keys as-is
|
||||||
// Override only if custom translations are required:
|
// Override only if custom translations are required:
|
||||||
// jest.mock('react-i18next', () => ({
|
// vi.mock('react-i18next', () => ({
|
||||||
// useTranslation: () => ({
|
// useTranslation: () => ({
|
||||||
// t: (key: string) => {
|
// t: (key: string) => {
|
||||||
// const customTranslations: Record<string, string> = {
|
// const customTranslations: Record<string, string> = {
|
||||||
|
|
@ -43,17 +43,17 @@ import userEvent from '@testing-library/user-event'
|
||||||
|
|
||||||
// Router (if component uses useRouter, usePathname, useSearchParams)
|
// Router (if component uses useRouter, usePathname, useSearchParams)
|
||||||
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
|
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
|
||||||
// const mockPush = jest.fn()
|
// const mockPush = vi.fn()
|
||||||
// jest.mock('next/navigation', () => ({
|
// vi.mock('next/navigation', () => ({
|
||||||
// useRouter: () => ({ push: mockPush }),
|
// useRouter: () => ({ push: mockPush }),
|
||||||
// usePathname: () => '/test-path',
|
// usePathname: () => '/test-path',
|
||||||
// }))
|
// }))
|
||||||
|
|
||||||
// API services (if component fetches data)
|
// API services (if component fetches data)
|
||||||
// WHY: Prevents real network calls, enables testing all states (loading/success/error)
|
// WHY: Prevents real network calls, enables testing all states (loading/success/error)
|
||||||
// jest.mock('@/service/api')
|
// vi.mock('@/service/api')
|
||||||
// import * as api from '@/service/api'
|
// import * as api from '@/service/api'
|
||||||
// const mockedApi = api as jest.Mocked<typeof api>
|
// const mockedApi = vi.mocked(api)
|
||||||
|
|
||||||
// Shared mock state (for portal/dropdown components)
|
// Shared mock state (for portal/dropdown components)
|
||||||
// WHY: Portal components like PortalToFollowElem need shared state between
|
// WHY: Portal components like PortalToFollowElem need shared state between
|
||||||
|
|
@ -98,7 +98,7 @@ describe('ComponentName', () => {
|
||||||
// - Prevents mock call history from leaking between tests
|
// - Prevents mock call history from leaking between tests
|
||||||
// - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes
|
// - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
// Reset shared mock state if used (CRITICAL for portal/dropdown tests)
|
// Reset shared mock state if used (CRITICAL for portal/dropdown tests)
|
||||||
// mockOpenState = false
|
// mockOpenState = false
|
||||||
})
|
})
|
||||||
|
|
@ -155,7 +155,7 @@ describe('ComponentName', () => {
|
||||||
// - userEvent simulates real user behavior (focus, hover, then click)
|
// - userEvent simulates real user behavior (focus, hover, then click)
|
||||||
// - fireEvent is lower-level, doesn't trigger all browser events
|
// - fireEvent is lower-level, doesn't trigger all browser events
|
||||||
// const user = userEvent.setup()
|
// const user = userEvent.setup()
|
||||||
// const handleClick = jest.fn()
|
// const handleClick = vi.fn()
|
||||||
// render(<ComponentName onClick={handleClick} />)
|
// render(<ComponentName onClick={handleClick} />)
|
||||||
//
|
//
|
||||||
// await user.click(screen.getByRole('button'))
|
// await user.click(screen.getByRole('button'))
|
||||||
|
|
@ -165,7 +165,7 @@ describe('ComponentName', () => {
|
||||||
|
|
||||||
it('should call onChange when value changes', async () => {
|
it('should call onChange when value changes', async () => {
|
||||||
// const user = userEvent.setup()
|
// const user = userEvent.setup()
|
||||||
// const handleChange = jest.fn()
|
// const handleChange = vi.fn()
|
||||||
// render(<ComponentName onChange={handleChange} />)
|
// render(<ComponentName onChange={handleChange} />)
|
||||||
//
|
//
|
||||||
// await user.type(screen.getByRole('textbox'), 'new value')
|
// await user.type(screen.getByRole('textbox'), 'new value')
|
||||||
|
|
@ -198,7 +198,7 @@ describe('ComponentName', () => {
|
||||||
})
|
})
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// Async Operations (if component fetches data - useSWR, useQuery, fetch)
|
// Async Operations (if component fetches data - useQuery, fetch)
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// WHY: Async operations have 3 states users experience: loading, success, error
|
// WHY: Async operations have 3 states users experience: loading, success, error
|
||||||
describe('Async Operations', () => {
|
describe('Async Operations', () => {
|
||||||
|
|
@ -15,9 +15,9 @@ import { renderHook, act, waitFor } from '@testing-library/react'
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
||||||
// API services (if hook fetches data)
|
// API services (if hook fetches data)
|
||||||
// jest.mock('@/service/api')
|
// vi.mock('@/service/api')
|
||||||
// import * as api from '@/service/api'
|
// import * as api from '@/service/api'
|
||||||
// const mockedApi = api as jest.Mocked<typeof api>
|
// const mockedApi = vi.mocked(api)
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// Test Helpers
|
// Test Helpers
|
||||||
|
|
@ -38,7 +38,7 @@ import { renderHook, act, waitFor } from '@testing-library/react'
|
||||||
|
|
||||||
describe('useHookName', () => {
|
describe('useHookName', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
})
|
})
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
|
|
@ -145,7 +145,7 @@ describe('useHookName', () => {
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
describe('Side Effects', () => {
|
describe('Side Effects', () => {
|
||||||
it('should call callback when value changes', () => {
|
it('should call callback when value changes', () => {
|
||||||
// const callback = jest.fn()
|
// const callback = vi.fn()
|
||||||
// const { result } = renderHook(() => useHookName({ onChange: callback }))
|
// const { result } = renderHook(() => useHookName({ onChange: callback }))
|
||||||
//
|
//
|
||||||
// act(() => {
|
// act(() => {
|
||||||
|
|
@ -156,9 +156,9 @@ describe('useHookName', () => {
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should cleanup on unmount', () => {
|
it('should cleanup on unmount', () => {
|
||||||
// const cleanup = jest.fn()
|
// const cleanup = vi.fn()
|
||||||
// jest.spyOn(window, 'addEventListener')
|
// vi.spyOn(window, 'addEventListener')
|
||||||
// jest.spyOn(window, 'removeEventListener')
|
// vi.spyOn(window, 'removeEventListener')
|
||||||
//
|
//
|
||||||
// const { unmount } = renderHook(() => useHookName())
|
// const { unmount } = renderHook(() => useHookName())
|
||||||
//
|
//
|
||||||
|
|
@ -49,7 +49,7 @@ import userEvent from '@testing-library/user-event'
|
||||||
|
|
||||||
it('should submit form', async () => {
|
it('should submit form', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
const onSubmit = jest.fn()
|
const onSubmit = vi.fn()
|
||||||
|
|
||||||
render(<Form onSubmit={onSubmit} />)
|
render(<Form onSubmit={onSubmit} />)
|
||||||
|
|
||||||
|
|
@ -77,15 +77,15 @@ it('should submit form', async () => {
|
||||||
```typescript
|
```typescript
|
||||||
describe('Debounced Search', () => {
|
describe('Debounced Search', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.useFakeTimers()
|
vi.useFakeTimers()
|
||||||
})
|
})
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
jest.useRealTimers()
|
vi.useRealTimers()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should debounce search input', async () => {
|
it('should debounce search input', async () => {
|
||||||
const onSearch = jest.fn()
|
const onSearch = vi.fn()
|
||||||
render(<SearchInput onSearch={onSearch} debounceMs={300} />)
|
render(<SearchInput onSearch={onSearch} debounceMs={300} />)
|
||||||
|
|
||||||
// Type in the input
|
// Type in the input
|
||||||
|
|
@ -95,7 +95,7 @@ describe('Debounced Search', () => {
|
||||||
expect(onSearch).not.toHaveBeenCalled()
|
expect(onSearch).not.toHaveBeenCalled()
|
||||||
|
|
||||||
// Advance timers
|
// Advance timers
|
||||||
jest.advanceTimersByTime(300)
|
vi.advanceTimersByTime(300)
|
||||||
|
|
||||||
// Now search is called
|
// Now search is called
|
||||||
expect(onSearch).toHaveBeenCalledWith('query')
|
expect(onSearch).toHaveBeenCalledWith('query')
|
||||||
|
|
@ -107,8 +107,8 @@ describe('Debounced Search', () => {
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
it('should retry on failure', async () => {
|
it('should retry on failure', async () => {
|
||||||
jest.useFakeTimers()
|
vi.useFakeTimers()
|
||||||
const fetchData = jest.fn()
|
const fetchData = vi.fn()
|
||||||
.mockRejectedValueOnce(new Error('Network error'))
|
.mockRejectedValueOnce(new Error('Network error'))
|
||||||
.mockResolvedValueOnce({ data: 'success' })
|
.mockResolvedValueOnce({ data: 'success' })
|
||||||
|
|
||||||
|
|
@ -120,7 +120,7 @@ it('should retry on failure', async () => {
|
||||||
})
|
})
|
||||||
|
|
||||||
// Advance timer for retry
|
// Advance timer for retry
|
||||||
jest.advanceTimersByTime(1000)
|
vi.advanceTimersByTime(1000)
|
||||||
|
|
||||||
// Second call succeeds
|
// Second call succeeds
|
||||||
await waitFor(() => {
|
await waitFor(() => {
|
||||||
|
|
@ -128,7 +128,7 @@ it('should retry on failure', async () => {
|
||||||
expect(screen.getByText('success')).toBeInTheDocument()
|
expect(screen.getByText('success')).toBeInTheDocument()
|
||||||
})
|
})
|
||||||
|
|
||||||
jest.useRealTimers()
|
vi.useRealTimers()
|
||||||
})
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -136,19 +136,19 @@ it('should retry on failure', async () => {
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
// Run all pending timers
|
// Run all pending timers
|
||||||
jest.runAllTimers()
|
vi.runAllTimers()
|
||||||
|
|
||||||
// Run only pending timers (not new ones created during execution)
|
// Run only pending timers (not new ones created during execution)
|
||||||
jest.runOnlyPendingTimers()
|
vi.runOnlyPendingTimers()
|
||||||
|
|
||||||
// Advance by specific time
|
// Advance by specific time
|
||||||
jest.advanceTimersByTime(1000)
|
vi.advanceTimersByTime(1000)
|
||||||
|
|
||||||
// Get current fake time
|
// Get current fake time
|
||||||
jest.now()
|
Date.now()
|
||||||
|
|
||||||
// Clear all timers
|
// Clear all timers
|
||||||
jest.clearAllTimers()
|
vi.clearAllTimers()
|
||||||
```
|
```
|
||||||
|
|
||||||
## API Testing Patterns
|
## API Testing Patterns
|
||||||
|
|
@ -158,7 +158,7 @@ jest.clearAllTimers()
|
||||||
```typescript
|
```typescript
|
||||||
describe('DataFetcher', () => {
|
describe('DataFetcher', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should show loading state', () => {
|
it('should show loading state', () => {
|
||||||
|
|
@ -241,7 +241,7 @@ it('should submit form and show success', async () => {
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
it('should fetch data on mount', async () => {
|
it('should fetch data on mount', async () => {
|
||||||
const fetchData = jest.fn().mockResolvedValue({ data: 'test' })
|
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
|
||||||
|
|
||||||
render(<ComponentWithEffect fetchData={fetchData} />)
|
render(<ComponentWithEffect fetchData={fetchData} />)
|
||||||
|
|
||||||
|
|
@ -255,7 +255,7 @@ it('should fetch data on mount', async () => {
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
it('should refetch when id changes', async () => {
|
it('should refetch when id changes', async () => {
|
||||||
const fetchData = jest.fn().mockResolvedValue({ data: 'test' })
|
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
|
||||||
|
|
||||||
const { rerender } = render(<ComponentWithEffect id="1" fetchData={fetchData} />)
|
const { rerender } = render(<ComponentWithEffect id="1" fetchData={fetchData} />)
|
||||||
|
|
||||||
|
|
@ -276,8 +276,8 @@ it('should refetch when id changes', async () => {
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
it('should cleanup subscription on unmount', () => {
|
it('should cleanup subscription on unmount', () => {
|
||||||
const subscribe = jest.fn()
|
const subscribe = vi.fn()
|
||||||
const unsubscribe = jest.fn()
|
const unsubscribe = vi.fn()
|
||||||
subscribe.mockReturnValue(unsubscribe)
|
subscribe.mockReturnValue(unsubscribe)
|
||||||
|
|
||||||
const { unmount } = render(<SubscriptionComponent subscribe={subscribe} />)
|
const { unmount } = render(<SubscriptionComponent subscribe={subscribe} />)
|
||||||
|
|
@ -332,14 +332,14 @@ expect(description).toBeInTheDocument()
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
// Bad - fake timers don't work well with real Promises
|
// Bad - fake timers don't work well with real Promises
|
||||||
jest.useFakeTimers()
|
vi.useFakeTimers()
|
||||||
await waitFor(() => {
|
await waitFor(() => {
|
||||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||||
}) // May timeout!
|
}) // May timeout!
|
||||||
|
|
||||||
// Good - use runAllTimers or advanceTimersByTime
|
// Good - use runAllTimers or advanceTimersByTime
|
||||||
jest.useFakeTimers()
|
vi.useFakeTimers()
|
||||||
render(<Component />)
|
render(<Component />)
|
||||||
jest.runAllTimers()
|
vi.runAllTimers()
|
||||||
expect(screen.getByText('Data')).toBeInTheDocument()
|
expect(screen.getByText('Data')).toBeInTheDocument()
|
||||||
```
|
```
|
||||||
|
|
@ -74,9 +74,9 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
|
||||||
### Mocks
|
### Mocks
|
||||||
|
|
||||||
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
|
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
|
||||||
- [ ] `jest.clearAllMocks()` in `beforeEach` (not `afterEach`)
|
- [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`)
|
||||||
- [ ] Shared mock state reset in `beforeEach`
|
- [ ] Shared mock state reset in `beforeEach`
|
||||||
- [ ] i18n uses shared mock (auto-loaded); only override locally for custom translations
|
- [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations
|
||||||
- [ ] Router mocks match actual Next.js API
|
- [ ] Router mocks match actual Next.js API
|
||||||
- [ ] Mocks reflect actual component conditional behavior
|
- [ ] Mocks reflect actual component conditional behavior
|
||||||
- [ ] Only mock: API services, complex context providers, third-party libs
|
- [ ] Only mock: API services, complex context providers, third-party libs
|
||||||
|
|
@ -114,15 +114,15 @@ For the current file being tested:
|
||||||
|
|
||||||
**Run these checks after EACH test file, not just at the end:**
|
**Run these checks after EACH test file, not just at the end:**
|
||||||
|
|
||||||
- [ ] Run `pnpm test -- path/to/file.spec.tsx` - **MUST PASS before next file**
|
- [ ] Run `pnpm test path/to/file.spec.tsx` - **MUST PASS before next file**
|
||||||
- [ ] Fix any failures immediately
|
- [ ] Fix any failures immediately
|
||||||
- [ ] Mark file as complete in todo list
|
- [ ] Mark file as complete in todo list
|
||||||
- [ ] Only then proceed to next file
|
- [ ] Only then proceed to next file
|
||||||
|
|
||||||
### After All Files Complete
|
### After All Files Complete
|
||||||
|
|
||||||
- [ ] Run full directory test: `pnpm test -- path/to/directory/`
|
- [ ] Run full directory test: `pnpm test path/to/directory/`
|
||||||
- [ ] Check coverage report: `pnpm test -- --coverage`
|
- [ ] Check coverage report: `pnpm test:coverage`
|
||||||
- [ ] Run `pnpm lint:fix` on all test files
|
- [ ] Run `pnpm lint:fix` on all test files
|
||||||
- [ ] Run `pnpm type-check:tsgo`
|
- [ ] Run `pnpm type-check:tsgo`
|
||||||
|
|
||||||
|
|
@ -132,10 +132,10 @@ For the current file being tested:
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
// ❌ Mock doesn't match actual behavior
|
// ❌ Mock doesn't match actual behavior
|
||||||
jest.mock('./Component', () => () => <div>Mocked</div>)
|
vi.mock('./Component', () => () => <div>Mocked</div>)
|
||||||
|
|
||||||
// ✅ Mock matches actual conditional logic
|
// ✅ Mock matches actual conditional logic
|
||||||
jest.mock('./Component', () => ({ isOpen }: any) =>
|
vi.mock('./Component', () => ({ isOpen }: any) =>
|
||||||
isOpen ? <div>Content</div> : null
|
isOpen ? <div>Content</div> : null
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
@ -145,7 +145,7 @@ jest.mock('./Component', () => ({ isOpen }: any) =>
|
||||||
```typescript
|
```typescript
|
||||||
// ❌ Shared state not reset
|
// ❌ Shared state not reset
|
||||||
let mockState = false
|
let mockState = false
|
||||||
jest.mock('./useHook', () => () => mockState)
|
vi.mock('./useHook', () => () => mockState)
|
||||||
|
|
||||||
// ✅ Reset in beforeEach
|
// ✅ Reset in beforeEach
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
|
|
@ -186,16 +186,16 @@ Always test these scenarios:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Run specific test
|
# Run specific test
|
||||||
pnpm test -- path/to/file.spec.tsx
|
pnpm test path/to/file.spec.tsx
|
||||||
|
|
||||||
# Run with coverage
|
# Run with coverage
|
||||||
pnpm test -- --coverage path/to/file.spec.tsx
|
pnpm test:coverage path/to/file.spec.tsx
|
||||||
|
|
||||||
# Watch mode
|
# Watch mode
|
||||||
pnpm test -- --watch path/to/file.spec.tsx
|
pnpm test:watch path/to/file.spec.tsx
|
||||||
|
|
||||||
# Update snapshots (use sparingly)
|
# Update snapshots (use sparingly)
|
||||||
pnpm test -- -u path/to/file.spec.tsx
|
pnpm test -u path/to/file.spec.tsx
|
||||||
|
|
||||||
# Analyze component
|
# Analyze component
|
||||||
pnpm analyze-component path/to/component.tsx
|
pnpm analyze-component path/to/component.tsx
|
||||||
|
|
@ -126,7 +126,7 @@ describe('Counter', () => {
|
||||||
describe('ControlledInput', () => {
|
describe('ControlledInput', () => {
|
||||||
it('should call onChange with new value', async () => {
|
it('should call onChange with new value', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
const handleChange = jest.fn()
|
const handleChange = vi.fn()
|
||||||
|
|
||||||
render(<ControlledInput value="" onChange={handleChange} />)
|
render(<ControlledInput value="" onChange={handleChange} />)
|
||||||
|
|
||||||
|
|
@ -136,7 +136,7 @@ describe('ControlledInput', () => {
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should display controlled value', () => {
|
it('should display controlled value', () => {
|
||||||
render(<ControlledInput value="controlled" onChange={jest.fn()} />)
|
render(<ControlledInput value="controlled" onChange={vi.fn()} />)
|
||||||
|
|
||||||
expect(screen.getByRole('textbox')).toHaveValue('controlled')
|
expect(screen.getByRole('textbox')).toHaveValue('controlled')
|
||||||
})
|
})
|
||||||
|
|
@ -195,7 +195,7 @@ describe('ItemList', () => {
|
||||||
|
|
||||||
it('should handle item selection', async () => {
|
it('should handle item selection', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
const onSelect = jest.fn()
|
const onSelect = vi.fn()
|
||||||
|
|
||||||
render(<ItemList items={items} onSelect={onSelect} />)
|
render(<ItemList items={items} onSelect={onSelect} />)
|
||||||
|
|
||||||
|
|
@ -217,20 +217,20 @@ describe('ItemList', () => {
|
||||||
```typescript
|
```typescript
|
||||||
describe('Modal', () => {
|
describe('Modal', () => {
|
||||||
it('should not render when closed', () => {
|
it('should not render when closed', () => {
|
||||||
render(<Modal isOpen={false} onClose={jest.fn()} />)
|
render(<Modal isOpen={false} onClose={vi.fn()} />)
|
||||||
|
|
||||||
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
|
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should render when open', () => {
|
it('should render when open', () => {
|
||||||
render(<Modal isOpen={true} onClose={jest.fn()} />)
|
render(<Modal isOpen={true} onClose={vi.fn()} />)
|
||||||
|
|
||||||
expect(screen.getByRole('dialog')).toBeInTheDocument()
|
expect(screen.getByRole('dialog')).toBeInTheDocument()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should call onClose when clicking overlay', async () => {
|
it('should call onClose when clicking overlay', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
const handleClose = jest.fn()
|
const handleClose = vi.fn()
|
||||||
|
|
||||||
render(<Modal isOpen={true} onClose={handleClose} />)
|
render(<Modal isOpen={true} onClose={handleClose} />)
|
||||||
|
|
||||||
|
|
@ -241,7 +241,7 @@ describe('Modal', () => {
|
||||||
|
|
||||||
it('should call onClose when pressing Escape', async () => {
|
it('should call onClose when pressing Escape', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
const handleClose = jest.fn()
|
const handleClose = vi.fn()
|
||||||
|
|
||||||
render(<Modal isOpen={true} onClose={handleClose} />)
|
render(<Modal isOpen={true} onClose={handleClose} />)
|
||||||
|
|
||||||
|
|
@ -254,7 +254,7 @@ describe('Modal', () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
|
|
||||||
render(
|
render(
|
||||||
<Modal isOpen={true} onClose={jest.fn()}>
|
<Modal isOpen={true} onClose={vi.fn()}>
|
||||||
<button>First</button>
|
<button>First</button>
|
||||||
<button>Second</button>
|
<button>Second</button>
|
||||||
</Modal>
|
</Modal>
|
||||||
|
|
@ -279,7 +279,7 @@ describe('Modal', () => {
|
||||||
describe('LoginForm', () => {
|
describe('LoginForm', () => {
|
||||||
it('should submit valid form', async () => {
|
it('should submit valid form', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
const onSubmit = jest.fn()
|
const onSubmit = vi.fn()
|
||||||
|
|
||||||
render(<LoginForm onSubmit={onSubmit} />)
|
render(<LoginForm onSubmit={onSubmit} />)
|
||||||
|
|
||||||
|
|
@ -296,7 +296,7 @@ describe('LoginForm', () => {
|
||||||
it('should show validation errors', async () => {
|
it('should show validation errors', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
|
|
||||||
render(<LoginForm onSubmit={jest.fn()} />)
|
render(<LoginForm onSubmit={vi.fn()} />)
|
||||||
|
|
||||||
// Submit empty form
|
// Submit empty form
|
||||||
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||||
|
|
@ -308,7 +308,7 @@ describe('LoginForm', () => {
|
||||||
it('should validate email format', async () => {
|
it('should validate email format', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
|
|
||||||
render(<LoginForm onSubmit={jest.fn()} />)
|
render(<LoginForm onSubmit={vi.fn()} />)
|
||||||
|
|
||||||
await user.type(screen.getByLabelText(/email/i), 'invalid-email')
|
await user.type(screen.getByLabelText(/email/i), 'invalid-email')
|
||||||
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
await user.click(screen.getByRole('button', { name: /sign in/i }))
|
||||||
|
|
@ -318,7 +318,7 @@ describe('LoginForm', () => {
|
||||||
|
|
||||||
it('should disable submit button while submitting', async () => {
|
it('should disable submit button while submitting', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
const onSubmit = jest.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
|
const onSubmit = vi.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
|
||||||
|
|
||||||
render(<LoginForm onSubmit={onSubmit} />)
|
render(<LoginForm onSubmit={onSubmit} />)
|
||||||
|
|
||||||
|
|
@ -407,7 +407,7 @@ it('test 1', () => {
|
||||||
|
|
||||||
// Good - cleanup is automatic with RTL, but reset mocks
|
// Good - cleanup is automatic with RTL, but reset mocks
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
})
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -23,7 +23,7 @@ import NodeConfigPanel from './node-config-panel'
|
||||||
import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow'
|
import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow'
|
||||||
|
|
||||||
// Mock workflow context
|
// Mock workflow context
|
||||||
jest.mock('@/app/components/workflow/hooks', () => ({
|
vi.mock('@/app/components/workflow/hooks', () => ({
|
||||||
useWorkflowStore: () => mockWorkflowStore,
|
useWorkflowStore: () => mockWorkflowStore,
|
||||||
useNodesInteractions: () => mockNodesInteractions,
|
useNodesInteractions: () => mockNodesInteractions,
|
||||||
}))
|
}))
|
||||||
|
|
@ -31,21 +31,21 @@ jest.mock('@/app/components/workflow/hooks', () => ({
|
||||||
let mockWorkflowStore = {
|
let mockWorkflowStore = {
|
||||||
nodes: [],
|
nodes: [],
|
||||||
edges: [],
|
edges: [],
|
||||||
updateNode: jest.fn(),
|
updateNode: vi.fn(),
|
||||||
}
|
}
|
||||||
|
|
||||||
let mockNodesInteractions = {
|
let mockNodesInteractions = {
|
||||||
handleNodeSelect: jest.fn(),
|
handleNodeSelect: vi.fn(),
|
||||||
handleNodeDelete: jest.fn(),
|
handleNodeDelete: vi.fn(),
|
||||||
}
|
}
|
||||||
|
|
||||||
describe('NodeConfigPanel', () => {
|
describe('NodeConfigPanel', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
mockWorkflowStore = {
|
mockWorkflowStore = {
|
||||||
nodes: [],
|
nodes: [],
|
||||||
edges: [],
|
edges: [],
|
||||||
updateNode: jest.fn(),
|
updateNode: vi.fn(),
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -161,23 +161,23 @@ import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||||
import userEvent from '@testing-library/user-event'
|
import userEvent from '@testing-library/user-event'
|
||||||
import DocumentUploader from './document-uploader'
|
import DocumentUploader from './document-uploader'
|
||||||
|
|
||||||
jest.mock('@/service/datasets', () => ({
|
vi.mock('@/service/datasets', () => ({
|
||||||
uploadDocument: jest.fn(),
|
uploadDocument: vi.fn(),
|
||||||
parseDocument: jest.fn(),
|
parseDocument: vi.fn(),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
import * as datasetService from '@/service/datasets'
|
import * as datasetService from '@/service/datasets'
|
||||||
const mockedService = datasetService as jest.Mocked<typeof datasetService>
|
const mockedService = vi.mocked(datasetService)
|
||||||
|
|
||||||
describe('DocumentUploader', () => {
|
describe('DocumentUploader', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
})
|
})
|
||||||
|
|
||||||
describe('File Upload', () => {
|
describe('File Upload', () => {
|
||||||
it('should accept valid file types', async () => {
|
it('should accept valid file types', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
const onUpload = jest.fn()
|
const onUpload = vi.fn()
|
||||||
mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' })
|
mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' })
|
||||||
|
|
||||||
render(<DocumentUploader onUpload={onUpload} />)
|
render(<DocumentUploader onUpload={onUpload} />)
|
||||||
|
|
@ -326,14 +326,14 @@ describe('DocumentList', () => {
|
||||||
describe('Search & Filtering', () => {
|
describe('Search & Filtering', () => {
|
||||||
it('should filter by search query', async () => {
|
it('should filter by search query', async () => {
|
||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
jest.useFakeTimers()
|
vi.useFakeTimers()
|
||||||
|
|
||||||
render(<DocumentList datasetId="ds-1" />)
|
render(<DocumentList datasetId="ds-1" />)
|
||||||
|
|
||||||
await user.type(screen.getByPlaceholderText(/search/i), 'test query')
|
await user.type(screen.getByPlaceholderText(/search/i), 'test query')
|
||||||
|
|
||||||
// Debounce
|
// Debounce
|
||||||
jest.advanceTimersByTime(300)
|
vi.advanceTimersByTime(300)
|
||||||
|
|
||||||
await waitFor(() => {
|
await waitFor(() => {
|
||||||
expect(mockedService.getDocuments).toHaveBeenCalledWith(
|
expect(mockedService.getDocuments).toHaveBeenCalledWith(
|
||||||
|
|
@ -342,7 +342,7 @@ describe('DocumentList', () => {
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
jest.useRealTimers()
|
vi.useRealTimers()
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
@ -367,13 +367,13 @@ import { render, screen, fireEvent, waitFor } from '@testing-library/react'
|
||||||
import userEvent from '@testing-library/user-event'
|
import userEvent from '@testing-library/user-event'
|
||||||
import AppConfigForm from './app-config-form'
|
import AppConfigForm from './app-config-form'
|
||||||
|
|
||||||
jest.mock('@/service/apps', () => ({
|
vi.mock('@/service/apps', () => ({
|
||||||
updateAppConfig: jest.fn(),
|
updateAppConfig: vi.fn(),
|
||||||
getAppConfig: jest.fn(),
|
getAppConfig: vi.fn(),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
import * as appService from '@/service/apps'
|
import * as appService from '@/service/apps'
|
||||||
const mockedService = appService as jest.Mocked<typeof appService>
|
const mockedService = vi.mocked(appService)
|
||||||
|
|
||||||
describe('AppConfigForm', () => {
|
describe('AppConfigForm', () => {
|
||||||
const defaultConfig = {
|
const defaultConfig = {
|
||||||
|
|
@ -384,7 +384,7 @@ describe('AppConfigForm', () => {
|
||||||
}
|
}
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
mockedService.getAppConfig.mockResolvedValue(defaultConfig)
|
mockedService.getAppConfig.mockResolvedValue(defaultConfig)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -19,8 +19,8 @@
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
// ❌ WRONG: Don't mock base components
|
// ❌ WRONG: Don't mock base components
|
||||||
jest.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
|
vi.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
|
||||||
jest.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
|
vi.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
|
||||||
|
|
||||||
// ✅ CORRECT: Import and use real base components
|
// ✅ CORRECT: Import and use real base components
|
||||||
import Loading from '@/app/components/base/loading'
|
import Loading from '@/app/components/base/loading'
|
||||||
|
|
@ -41,20 +41,23 @@ Only mock these categories:
|
||||||
|
|
||||||
| Location | Purpose |
|
| Location | Purpose |
|
||||||
|----------|---------|
|
|----------|---------|
|
||||||
| `web/__mocks__/` | Reusable mocks shared across multiple test files |
|
| `web/vitest.setup.ts` | Global mocks shared by all tests (for example `react-i18next`, `next/image`) |
|
||||||
| Test file | Test-specific mocks, inline with `jest.mock()` |
|
| `web/__mocks__/` | Reusable mock factories shared across multiple test files |
|
||||||
|
| Test file | Test-specific mocks, inline with `vi.mock()` |
|
||||||
|
|
||||||
|
Modules are not mocked automatically. Use `vi.mock` in test files, or add global mocks in `web/vitest.setup.ts`.
|
||||||
|
|
||||||
## Essential Mocks
|
## Essential Mocks
|
||||||
|
|
||||||
### 1. i18n (Auto-loaded via Shared Mock)
|
### 1. i18n (Auto-loaded via Global Mock)
|
||||||
|
|
||||||
A shared mock is available at `web/__mocks__/react-i18next.ts` and is auto-loaded by Jest.
|
A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
|
||||||
**No explicit mock needed** for most tests - it returns translation keys as-is.
|
**No explicit mock needed** for most tests - it returns translation keys as-is.
|
||||||
|
|
||||||
For tests requiring custom translations, override the mock:
|
For tests requiring custom translations, override the mock:
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
jest.mock('react-i18next', () => ({
|
vi.mock('react-i18next', () => ({
|
||||||
useTranslation: () => ({
|
useTranslation: () => ({
|
||||||
t: (key: string) => {
|
t: (key: string) => {
|
||||||
const translations: Record<string, string> = {
|
const translations: Record<string, string> = {
|
||||||
|
|
@ -69,15 +72,15 @@ jest.mock('react-i18next', () => ({
|
||||||
### 2. Next.js Router
|
### 2. Next.js Router
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
const mockPush = jest.fn()
|
const mockPush = vi.fn()
|
||||||
const mockReplace = jest.fn()
|
const mockReplace = vi.fn()
|
||||||
|
|
||||||
jest.mock('next/navigation', () => ({
|
vi.mock('next/navigation', () => ({
|
||||||
useRouter: () => ({
|
useRouter: () => ({
|
||||||
push: mockPush,
|
push: mockPush,
|
||||||
replace: mockReplace,
|
replace: mockReplace,
|
||||||
back: jest.fn(),
|
back: vi.fn(),
|
||||||
prefetch: jest.fn(),
|
prefetch: vi.fn(),
|
||||||
}),
|
}),
|
||||||
usePathname: () => '/current-path',
|
usePathname: () => '/current-path',
|
||||||
useSearchParams: () => new URLSearchParams('?key=value'),
|
useSearchParams: () => new URLSearchParams('?key=value'),
|
||||||
|
|
@ -85,7 +88,7 @@ jest.mock('next/navigation', () => ({
|
||||||
|
|
||||||
describe('Component', () => {
|
describe('Component', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should navigate on click', () => {
|
it('should navigate on click', () => {
|
||||||
|
|
@ -102,7 +105,7 @@ describe('Component', () => {
|
||||||
// ⚠️ Important: Use shared state for components that depend on each other
|
// ⚠️ Important: Use shared state for components that depend on each other
|
||||||
let mockPortalOpenState = false
|
let mockPortalOpenState = false
|
||||||
|
|
||||||
jest.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
||||||
PortalToFollowElem: ({ children, open, ...props }: any) => {
|
PortalToFollowElem: ({ children, open, ...props }: any) => {
|
||||||
mockPortalOpenState = open || false // Update shared state
|
mockPortalOpenState = open || false // Update shared state
|
||||||
return <div data-testid="portal" data-open={open}>{children}</div>
|
return <div data-testid="portal" data-open={open}>{children}</div>
|
||||||
|
|
@ -119,7 +122,7 @@ jest.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
||||||
|
|
||||||
describe('Component', () => {
|
describe('Component', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
mockPortalOpenState = false // ✅ Reset shared state
|
mockPortalOpenState = false // ✅ Reset shared state
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
@ -130,13 +133,13 @@ describe('Component', () => {
|
||||||
```typescript
|
```typescript
|
||||||
import * as api from '@/service/api'
|
import * as api from '@/service/api'
|
||||||
|
|
||||||
jest.mock('@/service/api')
|
vi.mock('@/service/api')
|
||||||
|
|
||||||
const mockedApi = api as jest.Mocked<typeof api>
|
const mockedApi = vi.mocked(api)
|
||||||
|
|
||||||
describe('Component', () => {
|
describe('Component', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
|
|
||||||
// Setup default mock implementation
|
// Setup default mock implementation
|
||||||
mockedApi.fetchData.mockResolvedValue({ data: [] })
|
mockedApi.fetchData.mockResolvedValue({ data: [] })
|
||||||
|
|
@ -239,32 +242,9 @@ describe('Component with Context', () => {
|
||||||
})
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
### 7. SWR / React Query
|
### 7. React Query
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
// SWR
|
|
||||||
jest.mock('swr', () => ({
|
|
||||||
__esModule: true,
|
|
||||||
default: jest.fn(),
|
|
||||||
}))
|
|
||||||
|
|
||||||
import useSWR from 'swr'
|
|
||||||
const mockedUseSWR = useSWR as jest.Mock
|
|
||||||
|
|
||||||
describe('Component with SWR', () => {
|
|
||||||
it('should show loading state', () => {
|
|
||||||
mockedUseSWR.mockReturnValue({
|
|
||||||
data: undefined,
|
|
||||||
error: undefined,
|
|
||||||
isLoading: true,
|
|
||||||
})
|
|
||||||
|
|
||||||
render(<Component />)
|
|
||||||
expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
// React Query
|
|
||||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||||
|
|
||||||
const createTestQueryClient = () => new QueryClient({
|
const createTestQueryClient = () => new QueryClient({
|
||||||
|
|
@ -35,7 +35,7 @@ When testing a **single component, hook, or utility**:
|
||||||
2. Run `pnpm analyze-component <path>` (if available)
|
2. Run `pnpm analyze-component <path>` (if available)
|
||||||
3. Check complexity score and features detected
|
3. Check complexity score and features detected
|
||||||
4. Write the test file
|
4. Write the test file
|
||||||
5. Run test: `pnpm test -- <file>.spec.tsx`
|
5. Run test: `pnpm test <file>.spec.tsx`
|
||||||
6. Fix any failures
|
6. Fix any failures
|
||||||
7. Verify coverage meets goals (100% function, >95% branch)
|
7. Verify coverage meets goals (100% function, >95% branch)
|
||||||
```
|
```
|
||||||
|
|
@ -80,7 +80,7 @@ Process files in this recommended order:
|
||||||
```
|
```
|
||||||
┌─────────────────────────────────────────────┐
|
┌─────────────────────────────────────────────┐
|
||||||
│ 1. Write test file │
|
│ 1. Write test file │
|
||||||
│ 2. Run: pnpm test -- <file>.spec.tsx │
|
│ 2. Run: pnpm test <file>.spec.tsx │
|
||||||
│ 3. If FAIL → Fix immediately, re-run │
|
│ 3. If FAIL → Fix immediately, re-run │
|
||||||
│ 4. If PASS → Mark complete in todo list │
|
│ 4. If PASS → Mark complete in todo list │
|
||||||
│ 5. ONLY THEN proceed to next file │
|
│ 5. ONLY THEN proceed to next file │
|
||||||
|
|
@ -95,10 +95,10 @@ After all individual tests pass:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Run all tests in the directory together
|
# Run all tests in the directory together
|
||||||
pnpm test -- path/to/directory/
|
pnpm test path/to/directory/
|
||||||
|
|
||||||
# Check coverage
|
# Check coverage
|
||||||
pnpm test -- --coverage path/to/directory/
|
pnpm test:coverage path/to/directory/
|
||||||
```
|
```
|
||||||
|
|
||||||
## Component Complexity Guidelines
|
## Component Complexity Guidelines
|
||||||
|
|
@ -201,9 +201,9 @@ Run pnpm test ← Multiple failures, hard to debug
|
||||||
```
|
```
|
||||||
# GOOD: Incremental with verification
|
# GOOD: Incremental with verification
|
||||||
Write component-a.spec.tsx
|
Write component-a.spec.tsx
|
||||||
Run pnpm test -- component-a.spec.tsx ✅
|
Run pnpm test component-a.spec.tsx ✅
|
||||||
Write component-b.spec.tsx
|
Write component-b.spec.tsx
|
||||||
Run pnpm test -- component-b.spec.tsx ✅
|
Run pnpm test component-b.spec.tsx ✅
|
||||||
...continue...
|
...continue...
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
../.claude/skills
|
||||||
|
|
@ -6,6 +6,9 @@
|
||||||
"context": "..",
|
"context": "..",
|
||||||
"dockerfile": "Dockerfile"
|
"dockerfile": "Dockerfile"
|
||||||
},
|
},
|
||||||
|
"mounts": [
|
||||||
|
"source=dify-dev-tmp,target=/tmp,type=volume"
|
||||||
|
],
|
||||||
"features": {
|
"features": {
|
||||||
"ghcr.io/devcontainers/features/node:1": {
|
"ghcr.io/devcontainers/features/node:1": {
|
||||||
"nodeGypDependencies": true,
|
"nodeGypDependencies": true,
|
||||||
|
|
@ -34,19 +37,13 @@
|
||||||
},
|
},
|
||||||
"postStartCommand": "./.devcontainer/post_start_command.sh",
|
"postStartCommand": "./.devcontainer/post_start_command.sh",
|
||||||
"postCreateCommand": "./.devcontainer/post_create_command.sh"
|
"postCreateCommand": "./.devcontainer/post_create_command.sh"
|
||||||
|
|
||||||
// Features to add to the dev container. More info: https://containers.dev/features.
|
// Features to add to the dev container. More info: https://containers.dev/features.
|
||||||
// "features": {},
|
// "features": {},
|
||||||
|
|
||||||
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
||||||
// "forwardPorts": [],
|
// "forwardPorts": [],
|
||||||
|
|
||||||
// Use 'postCreateCommand' to run commands after the container is created.
|
// Use 'postCreateCommand' to run commands after the container is created.
|
||||||
// "postCreateCommand": "python --version",
|
// "postCreateCommand": "python --version",
|
||||||
|
|
||||||
// Configure tool-specific properties.
|
// Configure tool-specific properties.
|
||||||
// "customizations": {},
|
// "customizations": {},
|
||||||
|
|
||||||
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
|
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
|
||||||
// "remoteUser": "root"
|
|
||||||
}
|
}
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
WORKSPACE_ROOT=$(pwd)
|
WORKSPACE_ROOT=$(pwd)
|
||||||
|
|
||||||
|
export COREPACK_ENABLE_DOWNLOAD_PROMPT=0
|
||||||
corepack enable
|
corepack enable
|
||||||
cd web && pnpm install
|
cd web && pnpm install
|
||||||
pipx install uv
|
pipx install uv
|
||||||
|
|
|
||||||
|
|
@ -7,234 +7,243 @@
|
||||||
* @crazywoola @laipz8200 @Yeuoly
|
* @crazywoola @laipz8200 @Yeuoly
|
||||||
|
|
||||||
# CODEOWNERS file
|
# CODEOWNERS file
|
||||||
.github/CODEOWNERS @laipz8200 @crazywoola
|
/.github/CODEOWNERS @laipz8200 @crazywoola
|
||||||
|
|
||||||
# Docs
|
# Docs
|
||||||
docs/ @crazywoola
|
/docs/ @crazywoola
|
||||||
|
|
||||||
# Backend (default owner, more specific rules below will override)
|
# Backend (default owner, more specific rules below will override)
|
||||||
api/ @QuantumGhost
|
/api/ @QuantumGhost
|
||||||
|
|
||||||
# Backend - MCP
|
# Backend - MCP
|
||||||
api/core/mcp/ @Nov1c444
|
/api/core/mcp/ @Nov1c444
|
||||||
api/core/entities/mcp_provider.py @Nov1c444
|
/api/core/entities/mcp_provider.py @Nov1c444
|
||||||
api/services/tools/mcp_tools_manage_service.py @Nov1c444
|
/api/services/tools/mcp_tools_manage_service.py @Nov1c444
|
||||||
api/controllers/mcp/ @Nov1c444
|
/api/controllers/mcp/ @Nov1c444
|
||||||
api/controllers/console/app/mcp_server.py @Nov1c444
|
/api/controllers/console/app/mcp_server.py @Nov1c444
|
||||||
api/tests/**/*mcp* @Nov1c444
|
/api/tests/**/*mcp* @Nov1c444
|
||||||
|
|
||||||
# Backend - Workflow - Engine (Core graph execution engine)
|
# Backend - Workflow - Engine (Core graph execution engine)
|
||||||
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
|
/api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
|
||||||
api/core/workflow/runtime/ @laipz8200 @QuantumGhost
|
/api/core/workflow/runtime/ @laipz8200 @QuantumGhost
|
||||||
api/core/workflow/graph/ @laipz8200 @QuantumGhost
|
/api/core/workflow/graph/ @laipz8200 @QuantumGhost
|
||||||
api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
|
/api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
|
||||||
api/core/workflow/node_events/ @laipz8200 @QuantumGhost
|
/api/core/workflow/node_events/ @laipz8200 @QuantumGhost
|
||||||
api/core/model_runtime/ @laipz8200 @QuantumGhost
|
/api/core/model_runtime/ @laipz8200 @QuantumGhost
|
||||||
|
|
||||||
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
|
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
|
||||||
api/core/workflow/nodes/agent/ @Nov1c444
|
/api/core/workflow/nodes/agent/ @Nov1c444
|
||||||
api/core/workflow/nodes/iteration/ @Nov1c444
|
/api/core/workflow/nodes/iteration/ @Nov1c444
|
||||||
api/core/workflow/nodes/loop/ @Nov1c444
|
/api/core/workflow/nodes/loop/ @Nov1c444
|
||||||
api/core/workflow/nodes/llm/ @Nov1c444
|
/api/core/workflow/nodes/llm/ @Nov1c444
|
||||||
|
|
||||||
# Backend - RAG (Retrieval Augmented Generation)
|
# Backend - RAG (Retrieval Augmented Generation)
|
||||||
api/core/rag/ @JohnJyong
|
/api/core/rag/ @JohnJyong
|
||||||
api/services/rag_pipeline/ @JohnJyong
|
/api/services/rag_pipeline/ @JohnJyong
|
||||||
api/services/dataset_service.py @JohnJyong
|
/api/services/dataset_service.py @JohnJyong
|
||||||
api/services/knowledge_service.py @JohnJyong
|
/api/services/knowledge_service.py @JohnJyong
|
||||||
api/services/external_knowledge_service.py @JohnJyong
|
/api/services/external_knowledge_service.py @JohnJyong
|
||||||
api/services/hit_testing_service.py @JohnJyong
|
/api/services/hit_testing_service.py @JohnJyong
|
||||||
api/services/metadata_service.py @JohnJyong
|
/api/services/metadata_service.py @JohnJyong
|
||||||
api/services/vector_service.py @JohnJyong
|
/api/services/vector_service.py @JohnJyong
|
||||||
api/services/entities/knowledge_entities/ @JohnJyong
|
/api/services/entities/knowledge_entities/ @JohnJyong
|
||||||
api/services/entities/external_knowledge_entities/ @JohnJyong
|
/api/services/entities/external_knowledge_entities/ @JohnJyong
|
||||||
api/controllers/console/datasets/ @JohnJyong
|
/api/controllers/console/datasets/ @JohnJyong
|
||||||
api/controllers/service_api/dataset/ @JohnJyong
|
/api/controllers/service_api/dataset/ @JohnJyong
|
||||||
api/models/dataset.py @JohnJyong
|
/api/models/dataset.py @JohnJyong
|
||||||
api/tasks/rag_pipeline/ @JohnJyong
|
/api/tasks/rag_pipeline/ @JohnJyong
|
||||||
api/tasks/add_document_to_index_task.py @JohnJyong
|
/api/tasks/add_document_to_index_task.py @JohnJyong
|
||||||
api/tasks/batch_clean_document_task.py @JohnJyong
|
/api/tasks/batch_clean_document_task.py @JohnJyong
|
||||||
api/tasks/clean_document_task.py @JohnJyong
|
/api/tasks/clean_document_task.py @JohnJyong
|
||||||
api/tasks/clean_notion_document_task.py @JohnJyong
|
/api/tasks/clean_notion_document_task.py @JohnJyong
|
||||||
api/tasks/document_indexing_task.py @JohnJyong
|
/api/tasks/document_indexing_task.py @JohnJyong
|
||||||
api/tasks/document_indexing_sync_task.py @JohnJyong
|
/api/tasks/document_indexing_sync_task.py @JohnJyong
|
||||||
api/tasks/document_indexing_update_task.py @JohnJyong
|
/api/tasks/document_indexing_update_task.py @JohnJyong
|
||||||
api/tasks/duplicate_document_indexing_task.py @JohnJyong
|
/api/tasks/duplicate_document_indexing_task.py @JohnJyong
|
||||||
api/tasks/recover_document_indexing_task.py @JohnJyong
|
/api/tasks/recover_document_indexing_task.py @JohnJyong
|
||||||
api/tasks/remove_document_from_index_task.py @JohnJyong
|
/api/tasks/remove_document_from_index_task.py @JohnJyong
|
||||||
api/tasks/retry_document_indexing_task.py @JohnJyong
|
/api/tasks/retry_document_indexing_task.py @JohnJyong
|
||||||
api/tasks/sync_website_document_indexing_task.py @JohnJyong
|
/api/tasks/sync_website_document_indexing_task.py @JohnJyong
|
||||||
api/tasks/batch_create_segment_to_index_task.py @JohnJyong
|
/api/tasks/batch_create_segment_to_index_task.py @JohnJyong
|
||||||
api/tasks/create_segment_to_index_task.py @JohnJyong
|
/api/tasks/create_segment_to_index_task.py @JohnJyong
|
||||||
api/tasks/delete_segment_from_index_task.py @JohnJyong
|
/api/tasks/delete_segment_from_index_task.py @JohnJyong
|
||||||
api/tasks/disable_segment_from_index_task.py @JohnJyong
|
/api/tasks/disable_segment_from_index_task.py @JohnJyong
|
||||||
api/tasks/disable_segments_from_index_task.py @JohnJyong
|
/api/tasks/disable_segments_from_index_task.py @JohnJyong
|
||||||
api/tasks/enable_segment_to_index_task.py @JohnJyong
|
/api/tasks/enable_segment_to_index_task.py @JohnJyong
|
||||||
api/tasks/enable_segments_to_index_task.py @JohnJyong
|
/api/tasks/enable_segments_to_index_task.py @JohnJyong
|
||||||
api/tasks/clean_dataset_task.py @JohnJyong
|
/api/tasks/clean_dataset_task.py @JohnJyong
|
||||||
api/tasks/deal_dataset_index_update_task.py @JohnJyong
|
/api/tasks/deal_dataset_index_update_task.py @JohnJyong
|
||||||
api/tasks/deal_dataset_vector_index_task.py @JohnJyong
|
/api/tasks/deal_dataset_vector_index_task.py @JohnJyong
|
||||||
|
|
||||||
# Backend - Plugins
|
# Backend - Plugins
|
||||||
api/core/plugin/ @Mairuis @Yeuoly @Stream29
|
/api/core/plugin/ @Mairuis @Yeuoly @Stream29
|
||||||
api/services/plugin/ @Mairuis @Yeuoly @Stream29
|
/api/services/plugin/ @Mairuis @Yeuoly @Stream29
|
||||||
api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
|
/api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
|
||||||
api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
|
/api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
|
||||||
api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
|
/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
|
||||||
|
|
||||||
# Backend - Trigger/Schedule/Webhook
|
# Backend - Trigger/Schedule/Webhook
|
||||||
api/controllers/trigger/ @Mairuis @Yeuoly
|
/api/controllers/trigger/ @Mairuis @Yeuoly
|
||||||
api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
|
/api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
|
||||||
api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
|
/api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
|
||||||
api/core/trigger/ @Mairuis @Yeuoly
|
/api/core/trigger/ @Mairuis @Yeuoly
|
||||||
api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
|
/api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
|
||||||
api/services/trigger/ @Mairuis @Yeuoly
|
/api/services/trigger/ @Mairuis @Yeuoly
|
||||||
api/models/trigger.py @Mairuis @Yeuoly
|
/api/models/trigger.py @Mairuis @Yeuoly
|
||||||
api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
|
/api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
|
||||||
api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
/api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||||
api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||||
api/libs/schedule_utils.py @Mairuis @Yeuoly
|
/api/libs/schedule_utils.py @Mairuis @Yeuoly
|
||||||
api/services/workflow/scheduler.py @Mairuis @Yeuoly
|
/api/services/workflow/scheduler.py @Mairuis @Yeuoly
|
||||||
api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
|
/api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
|
||||||
api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
|
/api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
|
||||||
api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
|
/api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
|
||||||
api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
|
/api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
|
||||||
api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
|
/api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
|
||||||
api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
|
/api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
|
||||||
api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
|
/api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
|
||||||
api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
|
/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
|
||||||
api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
|
/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
|
||||||
api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
|
/api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
|
||||||
|
|
||||||
# Backend - Async Workflow
|
# Backend - Async Workflow
|
||||||
api/services/async_workflow_service.py @Mairuis @Yeuoly
|
/api/services/async_workflow_service.py @Mairuis @Yeuoly
|
||||||
api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
|
/api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
|
||||||
|
|
||||||
# Backend - Billing
|
# Backend - Billing
|
||||||
api/services/billing_service.py @hj24 @zyssyz123
|
/api/services/billing_service.py @hj24 @zyssyz123
|
||||||
api/controllers/console/billing/ @hj24 @zyssyz123
|
/api/controllers/console/billing/ @hj24 @zyssyz123
|
||||||
|
|
||||||
# Backend - Enterprise
|
# Backend - Enterprise
|
||||||
api/configs/enterprise/ @GarfieldDai @GareArc
|
/api/configs/enterprise/ @GarfieldDai @GareArc
|
||||||
api/services/enterprise/ @GarfieldDai @GareArc
|
/api/services/enterprise/ @GarfieldDai @GareArc
|
||||||
api/services/feature_service.py @GarfieldDai @GareArc
|
/api/services/feature_service.py @GarfieldDai @GareArc
|
||||||
api/controllers/console/feature.py @GarfieldDai @GareArc
|
/api/controllers/console/feature.py @GarfieldDai @GareArc
|
||||||
api/controllers/web/feature.py @GarfieldDai @GareArc
|
/api/controllers/web/feature.py @GarfieldDai @GareArc
|
||||||
|
|
||||||
# Backend - Database Migrations
|
# Backend - Database Migrations
|
||||||
api/migrations/ @snakevash @laipz8200 @MRZHUH
|
/api/migrations/ @snakevash @laipz8200 @MRZHUH
|
||||||
|
|
||||||
|
# Backend - Vector DB Middleware
|
||||||
|
/api/configs/middleware/vdb/* @JohnJyong
|
||||||
|
|
||||||
# Frontend
|
# Frontend
|
||||||
web/ @iamjoel
|
/web/ @iamjoel
|
||||||
|
|
||||||
|
# Frontend - Web Tests
|
||||||
|
/.github/workflows/web-tests.yml @iamjoel
|
||||||
|
|
||||||
# Frontend - App - Orchestration
|
# Frontend - App - Orchestration
|
||||||
web/app/components/workflow/ @iamjoel @zxhlyh
|
/web/app/components/workflow/ @iamjoel @zxhlyh
|
||||||
web/app/components/workflow-app/ @iamjoel @zxhlyh
|
/web/app/components/workflow-app/ @iamjoel @zxhlyh
|
||||||
web/app/components/app/configuration/ @iamjoel @zxhlyh
|
/web/app/components/app/configuration/ @iamjoel @zxhlyh
|
||||||
web/app/components/app/app-publisher/ @iamjoel @zxhlyh
|
/web/app/components/app/app-publisher/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - WebApp - Chat
|
# Frontend - WebApp - Chat
|
||||||
web/app/components/base/chat/ @iamjoel @zxhlyh
|
/web/app/components/base/chat/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - WebApp - Completion
|
# Frontend - WebApp - Completion
|
||||||
web/app/components/share/text-generation/ @iamjoel @zxhlyh
|
/web/app/components/share/text-generation/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - App - List and Creation
|
# Frontend - App - List and Creation
|
||||||
web/app/components/apps/ @JzoNgKVO @iamjoel
|
/web/app/components/apps/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
|
/web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
|
/web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
|
/web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
|
||||||
|
|
||||||
# Frontend - App - API Documentation
|
# Frontend - App - API Documentation
|
||||||
web/app/components/develop/ @JzoNgKVO @iamjoel
|
/web/app/components/develop/ @JzoNgKVO @iamjoel
|
||||||
|
|
||||||
# Frontend - App - Logs and Annotations
|
# Frontend - App - Logs and Annotations
|
||||||
web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
|
/web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/log/ @JzoNgKVO @iamjoel
|
/web/app/components/app/log/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
|
/web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/annotation/ @JzoNgKVO @iamjoel
|
/web/app/components/app/annotation/ @JzoNgKVO @iamjoel
|
||||||
|
|
||||||
# Frontend - App - Monitoring
|
# Frontend - App - Monitoring
|
||||||
web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
|
/web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
|
||||||
web/app/components/app/overview/ @JzoNgKVO @iamjoel
|
/web/app/components/app/overview/ @JzoNgKVO @iamjoel
|
||||||
|
|
||||||
# Frontend - App - Settings
|
# Frontend - App - Settings
|
||||||
web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
|
/web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
|
||||||
|
|
||||||
# Frontend - RAG - Hit Testing
|
# Frontend - RAG - Hit Testing
|
||||||
web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
|
/web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
|
||||||
|
|
||||||
# Frontend - RAG - List and Creation
|
# Frontend - RAG - List and Creation
|
||||||
web/app/components/datasets/list/ @iamjoel @WTW0313
|
/web/app/components/datasets/list/ @iamjoel @WTW0313
|
||||||
web/app/components/datasets/create/ @iamjoel @WTW0313
|
/web/app/components/datasets/create/ @iamjoel @WTW0313
|
||||||
web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
|
/web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
|
||||||
web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
|
/web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
|
||||||
|
|
||||||
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
|
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
|
||||||
web/app/components/rag-pipeline/ @iamjoel @WTW0313
|
/web/app/components/rag-pipeline/ @iamjoel @WTW0313
|
||||||
web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
|
/web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
|
||||||
web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
|
/web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - RAG - Documents List
|
# Frontend - RAG - Documents List
|
||||||
web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
|
/web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
|
||||||
web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
|
/web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
|
||||||
|
|
||||||
# Frontend - RAG - Segments List
|
# Frontend - RAG - Segments List
|
||||||
web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
|
/web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
|
||||||
|
|
||||||
# Frontend - RAG - Settings
|
# Frontend - RAG - Settings
|
||||||
web/app/components/datasets/settings/ @iamjoel @WTW0313
|
/web/app/components/datasets/settings/ @iamjoel @WTW0313
|
||||||
|
|
||||||
# Frontend - Ecosystem - Plugins
|
# Frontend - Ecosystem - Plugins
|
||||||
web/app/components/plugins/ @iamjoel @zhsama
|
/web/app/components/plugins/ @iamjoel @zhsama
|
||||||
|
|
||||||
# Frontend - Ecosystem - Tools
|
# Frontend - Ecosystem - Tools
|
||||||
web/app/components/tools/ @iamjoel @Yessenia-d
|
/web/app/components/tools/ @iamjoel @Yessenia-d
|
||||||
|
|
||||||
# Frontend - Ecosystem - MarketPlace
|
# Frontend - Ecosystem - MarketPlace
|
||||||
web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
|
/web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
|
||||||
|
|
||||||
# Frontend - Login and Registration
|
# Frontend - Login and Registration
|
||||||
web/app/signin/ @douxc @iamjoel
|
/web/app/signin/ @douxc @iamjoel
|
||||||
web/app/signup/ @douxc @iamjoel
|
/web/app/signup/ @douxc @iamjoel
|
||||||
web/app/reset-password/ @douxc @iamjoel
|
/web/app/reset-password/ @douxc @iamjoel
|
||||||
web/app/install/ @douxc @iamjoel
|
/web/app/install/ @douxc @iamjoel
|
||||||
web/app/init/ @douxc @iamjoel
|
/web/app/init/ @douxc @iamjoel
|
||||||
web/app/forgot-password/ @douxc @iamjoel
|
/web/app/forgot-password/ @douxc @iamjoel
|
||||||
web/app/account/ @douxc @iamjoel
|
/web/app/account/ @douxc @iamjoel
|
||||||
|
|
||||||
# Frontend - Service Authentication
|
# Frontend - Service Authentication
|
||||||
web/service/base.ts @douxc @iamjoel
|
/web/service/base.ts @douxc @iamjoel
|
||||||
|
|
||||||
# Frontend - WebApp Authentication and Access Control
|
# Frontend - WebApp Authentication and Access Control
|
||||||
web/app/(shareLayout)/components/ @douxc @iamjoel
|
/web/app/(shareLayout)/components/ @douxc @iamjoel
|
||||||
web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
|
/web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
|
||||||
web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
|
/web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
|
||||||
web/app/components/app/app-access-control/ @douxc @iamjoel
|
/web/app/components/app/app-access-control/ @douxc @iamjoel
|
||||||
|
|
||||||
# Frontend - Explore Page
|
# Frontend - Explore Page
|
||||||
web/app/components/explore/ @CodingOnStar @iamjoel
|
/web/app/components/explore/ @CodingOnStar @iamjoel
|
||||||
|
|
||||||
# Frontend - Personal Settings
|
# Frontend - Personal Settings
|
||||||
web/app/components/header/account-setting/ @CodingOnStar @iamjoel
|
/web/app/components/header/account-setting/ @CodingOnStar @iamjoel
|
||||||
web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
|
/web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
|
||||||
|
|
||||||
# Frontend - Analytics
|
# Frontend - Analytics
|
||||||
web/app/components/base/ga/ @CodingOnStar @iamjoel
|
/web/app/components/base/ga/ @CodingOnStar @iamjoel
|
||||||
|
|
||||||
# Frontend - Base Components
|
# Frontend - Base Components
|
||||||
web/app/components/base/ @iamjoel @zxhlyh
|
/web/app/components/base/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - Utils and Hooks
|
# Frontend - Utils and Hooks
|
||||||
web/utils/classnames.ts @iamjoel @zxhlyh
|
/web/utils/classnames.ts @iamjoel @zxhlyh
|
||||||
web/utils/time.ts @iamjoel @zxhlyh
|
/web/utils/time.ts @iamjoel @zxhlyh
|
||||||
web/utils/format.ts @iamjoel @zxhlyh
|
/web/utils/format.ts @iamjoel @zxhlyh
|
||||||
web/utils/clipboard.ts @iamjoel @zxhlyh
|
/web/utils/clipboard.ts @iamjoel @zxhlyh
|
||||||
web/hooks/use-document-title.ts @iamjoel @zxhlyh
|
/web/hooks/use-document-title.ts @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - Billing and Education
|
# Frontend - Billing and Education
|
||||||
web/app/components/billing/ @iamjoel @zxhlyh
|
/web/app/components/billing/ @iamjoel @zxhlyh
|
||||||
web/app/education-apply/ @iamjoel @zxhlyh
|
/web/app/education-apply/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
# Frontend - Workspace
|
# Frontend - Workspace
|
||||||
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
|
/web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
|
||||||
|
|
||||||
|
# Docker
|
||||||
|
/docker/* @laipz8200
|
||||||
|
|
|
||||||
|
|
@ -13,12 +13,28 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Check Docker Compose inputs
|
||||||
|
id: docker-compose-changes
|
||||||
|
uses: tj-actions/changed-files@v46
|
||||||
|
with:
|
||||||
|
files: |
|
||||||
|
docker/generate_docker_compose
|
||||||
|
docker/.env.example
|
||||||
|
docker/docker-compose-template.yaml
|
||||||
|
docker/docker-compose.yaml
|
||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
- uses: astral-sh/setup-uv@v6
|
- uses: astral-sh/setup-uv@v6
|
||||||
|
|
||||||
|
- name: Generate Docker Compose
|
||||||
|
if: steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||||
|
run: |
|
||||||
|
cd docker
|
||||||
|
./generate_docker_compose
|
||||||
|
|
||||||
- run: |
|
- run: |
|
||||||
cd api
|
cd api
|
||||||
uv sync --dev
|
uv sync --dev
|
||||||
|
|
@ -66,27 +82,6 @@ jobs:
|
||||||
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
|
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
|
||||||
- name: mdformat
|
- name: mdformat
|
||||||
run: |
|
run: |
|
||||||
uvx --python 3.13 mdformat . --exclude ".claude/skills/**"
|
uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md"
|
||||||
|
|
||||||
- name: Install pnpm
|
|
||||||
uses: pnpm/action-setup@v4
|
|
||||||
with:
|
|
||||||
package_json_file: web/package.json
|
|
||||||
run_install: false
|
|
||||||
|
|
||||||
- name: Setup NodeJS
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: 22
|
|
||||||
cache: pnpm
|
|
||||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
|
||||||
|
|
||||||
- name: Web dependencies
|
|
||||||
working-directory: ./web
|
|
||||||
run: pnpm install --frozen-lockfile
|
|
||||||
|
|
||||||
- name: oxlint
|
|
||||||
working-directory: ./web
|
|
||||||
run: pnpm exec oxlint --config .oxlintrc.json --fix .
|
|
||||||
|
|
||||||
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
|
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
|
||||||
|
|
|
||||||
|
|
@ -108,36 +108,6 @@ jobs:
|
||||||
working-directory: ./web
|
working-directory: ./web
|
||||||
run: pnpm run type-check:tsgo
|
run: pnpm run type-check:tsgo
|
||||||
|
|
||||||
docker-compose-template:
|
|
||||||
name: Docker Compose Template
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Check changed files
|
|
||||||
id: changed-files
|
|
||||||
uses: tj-actions/changed-files@v46
|
|
||||||
with:
|
|
||||||
files: |
|
|
||||||
docker/generate_docker_compose
|
|
||||||
docker/.env.example
|
|
||||||
docker/docker-compose-template.yaml
|
|
||||||
docker/docker-compose.yaml
|
|
||||||
|
|
||||||
- name: Generate Docker Compose
|
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
|
||||||
run: |
|
|
||||||
cd docker
|
|
||||||
./generate_docker_compose
|
|
||||||
|
|
||||||
- name: Check for changes
|
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
|
||||||
run: git diff --exit-code
|
|
||||||
|
|
||||||
superlinter:
|
superlinter:
|
||||||
name: SuperLinter
|
name: SuperLinter
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
name: Check i18n Files and Create PR
|
name: Translate i18n Files Based on English
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
|
|
@ -67,25 +67,19 @@ jobs:
|
||||||
working-directory: ./web
|
working-directory: ./web
|
||||||
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
|
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
|
||||||
|
|
||||||
- name: Generate i18n type definitions
|
|
||||||
if: env.FILES_CHANGED == 'true'
|
|
||||||
working-directory: ./web
|
|
||||||
run: pnpm run gen:i18n-types
|
|
||||||
|
|
||||||
- name: Create Pull Request
|
- name: Create Pull Request
|
||||||
if: env.FILES_CHANGED == 'true'
|
if: env.FILES_CHANGED == 'true'
|
||||||
uses: peter-evans/create-pull-request@v6
|
uses: peter-evans/create-pull-request@v6
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
commit-message: 'chore(i18n): update translations based on en-US changes'
|
commit-message: 'chore(i18n): update translations based on en-US changes'
|
||||||
title: 'chore(i18n): translate i18n files and update type definitions'
|
title: 'chore(i18n): translate i18n files based on en-US changes'
|
||||||
body: |
|
body: |
|
||||||
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
|
This PR was automatically created to update i18n translation files based on changes in en-US locale.
|
||||||
|
|
||||||
**Triggered by:** ${{ github.sha }}
|
**Triggered by:** ${{ github.sha }}
|
||||||
|
|
||||||
**Changes included:**
|
**Changes included:**
|
||||||
- Updated translation files for all locales
|
- Updated translation files for all locales
|
||||||
- Regenerated TypeScript type definitions for type safety
|
|
||||||
branch: chore/automated-i18n-updates-${{ github.sha }}
|
branch: chore/automated-i18n-updates-${{ github.sha }}
|
||||||
delete-branch: true
|
delete-branch: true
|
||||||
|
|
|
||||||
|
|
@ -35,27 +35,11 @@ jobs:
|
||||||
cache: pnpm
|
cache: pnpm
|
||||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||||
|
|
||||||
- name: Restore Jest cache
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: web/.cache/jest
|
|
||||||
key: ${{ runner.os }}-jest-${{ hashFiles('web/pnpm-lock.yaml') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-jest-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pnpm install --frozen-lockfile
|
run: pnpm install --frozen-lockfile
|
||||||
|
|
||||||
- name: Check i18n types synchronization
|
|
||||||
run: pnpm run check:i18n-types
|
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: pnpm test:coverage
|
||||||
pnpm exec jest \
|
|
||||||
--ci \
|
|
||||||
--maxWorkers=100% \
|
|
||||||
--coverage \
|
|
||||||
--passWithNoTests
|
|
||||||
|
|
||||||
- name: Coverage Summary
|
- name: Coverage Summary
|
||||||
if: always()
|
if: always()
|
||||||
|
|
@ -69,7 +53,7 @@ jobs:
|
||||||
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
|
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
|
||||||
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
|
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
|
||||||
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
|
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
|
||||||
echo "Coverage data not found. Ensure Jest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
|
echo "Coverage data not found. Ensure Vitest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
|
||||||
exit 0
|
exit 0
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
@ -365,7 +349,7 @@ jobs:
|
||||||
.join(' | ')} |`;
|
.join(' | ')} |`;
|
||||||
|
|
||||||
console.log('');
|
console.log('');
|
||||||
console.log('<details><summary>Jest coverage table</summary>');
|
console.log('<details><summary>Vitest coverage table</summary>');
|
||||||
console.log('');
|
console.log('');
|
||||||
console.log(headerRow);
|
console.log(headerRow);
|
||||||
console.log(dividerRow);
|
console.log(dividerRow);
|
||||||
|
|
|
||||||
|
|
@ -139,7 +139,6 @@ pyrightconfig.json
|
||||||
.idea/'
|
.idea/'
|
||||||
|
|
||||||
.DS_Store
|
.DS_Store
|
||||||
web/.vscode/settings.json
|
|
||||||
|
|
||||||
# Intellij IDEA Files
|
# Intellij IDEA Files
|
||||||
.idea/*
|
.idea/*
|
||||||
|
|
@ -196,6 +195,7 @@ docker/nginx/ssl/*
|
||||||
!docker/nginx/ssl/.gitkeep
|
!docker/nginx/ssl/.gitkeep
|
||||||
docker/middleware.env
|
docker/middleware.env
|
||||||
docker/docker-compose.override.yaml
|
docker/docker-compose.override.yaml
|
||||||
|
docker/env-backup/*
|
||||||
|
|
||||||
sdks/python-client/build
|
sdks/python-client/build
|
||||||
sdks/python-client/dist
|
sdks/python-client/dist
|
||||||
|
|
@ -205,7 +205,6 @@ sdks/python-client/dify_client.egg-info
|
||||||
!.vscode/launch.json.template
|
!.vscode/launch.json.template
|
||||||
!.vscode/README.md
|
!.vscode/README.md
|
||||||
api/.vscode
|
api/.vscode
|
||||||
web/.vscode
|
|
||||||
# vscode Code History Extension
|
# vscode Code History Extension
|
||||||
.history
|
.history
|
||||||
|
|
||||||
|
|
@ -220,15 +219,6 @@ plugins.jsonl
|
||||||
# mise
|
# mise
|
||||||
mise.toml
|
mise.toml
|
||||||
|
|
||||||
# Next.js build output
|
|
||||||
.next/
|
|
||||||
|
|
||||||
# PWA generated files
|
|
||||||
web/public/sw.js
|
|
||||||
web/public/sw.js.map
|
|
||||||
web/public/workbox-*.js
|
|
||||||
web/public/workbox-*.js.map
|
|
||||||
web/public/fallback-*.js
|
|
||||||
|
|
||||||
# AI Assistant
|
# AI Assistant
|
||||||
.roo/
|
.roo/
|
||||||
|
|
|
||||||
|
|
@ -116,6 +116,7 @@ ALIYUN_OSS_AUTH_VERSION=v1
|
||||||
ALIYUN_OSS_REGION=your-region
|
ALIYUN_OSS_REGION=your-region
|
||||||
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
||||||
ALIYUN_OSS_PATH=your-path
|
ALIYUN_OSS_PATH=your-path
|
||||||
|
ALIYUN_CLOUDBOX_ID=your-cloudbox-id
|
||||||
|
|
||||||
# Google Storage configuration
|
# Google Storage configuration
|
||||||
GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
|
GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
|
||||||
|
|
@ -133,6 +134,7 @@ HUAWEI_OBS_BUCKET_NAME=your-bucket-name
|
||||||
HUAWEI_OBS_SECRET_KEY=your-secret-key
|
HUAWEI_OBS_SECRET_KEY=your-secret-key
|
||||||
HUAWEI_OBS_ACCESS_KEY=your-access-key
|
HUAWEI_OBS_ACCESS_KEY=your-access-key
|
||||||
HUAWEI_OBS_SERVER=your-server-url
|
HUAWEI_OBS_SERVER=your-server-url
|
||||||
|
HUAWEI_OBS_PATH_STYLE=false
|
||||||
|
|
||||||
# Baidu OBS Storage Configuration
|
# Baidu OBS Storage Configuration
|
||||||
BAIDU_OBS_BUCKET_NAME=your-bucket-name
|
BAIDU_OBS_BUCKET_NAME=your-bucket-name
|
||||||
|
|
@ -690,7 +692,6 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
|
||||||
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
|
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
|
||||||
# Maximum number of concurrent annotation import tasks per tenant
|
# Maximum number of concurrent annotation import tasks per tenant
|
||||||
ANNOTATION_IMPORT_MAX_CONCURRENT=5
|
ANNOTATION_IMPORT_MAX_CONCURRENT=5
|
||||||
|
|
||||||
# Sandbox expired records clean configuration
|
# Sandbox expired records clean configuration
|
||||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
|
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
|
||||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
||||||
|
|
|
||||||
|
|
@ -41,3 +41,8 @@ class AliyunOSSStorageConfig(BaseSettings):
|
||||||
description="Base path within the bucket to store objects (e.g., 'my-app-data/')",
|
description="Base path within the bucket to store objects (e.g., 'my-app-data/')",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ALIYUN_CLOUDBOX_ID: str | None = Field(
|
||||||
|
description="Cloudbox id for aliyun cloudbox service",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -26,3 +26,8 @@ class HuaweiCloudOBSStorageConfig(BaseSettings):
|
||||||
description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')",
|
description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
HUAWEI_OBS_PATH_STYLE: bool = Field(
|
||||||
|
description="Flag to indicate whether to use path-style URLs for OBS requests",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
import os
|
||||||
|
from email.message import Message
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
from flask import Response
|
||||||
|
|
||||||
|
HTML_MIME_TYPES = frozenset({"text/html", "application/xhtml+xml"})
|
||||||
|
HTML_EXTENSIONS = frozenset({"html", "htm"})
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_mime_type(mime_type: str | None) -> str:
|
||||||
|
if not mime_type:
|
||||||
|
return ""
|
||||||
|
message = Message()
|
||||||
|
message["Content-Type"] = mime_type
|
||||||
|
return message.get_content_type().strip().lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_html_extension(extension: str | None) -> bool:
|
||||||
|
if not extension:
|
||||||
|
return False
|
||||||
|
return extension.lstrip(".").lower() in HTML_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
def is_html_content(mime_type: str | None, filename: str | None, extension: str | None = None) -> bool:
|
||||||
|
normalized_mime_type = _normalize_mime_type(mime_type)
|
||||||
|
if normalized_mime_type in HTML_MIME_TYPES:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if _is_html_extension(extension):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if filename:
|
||||||
|
return _is_html_extension(os.path.splitext(filename)[1])
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def enforce_download_for_html(
|
||||||
|
response: Response,
|
||||||
|
*,
|
||||||
|
mime_type: str | None,
|
||||||
|
filename: str | None,
|
||||||
|
extension: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
if not is_html_content(mime_type, filename, extension):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if filename:
|
||||||
|
encoded_filename = quote(filename)
|
||||||
|
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||||
|
else:
|
||||||
|
response.headers["Content-Disposition"] = "attachment"
|
||||||
|
|
||||||
|
response.headers["Content-Type"] = "application/octet-stream"
|
||||||
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||||
|
return True
|
||||||
|
|
@ -7,9 +7,9 @@ from controllers.console import console_ns
|
||||||
from controllers.console.error import AlreadyActivateError
|
from controllers.console.error import AlreadyActivateError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.helper import EmailStr, extract_remote_ip, timezone
|
from libs.helper import EmailStr, timezone
|
||||||
from models import AccountStatus
|
from models import AccountStatus
|
||||||
from services.account_service import AccountService, RegisterService
|
from services.account_service import RegisterService
|
||||||
|
|
||||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
|
|
||||||
|
|
@ -93,7 +93,6 @@ class ActivateApi(Resource):
|
||||||
"ActivationResponse",
|
"ActivationResponse",
|
||||||
{
|
{
|
||||||
"result": fields.String(description="Operation result"),
|
"result": fields.String(description="Operation result"),
|
||||||
"data": fields.Raw(description="Login token data"),
|
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -117,6 +116,4 @@ class ActivateApi(Resource):
|
||||||
account.initialized_at = naive_utc_now()
|
account.initialized_at = naive_utc_now()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
return {"result": "success"}
|
||||||
|
|
||||||
return {"result": "success", "data": token_pair.model_dump()}
|
|
||||||
|
|
|
||||||
|
|
@ -572,7 +572,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||||
datasource_type=DatasourceType.NOTION,
|
datasource_type=DatasourceType.NOTION,
|
||||||
notion_info=NotionInfo.model_validate(
|
notion_info=NotionInfo.model_validate(
|
||||||
{
|
{
|
||||||
"credential_id": data_source_info["credential_id"],
|
"credential_id": data_source_info.get("credential_id"),
|
||||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||||
"notion_obj_id": data_source_info["notion_page_id"],
|
"notion_obj_id": data_source_info["notion_page_id"],
|
||||||
"notion_page_type": data_source_info["type"],
|
"notion_page_type": data_source_info["type"],
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import marshal_with
|
from flask_restx import marshal_with
|
||||||
|
|
@ -13,6 +12,7 @@ from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||||
|
from libs.helper import UUIDStrOrEmpty
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
from models import Account
|
from models import Account
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
|
@ -24,7 +24,7 @@ from .. import console_ns
|
||||||
|
|
||||||
|
|
||||||
class ConversationListQuery(BaseModel):
|
class ConversationListQuery(BaseModel):
|
||||||
last_id: UUID | None = None
|
last_id: UUIDStrOrEmpty | None = None
|
||||||
limit: int = Field(default=20, ge=1, le=100)
|
limit: int = Field(default=20, ge=1, le=100)
|
||||||
pinned: bool | None = None
|
pinned: bool | None = None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,8 @@ import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, inputs, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with
|
||||||
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import and_, select
|
from sqlalchemy import and_, select
|
||||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||||
|
|
||||||
|
|
@ -18,6 +19,15 @@ from services.account_service import TenantService
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
|
||||||
|
class InstalledAppCreatePayload(BaseModel):
|
||||||
|
app_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class InstalledAppUpdatePayload(BaseModel):
|
||||||
|
is_pinned: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -105,26 +115,25 @@ class InstalledAppsListApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("apps")
|
@cloud_edition_billing_resource_check("apps")
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id")
|
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
|
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first()
|
||||||
if recommended_app is None:
|
if recommended_app is None:
|
||||||
raise NotFound("App not found")
|
raise NotFound("Recommended app not found")
|
||||||
|
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
app = db.session.query(App).where(App.id == args["app_id"]).first()
|
app = db.session.query(App).where(App.id == payload.app_id).first()
|
||||||
|
|
||||||
if app is None:
|
if app is None:
|
||||||
raise NotFound("App not found")
|
raise NotFound("App entity not found")
|
||||||
|
|
||||||
if not app.is_public:
|
if not app.is_public:
|
||||||
raise Forbidden("You can't install a non-public app")
|
raise Forbidden("You can't install a non-public app")
|
||||||
|
|
||||||
installed_app = (
|
installed_app = (
|
||||||
db.session.query(InstalledApp)
|
db.session.query(InstalledApp)
|
||||||
.where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
|
.where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id))
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -133,7 +142,7 @@ class InstalledAppsListApi(Resource):
|
||||||
recommended_app.install_count += 1
|
recommended_app.install_count += 1
|
||||||
|
|
||||||
new_installed_app = InstalledApp(
|
new_installed_app = InstalledApp(
|
||||||
app_id=args["app_id"],
|
app_id=payload.app_id,
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
app_owner_tenant_id=app.tenant_id,
|
app_owner_tenant_id=app.tenant_id,
|
||||||
is_pinned=False,
|
is_pinned=False,
|
||||||
|
|
@ -163,12 +172,11 @@ class InstalledAppApi(InstalledAppResource):
|
||||||
return {"result": "success", "message": "App uninstalled successfully"}, 204
|
return {"result": "success", "message": "App uninstalled successfully"}, 204
|
||||||
|
|
||||||
def patch(self, installed_app):
|
def patch(self, installed_app):
|
||||||
parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean)
|
payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {})
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
commit_args = False
|
commit_args = False
|
||||||
if "is_pinned" in args:
|
if payload.is_pinned is not None:
|
||||||
installed_app.is_pinned = args["is_pinned"]
|
installed_app.is_pinned = payload.is_pinned
|
||||||
commit_args = True
|
commit_args = True
|
||||||
|
|
||||||
if commit_args:
|
if commit_args:
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,32 @@
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask import request
|
||||||
|
from flask_restx import Resource, fields, marshal_with
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from constants import HIDDEN_VALUE
|
from constants import HIDDEN_VALUE
|
||||||
from controllers.console import console_ns
|
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
|
||||||
from fields.api_based_extension_fields import api_based_extension_fields
|
from fields.api_based_extension_fields import api_based_extension_fields
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.api_based_extension import APIBasedExtension
|
from models.api_based_extension import APIBasedExtension
|
||||||
from services.api_based_extension_service import APIBasedExtensionService
|
from services.api_based_extension_service import APIBasedExtensionService
|
||||||
from services.code_based_extension_service import CodeBasedExtensionService
|
from services.code_based_extension_service import CodeBasedExtensionService
|
||||||
|
|
||||||
|
from ..common.schema import register_schema_models
|
||||||
|
from . import console_ns
|
||||||
|
from .wraps import account_initialization_required, setup_required
|
||||||
|
|
||||||
|
|
||||||
|
class CodeBasedExtensionQuery(BaseModel):
|
||||||
|
module: str
|
||||||
|
|
||||||
|
|
||||||
|
class APIBasedExtensionPayload(BaseModel):
|
||||||
|
name: str = Field(description="Extension name")
|
||||||
|
api_endpoint: str = Field(description="API endpoint URL")
|
||||||
|
api_key: str = Field(description="API key for authentication")
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(console_ns, APIBasedExtensionPayload)
|
||||||
|
|
||||||
|
|
||||||
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
|
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
|
||||||
|
|
||||||
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
|
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
|
||||||
|
|
@ -18,11 +36,7 @@ api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_m
|
||||||
class CodeBasedExtensionAPI(Resource):
|
class CodeBasedExtensionAPI(Resource):
|
||||||
@console_ns.doc("get_code_based_extension")
|
@console_ns.doc("get_code_based_extension")
|
||||||
@console_ns.doc(description="Get code-based extension data by module name")
|
@console_ns.doc(description="Get code-based extension data by module name")
|
||||||
@console_ns.expect(
|
@console_ns.doc(params={"module": "Extension module name"})
|
||||||
console_ns.parser().add_argument(
|
|
||||||
"module", type=str, required=True, location="args", help="Extension module name"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Success",
|
"Success",
|
||||||
|
|
@ -35,10 +49,9 @@ class CodeBasedExtensionAPI(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser().add_argument("module", type=str, required=True, location="args")
|
query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
|
return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)}
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/api-based-extension")
|
@console_ns.route("/api-based-extension")
|
||||||
|
|
@ -56,30 +69,21 @@ class APIBasedExtensionAPI(Resource):
|
||||||
|
|
||||||
@console_ns.doc("create_api_based_extension")
|
@console_ns.doc("create_api_based_extension")
|
||||||
@console_ns.doc(description="Create a new API-based extension")
|
@console_ns.doc(description="Create a new API-based extension")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"CreateAPIBasedExtensionRequest",
|
|
||||||
{
|
|
||||||
"name": fields.String(required=True, description="Extension name"),
|
|
||||||
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
|
|
||||||
"api_key": fields.String(required=True, description="API key for authentication"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(201, "Extension created successfully", api_based_extension_model)
|
@console_ns.response(201, "Extension created successfully", api_based_extension_model)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_based_extension_model)
|
@marshal_with(api_based_extension_model)
|
||||||
def post(self):
|
def post(self):
|
||||||
args = console_ns.payload
|
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
extension_data = APIBasedExtension(
|
extension_data = APIBasedExtension(
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
name=args["name"],
|
name=payload.name,
|
||||||
api_endpoint=args["api_endpoint"],
|
api_endpoint=payload.api_endpoint,
|
||||||
api_key=args["api_key"],
|
api_key=payload.api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
return APIBasedExtensionService.save(extension_data)
|
return APIBasedExtensionService.save(extension_data)
|
||||||
|
|
@ -104,16 +108,7 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||||
@console_ns.doc("update_api_based_extension")
|
@console_ns.doc("update_api_based_extension")
|
||||||
@console_ns.doc(description="Update API-based extension")
|
@console_ns.doc(description="Update API-based extension")
|
||||||
@console_ns.doc(params={"id": "Extension ID"})
|
@console_ns.doc(params={"id": "Extension ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"UpdateAPIBasedExtensionRequest",
|
|
||||||
{
|
|
||||||
"name": fields.String(required=True, description="Extension name"),
|
|
||||||
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
|
|
||||||
"api_key": fields.String(required=True, description="API key for authentication"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Extension updated successfully", api_based_extension_model)
|
@console_ns.response(200, "Extension updated successfully", api_based_extension_model)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
|
@ -125,13 +120,13 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||||
|
|
||||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||||
|
|
||||||
args = console_ns.payload
|
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
|
||||||
|
|
||||||
extension_data_from_db.name = args["name"]
|
extension_data_from_db.name = payload.name
|
||||||
extension_data_from_db.api_endpoint = args["api_endpoint"]
|
extension_data_from_db.api_endpoint = payload.api_endpoint
|
||||||
|
|
||||||
if args["api_key"] != HIDDEN_VALUE:
|
if payload.api_key != HIDDEN_VALUE:
|
||||||
extension_data_from_db.api_key = args["api_key"]
|
extension_data_from_db.api_key = payload.api_key
|
||||||
|
|
||||||
return APIBasedExtensionService.save(extension_data_from_db)
|
return APIBasedExtensionService.save(extension_data_from_db)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import io
|
import io
|
||||||
from typing import Literal
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
from flask import request, send_file
|
from flask import request, send_file
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
|
|
@ -141,6 +142,15 @@ class ParserDynamicOptions(BaseModel):
|
||||||
provider_type: Literal["tool", "trigger"]
|
provider_type: Literal["tool", "trigger"]
|
||||||
|
|
||||||
|
|
||||||
|
class ParserDynamicOptionsWithCredentials(BaseModel):
|
||||||
|
plugin_id: str
|
||||||
|
provider: str
|
||||||
|
action: str
|
||||||
|
parameter: str
|
||||||
|
credential_id: str
|
||||||
|
credentials: Mapping[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class PluginPermissionSettingsPayload(BaseModel):
|
class PluginPermissionSettingsPayload(BaseModel):
|
||||||
install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
|
install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
|
||||||
debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
|
debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
|
||||||
|
|
@ -183,6 +193,7 @@ reg(ParserGithubUpgrade)
|
||||||
reg(ParserUninstall)
|
reg(ParserUninstall)
|
||||||
reg(ParserPermissionChange)
|
reg(ParserPermissionChange)
|
||||||
reg(ParserDynamicOptions)
|
reg(ParserDynamicOptions)
|
||||||
|
reg(ParserDynamicOptionsWithCredentials)
|
||||||
reg(ParserPreferencesChange)
|
reg(ParserPreferencesChange)
|
||||||
reg(ParserExcludePlugin)
|
reg(ParserExcludePlugin)
|
||||||
reg(ParserReadme)
|
reg(ParserReadme)
|
||||||
|
|
@ -657,6 +668,37 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||||
return jsonable_encoder({"options": options})
|
return jsonable_encoder({"options": options})
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options-with-credentials")
|
||||||
|
class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource):
|
||||||
|
@console_ns.expect(console_ns.models[ParserDynamicOptionsWithCredentials.__name__])
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@is_admin_or_owner_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self):
|
||||||
|
"""Fetch dynamic options using credentials directly (for edit mode)."""
|
||||||
|
current_user, tenant_id = current_account_with_tenant()
|
||||||
|
user_id = current_user.id
|
||||||
|
|
||||||
|
args = ParserDynamicOptionsWithCredentials.model_validate(console_ns.payload)
|
||||||
|
|
||||||
|
try:
|
||||||
|
options = PluginParameterService.get_dynamic_select_options_with_credentials(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
plugin_id=args.plugin_id,
|
||||||
|
provider=args.provider,
|
||||||
|
action=args.action,
|
||||||
|
parameter=args.parameter,
|
||||||
|
credential_id=args.credential_id,
|
||||||
|
credentials=args.credentials,
|
||||||
|
)
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
return jsonable_encoder({"options": options})
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/plugin/preferences/change")
|
@console_ns.route("/workspaces/current/plugin/preferences/change")
|
||||||
class PluginChangePreferencesApi(Resource):
|
class PluginChangePreferencesApi(Resource):
|
||||||
@console_ns.expect(console_ns.models[ParserPreferencesChange.__name__])
|
@console_ns.expect(console_ns.models[ParserPreferencesChange.__name__])
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ from controllers.console.wraps import (
|
||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||||
|
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||||
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
|
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
|
||||||
from core.mcp.mcp_client import MCPClient
|
from core.mcp.mcp_client import MCPClient
|
||||||
|
|
@ -944,7 +945,7 @@ class ToolProviderMCPApi(Resource):
|
||||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||||
|
|
||||||
# Create provider
|
# Create provider in transaction
|
||||||
with Session(db.engine) as session, session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
result = service.create_provider(
|
result = service.create_provider(
|
||||||
|
|
@ -960,7 +961,11 @@ class ToolProviderMCPApi(Resource):
|
||||||
configuration=configuration,
|
configuration=configuration,
|
||||||
authentication=authentication,
|
authentication=authentication,
|
||||||
)
|
)
|
||||||
return jsonable_encoder(result)
|
|
||||||
|
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||||
|
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||||
|
|
||||||
|
return jsonable_encoder(result)
|
||||||
|
|
||||||
@console_ns.expect(parser_mcp_put)
|
@console_ns.expect(parser_mcp_put)
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -972,17 +977,23 @@ class ToolProviderMCPApi(Resource):
|
||||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
|
# Step 1: Get provider data for URL validation (short-lived session, no network I/O)
|
||||||
validation_result = None
|
validation_data = None
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
validation_result = service.validate_server_url_change(
|
validation_data = service.get_provider_for_url_validation(
|
||||||
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
|
tenant_id=current_tenant_id, provider_id=args["provider_id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# No need to check for errors here, exceptions will be raised directly
|
# Step 2: Perform URL validation with network I/O OUTSIDE of any database session
|
||||||
|
# This prevents holding database locks during potentially slow network operations
|
||||||
|
validation_result = MCPToolManageService.validate_server_url_standalone(
|
||||||
|
tenant_id=current_tenant_id,
|
||||||
|
new_server_url=args["server_url"],
|
||||||
|
validation_data=validation_data,
|
||||||
|
)
|
||||||
|
|
||||||
# Step 2: Perform database update in a transaction
|
# Step 3: Perform database update in a transaction
|
||||||
with Session(db.engine) as session, session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
service.update_provider(
|
service.update_provider(
|
||||||
|
|
@ -999,7 +1010,11 @@ class ToolProviderMCPApi(Resource):
|
||||||
authentication=authentication,
|
authentication=authentication,
|
||||||
validation_result=validation_result,
|
validation_result=validation_result,
|
||||||
)
|
)
|
||||||
return {"result": "success"}
|
|
||||||
|
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||||
|
ToolProviderListCache.invalidate_cache(current_tenant_id)
|
||||||
|
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
@console_ns.expect(parser_mcp_delete)
|
@console_ns.expect(parser_mcp_delete)
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -1012,7 +1027,11 @@ class ToolProviderMCPApi(Resource):
|
||||||
with Session(db.engine) as session, session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
||||||
return {"result": "success"}
|
|
||||||
|
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||||
|
ToolProviderListCache.invalidate_cache(current_tenant_id)
|
||||||
|
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
parser_auth = (
|
parser_auth = (
|
||||||
|
|
@ -1062,6 +1081,8 @@ class ToolMCPAuthApi(Resource):
|
||||||
credentials=provider_entity.credentials,
|
credentials=provider_entity.credentials,
|
||||||
authed=True,
|
authed=True,
|
||||||
)
|
)
|
||||||
|
# Invalidate cache after updating credentials
|
||||||
|
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
except MCPAuthError as e:
|
except MCPAuthError as e:
|
||||||
try:
|
try:
|
||||||
|
|
@ -1075,16 +1096,22 @@ class ToolMCPAuthApi(Resource):
|
||||||
with Session(db.engine) as session, session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
response = service.execute_auth_actions(auth_result)
|
response = service.execute_auth_actions(auth_result)
|
||||||
|
# Invalidate cache after auth actions may have updated provider state
|
||||||
|
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||||
return response
|
return response
|
||||||
except MCPRefreshTokenError as e:
|
except MCPRefreshTokenError as e:
|
||||||
with Session(db.engine) as session, session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
|
# Invalidate cache after clearing credentials
|
||||||
|
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||||
except (MCPError, ValueError) as e:
|
except (MCPError, ValueError) as e:
|
||||||
with Session(db.engine) as session, session.begin():
|
with Session(db.engine) as session, session.begin():
|
||||||
service = MCPToolManageService(session=session)
|
service = MCPToolManageService(session=session)
|
||||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
|
# Invalidate cache after clearing credentials
|
||||||
|
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,15 @@
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from flask import make_response, redirect, request
|
from flask import make_response, redirect, request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import BadRequest, Forbidden
|
from werkzeug.exceptions import BadRequest, Forbidden
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||||
from controllers.web.error import NotFoundError
|
from controllers.web.error import NotFoundError
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.entities.plugin_daemon import CredentialType
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
|
|
@ -32,6 +36,32 @@ from ..wraps import (
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerSubscriptionUpdateRequest(BaseModel):
|
||||||
|
"""Request payload for updating a trigger subscription"""
|
||||||
|
|
||||||
|
name: str | None = Field(default=None, description="The name for the subscription")
|
||||||
|
credentials: Mapping[str, Any] | None = Field(default=None, description="The credentials for the subscription")
|
||||||
|
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription")
|
||||||
|
properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription")
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerSubscriptionVerifyRequest(BaseModel):
|
||||||
|
"""Request payload for verifying subscription credentials."""
|
||||||
|
|
||||||
|
credentials: Mapping[str, Any] = Field(description="The credentials to verify")
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
TriggerSubscriptionUpdateRequest.__name__,
|
||||||
|
TriggerSubscriptionUpdateRequest.model_json_schema(ref_template="#/definitions/{model}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
TriggerSubscriptionVerifyRequest.__name__,
|
||||||
|
TriggerSubscriptionVerifyRequest.model_json_schema(ref_template="#/definitions/{model}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/icon")
|
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/icon")
|
||||||
class TriggerProviderIconApi(Resource):
|
class TriggerProviderIconApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
|
@ -155,16 +185,16 @@ parser_api = (
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route(
|
@console_ns.route(
|
||||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
|
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify-and-update/<path:subscription_builder_id>",
|
||||||
)
|
)
|
||||||
class TriggerSubscriptionBuilderVerifyApi(Resource):
|
class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
|
||||||
@console_ns.expect(parser_api)
|
@console_ns.expect(parser_api)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@edit_permission_required
|
@edit_permission_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, provider, subscription_builder_id):
|
def post(self, provider, subscription_builder_id):
|
||||||
"""Verify a subscription instance for a trigger provider"""
|
"""Verify and update a subscription instance for a trigger provider"""
|
||||||
user = current_user
|
user = current_user
|
||||||
assert user.current_tenant_id is not None
|
assert user.current_tenant_id is not None
|
||||||
|
|
||||||
|
|
@ -289,6 +319,83 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||||
raise ValueError(str(e)) from e
|
raise ValueError(str(e)) from e
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/update",
|
||||||
|
)
|
||||||
|
class TriggerSubscriptionUpdateApi(Resource):
|
||||||
|
@console_ns.expect(console_ns.models[TriggerSubscriptionUpdateRequest.__name__])
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@edit_permission_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, subscription_id: str):
|
||||||
|
"""Update a subscription instance"""
|
||||||
|
user = current_user
|
||||||
|
assert user.current_tenant_id is not None
|
||||||
|
|
||||||
|
args = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
|
||||||
|
|
||||||
|
subscription = TriggerProviderService.get_subscription_by_id(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
)
|
||||||
|
if not subscription:
|
||||||
|
raise NotFoundError(f"Subscription {subscription_id} not found")
|
||||||
|
|
||||||
|
provider_id = TriggerProviderID(subscription.provider_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# rename only
|
||||||
|
if (
|
||||||
|
args.name is not None
|
||||||
|
and args.credentials is None
|
||||||
|
and args.parameters is None
|
||||||
|
and args.properties is None
|
||||||
|
):
|
||||||
|
TriggerProviderService.update_trigger_subscription(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
name=args.name,
|
||||||
|
)
|
||||||
|
return 200
|
||||||
|
|
||||||
|
# rebuild for create automatically by the provider
|
||||||
|
match subscription.credential_type:
|
||||||
|
case CredentialType.UNAUTHORIZED:
|
||||||
|
TriggerProviderService.update_trigger_subscription(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
name=args.name,
|
||||||
|
properties=args.properties,
|
||||||
|
)
|
||||||
|
return 200
|
||||||
|
case CredentialType.API_KEY | CredentialType.OAUTH2:
|
||||||
|
if args.credentials:
|
||||||
|
new_credentials: dict[str, Any] = {
|
||||||
|
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
|
||||||
|
for key, value in args.credentials.items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
new_credentials = subscription.credentials
|
||||||
|
|
||||||
|
TriggerProviderService.rebuild_trigger_subscription(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
name=args.name,
|
||||||
|
provider_id=provider_id,
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
credentials=new_credentials,
|
||||||
|
parameters=args.parameters or subscription.parameters,
|
||||||
|
)
|
||||||
|
return 200
|
||||||
|
case _:
|
||||||
|
raise BadRequest("Invalid credential type")
|
||||||
|
except ValueError as e:
|
||||||
|
raise BadRequest(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error updating subscription", exc_info=e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route(
|
@console_ns.route(
|
||||||
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
|
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
|
||||||
)
|
)
|
||||||
|
|
@ -576,3 +683,38 @@ class TriggerOAuthClientManageApi(Resource):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error removing OAuth client", exc_info=e)
|
logger.exception("Error removing OAuth client", exc_info=e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/verify/<path:subscription_id>",
|
||||||
|
)
|
||||||
|
class TriggerSubscriptionVerifyApi(Resource):
|
||||||
|
@console_ns.expect(console_ns.models[TriggerSubscriptionVerifyRequest.__name__])
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@edit_permission_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider, subscription_id):
|
||||||
|
"""Verify credentials for an existing subscription (edit mode only)"""
|
||||||
|
user = current_user
|
||||||
|
assert user.current_tenant_id is not None
|
||||||
|
|
||||||
|
verify_request: TriggerSubscriptionVerifyRequest = TriggerSubscriptionVerifyRequest.model_validate(
|
||||||
|
console_ns.payload
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = TriggerProviderService.verify_subscription_credentials(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
provider_id=TriggerProviderID(provider),
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
credentials=verify_request.credentials,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning("Credential verification failed", exc_info=e)
|
||||||
|
raise BadRequest(str(e)) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error verifying subscription credentials", exc_info=e)
|
||||||
|
raise BadRequest(str(e)) from e
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from controllers.common.errors import UnsupportedFileTypeError
|
from controllers.common.errors import UnsupportedFileTypeError
|
||||||
|
from controllers.common.file_response import enforce_download_for_html
|
||||||
from controllers.files import files_ns
|
from controllers.files import files_ns
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
|
|
@ -138,6 +139,13 @@ class FilePreviewApi(Resource):
|
||||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||||
response.headers["Content-Type"] = "application/octet-stream"
|
response.headers["Content-Type"] = "application/octet-stream"
|
||||||
|
|
||||||
|
enforce_download_for_html(
|
||||||
|
response,
|
||||||
|
mime_type=upload_file.mime_type,
|
||||||
|
filename=upload_file.name,
|
||||||
|
extension=upload_file.extension,
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
from controllers.common.errors import UnsupportedFileTypeError
|
from controllers.common.errors import UnsupportedFileTypeError
|
||||||
|
from controllers.common.file_response import enforce_download_for_html
|
||||||
from controllers.files import files_ns
|
from controllers.files import files_ns
|
||||||
from core.tools.signature import verify_tool_file_signature
|
from core.tools.signature import verify_tool_file_signature
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
|
@ -78,4 +79,11 @@ class ToolFileApi(Resource):
|
||||||
encoded_filename = quote(tool_file.name)
|
encoded_filename = quote(tool_file.name)
|
||||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||||
|
|
||||||
|
enforce_download_for_html(
|
||||||
|
response,
|
||||||
|
mime_type=tool_file.mimetype,
|
||||||
|
filename=tool_file.name,
|
||||||
|
extension=extension,
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from uuid import UUID
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from flask_restx._http import HTTPStatus
|
from flask_restx._http import HTTPStatus
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import BadRequest, NotFound
|
from werkzeug.exceptions import BadRequest, NotFound
|
||||||
|
|
||||||
|
|
@ -51,6 +51,32 @@ class ConversationRenamePayload(BaseModel):
|
||||||
class ConversationVariablesQuery(BaseModel):
|
class ConversationVariablesQuery(BaseModel):
|
||||||
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
|
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
|
||||||
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
||||||
|
variable_name: str | None = Field(
|
||||||
|
default=None, description="Filter variables by name", min_length=1, max_length=255
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("variable_name", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_variable_name(cls, v: str | None) -> str | None:
|
||||||
|
"""
|
||||||
|
Validate variable_name to prevent injection attacks.
|
||||||
|
"""
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
|
||||||
|
# Only allow safe characters: alphanumeric, underscore, hyphen, period
|
||||||
|
if not v.replace("-", "").replace("_", "").replace(".", "").isalnum():
|
||||||
|
raise ValueError(
|
||||||
|
"Variable name can only contain letters, numbers, hyphens (-), underscores (_), and periods (.)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prevent SQL injection patterns
|
||||||
|
dangerous_patterns = ["'", '"', ";", "--", "/*", "*/", "xp_", "sp_"]
|
||||||
|
for pattern in dangerous_patterns:
|
||||||
|
if pattern in v.lower():
|
||||||
|
raise ValueError(f"Variable name contains invalid characters: {pattern}")
|
||||||
|
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class ConversationVariableUpdatePayload(BaseModel):
|
class ConversationVariableUpdatePayload(BaseModel):
|
||||||
|
|
@ -199,7 +225,7 @@ class ConversationVariablesApi(Resource):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return ConversationService.get_conversational_variable(
|
return ConversationService.get_conversational_variable(
|
||||||
app_model, conversation_id, end_user, query_args.limit, last_id
|
app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name
|
||||||
)
|
)
|
||||||
except services.errors.conversation.ConversationNotExistsError:
|
except services.errors.conversation.ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from flask import Response, request
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from controllers.common.file_response import enforce_download_for_html
|
||||||
from controllers.common.schema import register_schema_model
|
from controllers.common.schema import register_schema_model
|
||||||
from controllers.service_api import service_api_ns
|
from controllers.service_api import service_api_ns
|
||||||
from controllers.service_api.app.error import (
|
from controllers.service_api.app.error import (
|
||||||
|
|
@ -183,6 +184,13 @@ class FilePreviewApi(Resource):
|
||||||
# Override content-type for downloads to force download
|
# Override content-type for downloads to force download
|
||||||
response.headers["Content-Type"] = "application/octet-stream"
|
response.headers["Content-Type"] = "application/octet-stream"
|
||||||
|
|
||||||
|
enforce_download_for_html(
|
||||||
|
response,
|
||||||
|
mime_type=upload_file.mime_type,
|
||||||
|
filename=upload_file.name,
|
||||||
|
extension=upload_file.extension,
|
||||||
|
)
|
||||||
|
|
||||||
# Add caching headers for performance
|
# Add caching headers for performance
|
||||||
response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour
|
response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,13 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from werkzeug.exceptions import Unauthorized
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
|
||||||
from constants import HEADER_NAME_APP_CODE
|
from constants import HEADER_NAME_APP_CODE
|
||||||
from controllers.common import fields
|
from controllers.common import fields
|
||||||
from controllers.web import web_ns
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.web.error import AppUnavailableError
|
|
||||||
from controllers.web.wraps import WebApiResource
|
|
||||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||||
from libs.passport import PassportService
|
from libs.passport import PassportService
|
||||||
from libs.token import extract_webapp_passport
|
from libs.token import extract_webapp_passport
|
||||||
|
|
@ -18,9 +17,23 @@ from services.enterprise.enterprise_service import EnterpriseService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
from services.webapp_auth_service import WebAppAuthService
|
from services.webapp_auth_service import WebAppAuthService
|
||||||
|
|
||||||
|
from . import web_ns
|
||||||
|
from .error import AppUnavailableError
|
||||||
|
from .wraps import WebApiResource
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AppAccessModeQuery(BaseModel):
|
||||||
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
app_id: str | None = Field(default=None, alias="appId", description="Application ID")
|
||||||
|
app_code: str | None = Field(default=None, alias="appCode", description="Application code")
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(web_ns, AppAccessModeQuery)
|
||||||
|
|
||||||
|
|
||||||
@web_ns.route("/parameters")
|
@web_ns.route("/parameters")
|
||||||
class AppParameterApi(WebApiResource):
|
class AppParameterApi(WebApiResource):
|
||||||
"""Resource for app variables."""
|
"""Resource for app variables."""
|
||||||
|
|
@ -96,21 +109,16 @@ class AppAccessMode(Resource):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = (
|
raw_args = request.args.to_dict()
|
||||||
reqparse.RequestParser()
|
args = AppAccessModeQuery.model_validate(raw_args)
|
||||||
.add_argument("appId", type=str, required=False, location="args")
|
|
||||||
.add_argument("appCode", type=str, required=False, location="args")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
features = FeatureService.get_system_features()
|
features = FeatureService.get_system_features()
|
||||||
if not features.webapp_auth.enabled:
|
if not features.webapp_auth.enabled:
|
||||||
return {"accessMode": "public"}
|
return {"accessMode": "public"}
|
||||||
|
|
||||||
app_id = args.get("appId")
|
app_id = args.app_id
|
||||||
if args.get("appCode"):
|
if args.app_code:
|
||||||
app_code = args["appCode"]
|
app_id = AppService.get_app_id_by_code(args.app_code)
|
||||||
app_id = AppService.get_app_id_by_code(app_code)
|
|
||||||
|
|
||||||
if not app_id:
|
if not app_id:
|
||||||
raise ValueError("appId or appCode must be provided")
|
raise ValueError("appId or appCode must be provided")
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,12 @@ import base64
|
||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console.auth.error import (
|
from controllers.console.auth.error import (
|
||||||
AuthenticationFailedError,
|
AuthenticationFailedError,
|
||||||
EmailCodeError,
|
EmailCodeError,
|
||||||
|
|
@ -18,14 +20,40 @@ from controllers.console.error import EmailSendIpLimitError
|
||||||
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
|
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
|
||||||
from controllers.web import web_ns
|
from controllers.web import web_ns
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import email, extract_remote_ip
|
from libs.helper import EmailStr, extract_remote_ip
|
||||||
from libs.password import hash_password, valid_password
|
from libs.password import hash_password, valid_password
|
||||||
from models import Account
|
from models import Account
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
|
|
||||||
|
|
||||||
|
class ForgotPasswordSendPayload(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
language: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ForgotPasswordCheckPayload(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
code: str
|
||||||
|
token: str = Field(min_length=1)
|
||||||
|
|
||||||
|
|
||||||
|
class ForgotPasswordResetPayload(BaseModel):
|
||||||
|
token: str = Field(min_length=1)
|
||||||
|
new_password: str
|
||||||
|
password_confirm: str
|
||||||
|
|
||||||
|
@field_validator("new_password", "password_confirm")
|
||||||
|
@classmethod
|
||||||
|
def validate_password(cls, value: str) -> str:
|
||||||
|
return valid_password(value)
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload)
|
||||||
|
|
||||||
|
|
||||||
@web_ns.route("/forgot-password")
|
@web_ns.route("/forgot-password")
|
||||||
class ForgotPasswordSendEmailApi(Resource):
|
class ForgotPasswordSendEmailApi(Resource):
|
||||||
|
@web_ns.expect(web_ns.models[ForgotPasswordSendPayload.__name__])
|
||||||
@only_edition_enterprise
|
@only_edition_enterprise
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
|
|
@ -40,35 +68,31 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=email, required=True, location="json")
|
|
||||||
.add_argument("language", type=str, required=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
if AccountService.is_email_send_ip_limit(ip_address):
|
if AccountService.is_email_send_ip_limit(ip_address):
|
||||||
raise EmailSendIpLimitError()
|
raise EmailSendIpLimitError()
|
||||||
|
|
||||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
if payload.language == "zh-Hans":
|
||||||
language = "zh-Hans"
|
language = "zh-Hans"
|
||||||
else:
|
else:
|
||||||
language = "en-US"
|
language = "en-US"
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none()
|
||||||
token = None
|
token = None
|
||||||
if account is None:
|
if account is None:
|
||||||
raise AuthenticationFailedError()
|
raise AuthenticationFailedError()
|
||||||
else:
|
else:
|
||||||
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
|
token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language)
|
||||||
|
|
||||||
return {"result": "success", "data": token}
|
return {"result": "success", "data": token}
|
||||||
|
|
||||||
|
|
||||||
@web_ns.route("/forgot-password/validity")
|
@web_ns.route("/forgot-password/validity")
|
||||||
class ForgotPasswordCheckApi(Resource):
|
class ForgotPasswordCheckApi(Resource):
|
||||||
|
@web_ns.expect(web_ns.models[ForgotPasswordCheckPayload.__name__])
|
||||||
@only_edition_enterprise
|
@only_edition_enterprise
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
|
|
@ -78,45 +102,40 @@ class ForgotPasswordCheckApi(Resource):
|
||||||
responses={200: "Token is valid", 400: "Bad request - invalid token format", 401: "Invalid or expired token"}
|
responses={200: "Token is valid", 400: "Bad request - invalid token format", 401: "Invalid or expired token"}
|
||||||
)
|
)
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("email", type=str, required=True, location="json")
|
|
||||||
.add_argument("code", type=str, required=True, location="json")
|
|
||||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
user_email = args["email"]
|
user_email = payload.email
|
||||||
|
|
||||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
|
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email)
|
||||||
if is_forgot_password_error_rate_limit:
|
if is_forgot_password_error_rate_limit:
|
||||||
raise EmailPasswordResetLimitError()
|
raise EmailPasswordResetLimitError()
|
||||||
|
|
||||||
token_data = AccountService.get_reset_password_data(args["token"])
|
token_data = AccountService.get_reset_password_data(payload.token)
|
||||||
if token_data is None:
|
if token_data is None:
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
|
|
||||||
if user_email != token_data.get("email"):
|
if user_email != token_data.get("email"):
|
||||||
raise InvalidEmailError()
|
raise InvalidEmailError()
|
||||||
|
|
||||||
if args["code"] != token_data.get("code"):
|
if payload.code != token_data.get("code"):
|
||||||
AccountService.add_forgot_password_error_rate_limit(args["email"])
|
AccountService.add_forgot_password_error_rate_limit(payload.email)
|
||||||
raise EmailCodeError()
|
raise EmailCodeError()
|
||||||
|
|
||||||
# Verified, revoke the first token
|
# Verified, revoke the first token
|
||||||
AccountService.revoke_reset_password_token(args["token"])
|
AccountService.revoke_reset_password_token(payload.token)
|
||||||
|
|
||||||
# Refresh token data by generating a new token
|
# Refresh token data by generating a new token
|
||||||
_, new_token = AccountService.generate_reset_password_token(
|
_, new_token = AccountService.generate_reset_password_token(
|
||||||
user_email, code=args["code"], additional_data={"phase": "reset"}
|
user_email, code=payload.code, additional_data={"phase": "reset"}
|
||||||
)
|
)
|
||||||
|
|
||||||
AccountService.reset_forgot_password_error_rate_limit(args["email"])
|
AccountService.reset_forgot_password_error_rate_limit(payload.email)
|
||||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||||
|
|
||||||
|
|
||||||
@web_ns.route("/forgot-password/resets")
|
@web_ns.route("/forgot-password/resets")
|
||||||
class ForgotPasswordResetApi(Resource):
|
class ForgotPasswordResetApi(Resource):
|
||||||
|
@web_ns.expect(web_ns.models[ForgotPasswordResetPayload.__name__])
|
||||||
@only_edition_enterprise
|
@only_edition_enterprise
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
|
|
@ -131,20 +150,14 @@ class ForgotPasswordResetApi(Resource):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = (
|
payload = ForgotPasswordResetPayload.model_validate(web_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
|
||||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Validate passwords match
|
# Validate passwords match
|
||||||
if args["new_password"] != args["password_confirm"]:
|
if payload.new_password != payload.password_confirm:
|
||||||
raise PasswordMismatchError()
|
raise PasswordMismatchError()
|
||||||
|
|
||||||
# Validate token and get reset data
|
# Validate token and get reset data
|
||||||
reset_data = AccountService.get_reset_password_data(args["token"])
|
reset_data = AccountService.get_reset_password_data(payload.token)
|
||||||
if not reset_data:
|
if not reset_data:
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
# Must use token in reset phase
|
# Must use token in reset phase
|
||||||
|
|
@ -152,11 +165,11 @@ class ForgotPasswordResetApi(Resource):
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
|
|
||||||
# Revoke token to prevent reuse
|
# Revoke token to prevent reuse
|
||||||
AccountService.revoke_reset_password_token(args["token"])
|
AccountService.revoke_reset_password_token(payload.token)
|
||||||
|
|
||||||
# Generate secure salt and hash password
|
# Generate secure salt and hash password
|
||||||
salt = secrets.token_bytes(16)
|
salt = secrets.token_bytes(16)
|
||||||
password_hashed = hash_password(args["new_password"], salt)
|
password_hashed = hash_password(payload.new_password, salt)
|
||||||
|
|
||||||
email = reset_data.get("email", "")
|
email = reset_data.get("email", "")
|
||||||
|
|
||||||
|
|
@ -170,7 +183,7 @@ class ForgotPasswordResetApi(Resource):
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
def _update_existing_account(self, account, password_hashed, salt, session):
|
def _update_existing_account(self, account: Account, password_hashed, salt, session):
|
||||||
# Update existing account credentials
|
# Update existing account credentials
|
||||||
account.password = base64.b64encode(password_hashed).decode()
|
account.password = base64.b64encode(password_hashed).decode()
|
||||||
account.password_salt = base64.b64encode(salt).decode()
|
account.password_salt = base64.b64encode(salt).decode()
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from flask_restx import fields, marshal_with, reqparse
|
from flask import request
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx import fields, marshal_with
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.web import web_ns
|
from controllers.web import web_ns
|
||||||
from controllers.web.error import (
|
from controllers.web.error import (
|
||||||
AppMoreLikeThisDisabledError,
|
AppMoreLikeThisDisabledError,
|
||||||
|
|
@ -38,6 +41,33 @@ from services.message_service import MessageService
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageListQuery(BaseModel):
|
||||||
|
conversation_id: str = Field(description="Conversation UUID")
|
||||||
|
first_id: str | None = Field(default=None, description="First message ID for pagination")
|
||||||
|
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
|
||||||
|
|
||||||
|
@field_validator("conversation_id", "first_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_uuid(cls, value: str | None) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return uuid_value(value)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageFeedbackPayload(BaseModel):
|
||||||
|
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||||
|
content: str | None = Field(default=None, description="Feedback content")
|
||||||
|
|
||||||
|
|
||||||
|
class MessageMoreLikeThisQuery(BaseModel):
|
||||||
|
response_mode: Literal["blocking", "streaming"] = Field(
|
||||||
|
description="Response mode",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(web_ns, MessageListQuery, MessageFeedbackPayload, MessageMoreLikeThisQuery)
|
||||||
|
|
||||||
|
|
||||||
@web_ns.route("/messages")
|
@web_ns.route("/messages")
|
||||||
class MessageListApi(WebApiResource):
|
class MessageListApi(WebApiResource):
|
||||||
message_fields = {
|
message_fields = {
|
||||||
|
|
@ -68,7 +98,11 @@ class MessageListApi(WebApiResource):
|
||||||
@web_ns.doc(
|
@web_ns.doc(
|
||||||
params={
|
params={
|
||||||
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": True},
|
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": True},
|
||||||
"first_id": {"description": "First message ID for pagination", "type": "string", "required": False},
|
"first_id": {
|
||||||
|
"description": "First message ID for pagination",
|
||||||
|
"type": "string",
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
"limit": {
|
"limit": {
|
||||||
"description": "Number of messages to return (1-100)",
|
"description": "Number of messages to return (1-100)",
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
|
|
@ -93,17 +127,12 @@ class MessageListApi(WebApiResource):
|
||||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = (
|
raw_args = request.args.to_dict()
|
||||||
reqparse.RequestParser()
|
query = MessageListQuery.model_validate(raw_args)
|
||||||
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
|
||||||
.add_argument("first_id", type=uuid_value, location="args")
|
|
||||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return MessageService.pagination_by_first_id(
|
return MessageService.pagination_by_first_id(
|
||||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
|
app_model, end_user, query.conversation_id, query.first_id, query.limit
|
||||||
)
|
)
|
||||||
except ConversationNotExistsError:
|
except ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
@ -128,7 +157,7 @@ class MessageFeedbackApi(WebApiResource):
|
||||||
"enum": ["like", "dislike"],
|
"enum": ["like", "dislike"],
|
||||||
"required": False,
|
"required": False,
|
||||||
},
|
},
|
||||||
"content": {"description": "Feedback content/comment", "type": "string", "required": False},
|
"content": {"description": "Feedback content", "type": "string", "required": False},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@web_ns.doc(
|
@web_ns.doc(
|
||||||
|
|
@ -145,20 +174,15 @@ class MessageFeedbackApi(WebApiResource):
|
||||||
def post(self, app_model, end_user, message_id):
|
def post(self, app_model, end_user, message_id):
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
parser = (
|
payload = MessageFeedbackPayload.model_validate(web_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
|
||||||
.add_argument("content", type=str, location="json", default=None)
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
MessageService.create_feedback(
|
MessageService.create_feedback(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
user=end_user,
|
user=end_user,
|
||||||
rating=args.get("rating"),
|
rating=payload.rating,
|
||||||
content=args.get("content"),
|
content=payload.content,
|
||||||
)
|
)
|
||||||
except MessageNotExistsError:
|
except MessageNotExistsError:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
|
|
@ -170,17 +194,7 @@ class MessageFeedbackApi(WebApiResource):
|
||||||
class MessageMoreLikeThisApi(WebApiResource):
|
class MessageMoreLikeThisApi(WebApiResource):
|
||||||
@web_ns.doc("Generate More Like This")
|
@web_ns.doc("Generate More Like This")
|
||||||
@web_ns.doc(description="Generate a new completion similar to an existing message (completion apps only).")
|
@web_ns.doc(description="Generate a new completion similar to an existing message (completion apps only).")
|
||||||
@web_ns.doc(
|
@web_ns.expect(web_ns.models[MessageMoreLikeThisQuery.__name__])
|
||||||
params={
|
|
||||||
"message_id": {"description": "Message UUID", "type": "string", "required": True},
|
|
||||||
"response_mode": {
|
|
||||||
"description": "Response mode",
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["blocking", "streaming"],
|
|
||||||
"required": True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
@web_ns.doc(
|
@web_ns.doc(
|
||||||
responses={
|
responses={
|
||||||
200: "Success",
|
200: "Success",
|
||||||
|
|
@ -197,12 +211,10 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||||
|
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser().add_argument(
|
raw_args = request.args.to_dict()
|
||||||
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
|
query = MessageMoreLikeThisQuery.model_validate(raw_args)
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
streaming = args["response_mode"] == "streaming"
|
streaming = query.response_mode == "streaming"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate_more_like_this(
|
response = AppGenerateService.generate_more_like_this(
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from flask_restx import marshal_with, reqparse
|
from flask_restx import marshal_with
|
||||||
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from controllers.common import helpers
|
from controllers.common import helpers
|
||||||
|
|
@ -10,14 +11,23 @@ from controllers.common.errors import (
|
||||||
RemoteFileUploadError,
|
RemoteFileUploadError,
|
||||||
UnsupportedFileTypeError,
|
UnsupportedFileTypeError,
|
||||||
)
|
)
|
||||||
from controllers.web import web_ns
|
|
||||||
from controllers.web.wraps import WebApiResource
|
|
||||||
from core.file import helpers as file_helpers
|
from core.file import helpers as file_helpers
|
||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
|
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
from ..common.schema import register_schema_models
|
||||||
|
from . import web_ns
|
||||||
|
from .wraps import WebApiResource
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteFileUploadPayload(BaseModel):
|
||||||
|
url: HttpUrl = Field(description="Remote file URL")
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(web_ns, RemoteFileUploadPayload)
|
||||||
|
|
||||||
|
|
||||||
@web_ns.route("/remote-files/<path:url>")
|
@web_ns.route("/remote-files/<path:url>")
|
||||||
class RemoteFileInfoApi(WebApiResource):
|
class RemoteFileInfoApi(WebApiResource):
|
||||||
|
|
@ -97,10 +107,8 @@ class RemoteFileUploadApi(WebApiResource):
|
||||||
FileTooLargeError: File exceeds size limit
|
FileTooLargeError: File exceeds size limit
|
||||||
UnsupportedFileTypeError: File type not supported
|
UnsupportedFileTypeError: File type not supported
|
||||||
"""
|
"""
|
||||||
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
|
payload = RemoteFileUploadPayload.model_validate(web_ns.payload or {})
|
||||||
args = parser.parse_args()
|
url = str(payload.url)
|
||||||
|
|
||||||
url = args["url"]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = ssrf_proxy.head(url=url)
|
resp = ssrf_proxy.head(url=url)
|
||||||
|
|
|
||||||
|
|
@ -105,8 +105,9 @@ class BaseAppGenerator:
|
||||||
variable_entity.type in {VariableEntityType.FILE, VariableEntityType.FILE_LIST}
|
variable_entity.type in {VariableEntityType.FILE, VariableEntityType.FILE_LIST}
|
||||||
and not variable_entity.required
|
and not variable_entity.required
|
||||||
):
|
):
|
||||||
# Treat empty string (frontend default) or empty list as unset
|
# Treat empty string (frontend default) as unset
|
||||||
if not value and isinstance(value, (str, list)):
|
# For FILE_LIST, allow empty list [] to pass through
|
||||||
|
if isinstance(value, str) and not value:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if variable_entity.type in {
|
if variable_entity.type in {
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,22 @@ def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_provided_host_header(headers: dict | None) -> str | None:
|
||||||
|
"""
|
||||||
|
Extract the user-provided Host header from the headers dict.
|
||||||
|
|
||||||
|
This is needed because when using a forward proxy, httpx may override the Host header.
|
||||||
|
We preserve the user's explicit Host header to support virtual hosting and other use cases.
|
||||||
|
"""
|
||||||
|
if not headers:
|
||||||
|
return None
|
||||||
|
# Case-insensitive lookup for Host header
|
||||||
|
for key, value in headers.items():
|
||||||
|
if key.lower() == "host":
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
if "allow_redirects" in kwargs:
|
if "allow_redirects" in kwargs:
|
||||||
allow_redirects = kwargs.pop("allow_redirects")
|
allow_redirects = kwargs.pop("allow_redirects")
|
||||||
|
|
@ -90,10 +106,24 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
|
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
|
||||||
client = _get_ssrf_client(verify_option)
|
client = _get_ssrf_client(verify_option)
|
||||||
|
|
||||||
|
# Preserve user-provided Host header
|
||||||
|
# When using a forward proxy, httpx may override the Host header based on the URL.
|
||||||
|
# We extract and preserve any explicitly set Host header to support virtual hosting.
|
||||||
|
headers = kwargs.get("headers", {})
|
||||||
|
user_provided_host = _get_user_provided_host_header(headers)
|
||||||
|
|
||||||
retries = 0
|
retries = 0
|
||||||
while retries <= max_retries:
|
while retries <= max_retries:
|
||||||
try:
|
try:
|
||||||
|
# Build the request manually to preserve the Host header
|
||||||
|
# httpx may override the Host header when using a proxy, so we use
|
||||||
|
# the request API to explicitly set headers before sending
|
||||||
|
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
|
||||||
|
if user_provided_host is not None:
|
||||||
|
headers["host"] = user_provided_host
|
||||||
|
kwargs["headers"] = headers
|
||||||
response = client.request(method=method, url=url, **kwargs)
|
response = client.request(method=method, url=url, **kwargs)
|
||||||
|
|
||||||
# Check for SSRF protection by Squid proxy
|
# Check for SSRF protection by Squid proxy
|
||||||
if response.status_code in (401, 403):
|
if response.status_code in (401, 403):
|
||||||
# Check if this is a Squid SSRF rejection
|
# Check if this is a Squid SSRF rejection
|
||||||
|
|
|
||||||
|
|
@ -396,7 +396,7 @@ class IndexingRunner:
|
||||||
datasource_type=DatasourceType.NOTION,
|
datasource_type=DatasourceType.NOTION,
|
||||||
notion_info=NotionInfo.model_validate(
|
notion_info=NotionInfo.model_validate(
|
||||||
{
|
{
|
||||||
"credential_id": data_source_info["credential_id"],
|
"credential_id": data_source_info.get("credential_id"),
|
||||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||||
"notion_obj_id": data_source_info["notion_page_id"],
|
"notion_obj_id": data_source_info["notion_page_id"],
|
||||||
"notion_page_type": data_source_info["type"],
|
"notion_page_type": data_source_info["type"],
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,11 @@ def build_protected_resource_metadata_discovery_urls(
|
||||||
"""
|
"""
|
||||||
Build a list of URLs to try for Protected Resource Metadata discovery.
|
Build a list of URLs to try for Protected Resource Metadata discovery.
|
||||||
|
|
||||||
Per SEP-985, supports fallback when discovery fails at one URL.
|
Per RFC 9728 Section 5.1, supports fallback when discovery fails at one URL.
|
||||||
|
Priority order:
|
||||||
|
1. URL from WWW-Authenticate header (if provided)
|
||||||
|
2. Well-known URI with path: https://example.com/.well-known/oauth-protected-resource/public/mcp
|
||||||
|
3. Well-known URI at root: https://example.com/.well-known/oauth-protected-resource
|
||||||
"""
|
"""
|
||||||
urls = []
|
urls = []
|
||||||
|
|
||||||
|
|
@ -58,9 +62,18 @@ def build_protected_resource_metadata_discovery_urls(
|
||||||
# Fallback: construct from server URL
|
# Fallback: construct from server URL
|
||||||
parsed = urlparse(server_url)
|
parsed = urlparse(server_url)
|
||||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||||
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
|
path = parsed.path.rstrip("/")
|
||||||
if fallback_url not in urls:
|
|
||||||
urls.append(fallback_url)
|
# Priority 2: With path insertion (e.g., /.well-known/oauth-protected-resource/public/mcp)
|
||||||
|
if path:
|
||||||
|
path_url = f"{base_url}/.well-known/oauth-protected-resource{path}"
|
||||||
|
if path_url not in urls:
|
||||||
|
urls.append(path_url)
|
||||||
|
|
||||||
|
# Priority 3: At root (e.g., /.well-known/oauth-protected-resource)
|
||||||
|
root_url = f"{base_url}/.well-known/oauth-protected-resource"
|
||||||
|
if root_url not in urls:
|
||||||
|
urls.append(root_url)
|
||||||
|
|
||||||
return urls
|
return urls
|
||||||
|
|
||||||
|
|
@ -71,30 +84,34 @@ def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: st
|
||||||
|
|
||||||
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
|
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
|
||||||
|
|
||||||
Per RFC 8414 section 3:
|
Per RFC 8414 section 3.1 and section 5, try all possible endpoints:
|
||||||
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server
|
- OAuth 2.0 with path insertion: https://example.com/.well-known/oauth-authorization-server/tenant1
|
||||||
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path}
|
- OpenID Connect with path insertion: https://example.com/.well-known/openid-configuration/tenant1
|
||||||
|
- OpenID Connect path appending: https://example.com/tenant1/.well-known/openid-configuration
|
||||||
Example:
|
- OAuth 2.0 at root: https://example.com/.well-known/oauth-authorization-server
|
||||||
- issuer: https://example.com/oauth
|
- OpenID Connect at root: https://example.com/.well-known/openid-configuration
|
||||||
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
|
|
||||||
"""
|
"""
|
||||||
urls = []
|
urls = []
|
||||||
base_url = auth_server_url or server_url
|
base_url = auth_server_url or server_url
|
||||||
|
|
||||||
parsed = urlparse(base_url)
|
parsed = urlparse(base_url)
|
||||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||||
path = parsed.path.rstrip("/") # Remove trailing slash
|
path = parsed.path.rstrip("/")
|
||||||
|
# OAuth 2.0 Authorization Server Metadata at root (MCP-03-26)
|
||||||
|
urls.append(f"{base}/.well-known/oauth-authorization-server")
|
||||||
|
|
||||||
# Try OpenID Connect discovery first (more common)
|
# OpenID Connect Discovery at root
|
||||||
urls.append(urljoin(base + "/", ".well-known/openid-configuration"))
|
urls.append(f"{base}/.well-known/openid-configuration")
|
||||||
|
|
||||||
# OAuth 2.0 Authorization Server Metadata (RFC 8414)
|
|
||||||
# Include the path component if present in the issuer URL
|
|
||||||
if path:
|
if path:
|
||||||
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}"))
|
# OpenID Connect Discovery with path insertion
|
||||||
else:
|
urls.append(f"{base}/.well-known/openid-configuration{path}")
|
||||||
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
|
|
||||||
|
# OpenID Connect Discovery path appending
|
||||||
|
urls.append(f"{base}{path}/.well-known/openid-configuration")
|
||||||
|
|
||||||
|
# OAuth 2.0 Authorization Server Metadata with path insertion
|
||||||
|
urls.append(f"{base}/.well-known/oauth-authorization-server{path}")
|
||||||
|
|
||||||
return urls
|
return urls
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,7 @@ class SSETransport:
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.sse_read_timeout = sse_read_timeout
|
self.sse_read_timeout = sse_read_timeout
|
||||||
self.endpoint_url: str | None = None
|
self.endpoint_url: str | None = None
|
||||||
|
self.event_source: EventSource | None = None
|
||||||
|
|
||||||
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
|
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
|
||||||
"""Validate that the endpoint URL matches the connection origin.
|
"""Validate that the endpoint URL matches the connection origin.
|
||||||
|
|
@ -237,6 +238,9 @@ class SSETransport:
|
||||||
write_queue: WriteQueue = queue.Queue()
|
write_queue: WriteQueue = queue.Queue()
|
||||||
status_queue: StatusQueue = queue.Queue()
|
status_queue: StatusQueue = queue.Queue()
|
||||||
|
|
||||||
|
# Store event_source for graceful shutdown
|
||||||
|
self.event_source = event_source
|
||||||
|
|
||||||
# Start SSE reader thread
|
# Start SSE reader thread
|
||||||
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
|
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
|
||||||
|
|
||||||
|
|
@ -296,6 +300,13 @@ def sse_client(
|
||||||
logger.exception("Error connecting to SSE endpoint")
|
logger.exception("Error connecting to SSE endpoint")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
|
# Close the SSE connection to unblock the reader thread
|
||||||
|
if transport.event_source is not None:
|
||||||
|
try:
|
||||||
|
transport.event_source.response.close()
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Clean up queues
|
# Clean up queues
|
||||||
if read_queue:
|
if read_queue:
|
||||||
read_queue.put(None)
|
read_queue.put(None)
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ and session management.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
|
import threading
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Callable, Generator
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
@ -103,6 +104,9 @@ class StreamableHTTPTransport:
|
||||||
CONTENT_TYPE: JSON,
|
CONTENT_TYPE: JSON,
|
||||||
**self.headers,
|
**self.headers,
|
||||||
}
|
}
|
||||||
|
self.stop_event = threading.Event()
|
||||||
|
self._active_responses: list[httpx.Response] = []
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
|
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
|
||||||
"""Update headers with session ID if available."""
|
"""Update headers with session ID if available."""
|
||||||
|
|
@ -111,6 +115,30 @@ class StreamableHTTPTransport:
|
||||||
headers[MCP_SESSION_ID] = self.session_id
|
headers[MCP_SESSION_ID] = self.session_id
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
def _register_response(self, response: httpx.Response):
|
||||||
|
"""Register a response for cleanup on shutdown."""
|
||||||
|
with self._lock:
|
||||||
|
self._active_responses.append(response)
|
||||||
|
|
||||||
|
def _unregister_response(self, response: httpx.Response):
|
||||||
|
"""Unregister a response after it's closed."""
|
||||||
|
with self._lock:
|
||||||
|
try:
|
||||||
|
self._active_responses.remove(response)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.debug("Ignoring error during response unregister: %s", e)
|
||||||
|
|
||||||
|
def close_active_responses(self):
|
||||||
|
"""Close all active SSE connections to unblock threads."""
|
||||||
|
with self._lock:
|
||||||
|
responses_to_close = list(self._active_responses)
|
||||||
|
self._active_responses.clear()
|
||||||
|
for response in responses_to_close:
|
||||||
|
try:
|
||||||
|
response.close()
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.debug("Ignoring error during active response close: %s", e)
|
||||||
|
|
||||||
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
||||||
"""Check if the message is an initialization request."""
|
"""Check if the message is an initialization request."""
|
||||||
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
|
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
|
||||||
|
|
@ -195,11 +223,21 @@ class StreamableHTTPTransport:
|
||||||
event_source.response.raise_for_status()
|
event_source.response.raise_for_status()
|
||||||
logger.debug("GET SSE connection established")
|
logger.debug("GET SSE connection established")
|
||||||
|
|
||||||
for sse in event_source.iter_sse():
|
# Register response for cleanup
|
||||||
self._handle_sse_event(sse, server_to_client_queue)
|
self._register_response(event_source.response)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for sse in event_source.iter_sse():
|
||||||
|
if self.stop_event.is_set():
|
||||||
|
logger.debug("GET stream received stop signal")
|
||||||
|
break
|
||||||
|
self._handle_sse_event(sse, server_to_client_queue)
|
||||||
|
finally:
|
||||||
|
self._unregister_response(event_source.response)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.debug("GET stream error (non-fatal): %s", exc)
|
if not self.stop_event.is_set():
|
||||||
|
logger.debug("GET stream error (non-fatal): %s", exc)
|
||||||
|
|
||||||
def _handle_resumption_request(self, ctx: RequestContext):
|
def _handle_resumption_request(self, ctx: RequestContext):
|
||||||
"""Handle a resumption request using GET with SSE."""
|
"""Handle a resumption request using GET with SSE."""
|
||||||
|
|
@ -224,15 +262,24 @@ class StreamableHTTPTransport:
|
||||||
event_source.response.raise_for_status()
|
event_source.response.raise_for_status()
|
||||||
logger.debug("Resumption GET SSE connection established")
|
logger.debug("Resumption GET SSE connection established")
|
||||||
|
|
||||||
for sse in event_source.iter_sse():
|
# Register response for cleanup
|
||||||
is_complete = self._handle_sse_event(
|
self._register_response(event_source.response)
|
||||||
sse,
|
|
||||||
ctx.server_to_client_queue,
|
try:
|
||||||
original_request_id,
|
for sse in event_source.iter_sse():
|
||||||
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
|
if self.stop_event.is_set():
|
||||||
)
|
logger.debug("Resumption stream received stop signal")
|
||||||
if is_complete:
|
break
|
||||||
break
|
is_complete = self._handle_sse_event(
|
||||||
|
sse,
|
||||||
|
ctx.server_to_client_queue,
|
||||||
|
original_request_id,
|
||||||
|
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
|
||||||
|
)
|
||||||
|
if is_complete:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
self._unregister_response(event_source.response)
|
||||||
|
|
||||||
def _handle_post_request(self, ctx: RequestContext):
|
def _handle_post_request(self, ctx: RequestContext):
|
||||||
"""Handle a POST request with response processing."""
|
"""Handle a POST request with response processing."""
|
||||||
|
|
@ -295,17 +342,27 @@ class StreamableHTTPTransport:
|
||||||
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext):
|
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext):
|
||||||
"""Handle SSE response from the server."""
|
"""Handle SSE response from the server."""
|
||||||
try:
|
try:
|
||||||
|
# Register response for cleanup
|
||||||
|
self._register_response(response)
|
||||||
|
|
||||||
event_source = EventSource(response)
|
event_source = EventSource(response)
|
||||||
for sse in event_source.iter_sse():
|
try:
|
||||||
is_complete = self._handle_sse_event(
|
for sse in event_source.iter_sse():
|
||||||
sse,
|
if self.stop_event.is_set():
|
||||||
ctx.server_to_client_queue,
|
logger.debug("SSE response stream received stop signal")
|
||||||
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
break
|
||||||
)
|
is_complete = self._handle_sse_event(
|
||||||
if is_complete:
|
sse,
|
||||||
break
|
ctx.server_to_client_queue,
|
||||||
|
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
||||||
|
)
|
||||||
|
if is_complete:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
self._unregister_response(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ctx.server_to_client_queue.put(e)
|
if not self.stop_event.is_set():
|
||||||
|
ctx.server_to_client_queue.put(e)
|
||||||
|
|
||||||
def _handle_unexpected_content_type(
|
def _handle_unexpected_content_type(
|
||||||
self,
|
self,
|
||||||
|
|
@ -345,6 +402,11 @@ class StreamableHTTPTransport:
|
||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
# Check if we should stop
|
||||||
|
if self.stop_event.is_set():
|
||||||
|
logger.debug("Post writer received stop signal")
|
||||||
|
break
|
||||||
|
|
||||||
# Read message from client queue with timeout to check stop_event periodically
|
# Read message from client queue with timeout to check stop_event periodically
|
||||||
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||||
if session_message is None:
|
if session_message is None:
|
||||||
|
|
@ -381,7 +443,8 @@ class StreamableHTTPTransport:
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
server_to_client_queue.put(exc)
|
if not self.stop_event.is_set():
|
||||||
|
server_to_client_queue.put(exc)
|
||||||
|
|
||||||
def terminate_session(self, client: httpx.Client):
|
def terminate_session(self, client: httpx.Client):
|
||||||
"""Terminate the session by sending a DELETE request."""
|
"""Terminate the session by sending a DELETE request."""
|
||||||
|
|
@ -465,6 +528,12 @@ def streamablehttp_client(
|
||||||
transport.get_session_id,
|
transport.get_session_id,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
# Set stop event to signal all threads to stop
|
||||||
|
transport.stop_event.set()
|
||||||
|
|
||||||
|
# Close all active SSE connections to unblock threads
|
||||||
|
transport.close_active_responses()
|
||||||
|
|
||||||
if transport.session_id and terminate_on_close:
|
if transport.session_id and terminate_on_close:
|
||||||
transport.terminate_session(client)
|
transport.terminate_session(client)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ class MCPClient:
|
||||||
try:
|
try:
|
||||||
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
|
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
|
||||||
self.connect_server(sse_client, "sse")
|
self.connect_server(sse_client, "sse")
|
||||||
except MCPConnectionError:
|
except (MCPConnectionError, ValueError):
|
||||||
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
|
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
|
||||||
self.connect_server(streamablehttp_client, "mcp")
|
self.connect_server(streamablehttp_client, "mcp")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ def generate_dotted_order(run_id: str, start_time: Union[str, datetime], parent_
|
||||||
generate dotted_order for langsmith
|
generate dotted_order for langsmith
|
||||||
"""
|
"""
|
||||||
start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time
|
start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time
|
||||||
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f")[:-3] + "Z"
|
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f") + "Z"
|
||||||
current_segment = f"{timestamp}{run_id}"
|
current_segment = f"{timestamp}{run_id}"
|
||||||
|
|
||||||
if parent_dotted_order is None:
|
if parent_dotted_order is None:
|
||||||
|
|
|
||||||
|
|
@ -90,13 +90,17 @@ class Jieba(BaseKeyword):
|
||||||
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
|
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
|
|
||||||
|
segment_query_stmt = db.session.query(DocumentSegment).where(
|
||||||
|
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices)
|
||||||
|
)
|
||||||
|
if document_ids_filter:
|
||||||
|
segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter))
|
||||||
|
|
||||||
|
segments = db.session.execute(segment_query_stmt).scalars().all()
|
||||||
|
segment_map = {segment.index_node_id: segment for segment in segments}
|
||||||
for chunk_index in sorted_chunk_indices:
|
for chunk_index in sorted_chunk_indices:
|
||||||
segment_query = db.session.query(DocumentSegment).where(
|
segment = segment_map.get(chunk_index)
|
||||||
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
|
|
||||||
)
|
|
||||||
if document_ids_filter:
|
|
||||||
segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter))
|
|
||||||
segment = segment_query.first()
|
|
||||||
|
|
||||||
if segment:
|
if segment:
|
||||||
documents.append(
|
documents.append(
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session, load_only
|
from sqlalchemy.orm import Session, load_only
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from core.db.session_factory import session_factory
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||||
|
|
@ -138,37 +139,47 @@ class RetrievalService:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]:
|
def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]:
|
||||||
"""Deduplicate documents based on doc_id to avoid duplicate chunks in hybrid search."""
|
"""Deduplicate documents in O(n) while preserving first-seen order.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- For provider == "dify" and metadata["doc_id"] exists: keep the doc with the highest
|
||||||
|
metadata["score"] among duplicates; if a later duplicate has no score, ignore it.
|
||||||
|
- For non-dify documents (or dify without doc_id): deduplicate by content key
|
||||||
|
(provider, page_content), keeping the first occurrence.
|
||||||
|
"""
|
||||||
if not documents:
|
if not documents:
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
unique_documents = []
|
# Map of dedup key -> chosen Document
|
||||||
seen_doc_ids = set()
|
chosen: dict[tuple, Document] = {}
|
||||||
|
# Preserve the order of first appearance of each dedup key
|
||||||
|
order: list[tuple] = []
|
||||||
|
|
||||||
for document in documents:
|
for doc in documents:
|
||||||
# For dify provider documents, use doc_id for deduplication
|
is_dify = doc.provider == "dify"
|
||||||
if document.provider == "dify" and document.metadata is not None and "doc_id" in document.metadata:
|
doc_id = (doc.metadata or {}).get("doc_id") if is_dify else None
|
||||||
doc_id = document.metadata["doc_id"]
|
|
||||||
if doc_id not in seen_doc_ids:
|
if is_dify and doc_id:
|
||||||
seen_doc_ids.add(doc_id)
|
key = ("dify", doc_id)
|
||||||
unique_documents.append(document)
|
if key not in chosen:
|
||||||
# If duplicate, keep the one with higher score
|
chosen[key] = doc
|
||||||
elif "score" in document.metadata:
|
order.append(key)
|
||||||
# Find existing document with same doc_id and compare scores
|
else:
|
||||||
for i, existing_doc in enumerate(unique_documents):
|
# Only replace if the new one has a score and it's strictly higher
|
||||||
if (
|
if "score" in doc.metadata:
|
||||||
existing_doc.metadata
|
new_score = float(doc.metadata.get("score", 0.0))
|
||||||
and existing_doc.metadata.get("doc_id") == doc_id
|
old_score = float(chosen[key].metadata.get("score", 0.0)) if chosen[key].metadata else 0.0
|
||||||
and existing_doc.metadata.get("score", 0) < document.metadata.get("score", 0)
|
if new_score > old_score:
|
||||||
):
|
chosen[key] = doc
|
||||||
unique_documents[i] = document
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
# For non-dify documents, use content-based deduplication
|
# Content-based dedup for non-dify or dify without doc_id
|
||||||
if document not in unique_documents:
|
content_key = (doc.provider or "dify", doc.page_content)
|
||||||
unique_documents.append(document)
|
if content_key not in chosen:
|
||||||
|
chosen[content_key] = doc
|
||||||
|
order.append(content_key)
|
||||||
|
# If duplicate content appears, we keep the first occurrence (no score comparison)
|
||||||
|
|
||||||
return unique_documents
|
return [chosen[k] for k in order]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
|
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
|
||||||
|
|
@ -371,58 +382,96 @@ class RetrievalService:
|
||||||
include_segment_ids = set()
|
include_segment_ids = set()
|
||||||
segment_child_map = {}
|
segment_child_map = {}
|
||||||
segment_file_map = {}
|
segment_file_map = {}
|
||||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
|
||||||
# Process documents
|
|
||||||
for document in documents:
|
|
||||||
segment_id = None
|
|
||||||
attachment_info = None
|
|
||||||
child_chunk = None
|
|
||||||
document_id = document.metadata.get("document_id")
|
|
||||||
if document_id not in dataset_documents:
|
|
||||||
continue
|
|
||||||
|
|
||||||
dataset_document = dataset_documents[document_id]
|
valid_dataset_documents = {}
|
||||||
if not dataset_document:
|
image_doc_ids = []
|
||||||
continue
|
child_index_node_ids = []
|
||||||
|
index_node_ids = []
|
||||||
|
doc_to_document_map = {}
|
||||||
|
for document in documents:
|
||||||
|
document_id = document.metadata.get("document_id")
|
||||||
|
if document_id not in dataset_documents:
|
||||||
|
continue
|
||||||
|
|
||||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
dataset_document = dataset_documents[document_id]
|
||||||
# Handle parent-child documents
|
if not dataset_document:
|
||||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
continue
|
||||||
attachment_info_dict = cls.get_segment_attachment_info(
|
valid_dataset_documents[document_id] = dataset_document
|
||||||
dataset_document.dataset_id,
|
|
||||||
dataset_document.tenant_id,
|
|
||||||
document.metadata.get("doc_id") or "",
|
|
||||||
session,
|
|
||||||
)
|
|
||||||
if attachment_info_dict:
|
|
||||||
attachment_info = attachment_info_dict["attachment_info"]
|
|
||||||
segment_id = attachment_info_dict["segment_id"]
|
|
||||||
else:
|
|
||||||
child_index_node_id = document.metadata.get("doc_id")
|
|
||||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
|
|
||||||
child_chunk = session.scalar(child_chunk_stmt)
|
|
||||||
|
|
||||||
if not child_chunk:
|
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||||
continue
|
doc_id = document.metadata.get("doc_id") or ""
|
||||||
segment_id = child_chunk.segment_id
|
doc_to_document_map[doc_id] = document
|
||||||
|
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||||
|
image_doc_ids.append(doc_id)
|
||||||
|
else:
|
||||||
|
child_index_node_ids.append(doc_id)
|
||||||
|
else:
|
||||||
|
doc_id = document.metadata.get("doc_id") or ""
|
||||||
|
doc_to_document_map[doc_id] = document
|
||||||
|
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||||
|
image_doc_ids.append(doc_id)
|
||||||
|
else:
|
||||||
|
index_node_ids.append(doc_id)
|
||||||
|
|
||||||
if not segment_id:
|
image_doc_ids = [i for i in image_doc_ids if i]
|
||||||
continue
|
child_index_node_ids = [i for i in child_index_node_ids if i]
|
||||||
|
index_node_ids = [i for i in index_node_ids if i]
|
||||||
|
|
||||||
segment = (
|
segment_ids = []
|
||||||
session.query(DocumentSegment)
|
index_node_segments: list[DocumentSegment] = []
|
||||||
.where(
|
segments: list[DocumentSegment] = []
|
||||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
attachment_map = {}
|
||||||
DocumentSegment.enabled == True,
|
child_chunk_map = {}
|
||||||
DocumentSegment.status == "completed",
|
doc_segment_map = {}
|
||||||
DocumentSegment.id == segment_id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not segment:
|
with session_factory.create_session() as session:
|
||||||
continue
|
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
|
||||||
|
|
||||||
|
for attachment in attachments:
|
||||||
|
segment_ids.append(attachment["segment_id"])
|
||||||
|
attachment_map[attachment["segment_id"]] = attachment
|
||||||
|
doc_segment_map[attachment["segment_id"]] = attachment["attachment_id"]
|
||||||
|
|
||||||
|
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
|
||||||
|
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
||||||
|
|
||||||
|
for i in child_index_nodes:
|
||||||
|
segment_ids.append(i.segment_id)
|
||||||
|
child_chunk_map[i.segment_id] = i
|
||||||
|
doc_segment_map[i.segment_id] = i.index_node_id
|
||||||
|
|
||||||
|
if index_node_ids:
|
||||||
|
document_segment_stmt = select(DocumentSegment).where(
|
||||||
|
DocumentSegment.enabled == True,
|
||||||
|
DocumentSegment.status == "completed",
|
||||||
|
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||||
|
)
|
||||||
|
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
|
||||||
|
for index_node_segment in index_node_segments:
|
||||||
|
doc_segment_map[index_node_segment.id] = index_node_segment.index_node_id
|
||||||
|
if segment_ids:
|
||||||
|
document_segment_stmt = select(DocumentSegment).where(
|
||||||
|
DocumentSegment.enabled == True,
|
||||||
|
DocumentSegment.status == "completed",
|
||||||
|
DocumentSegment.id.in_(segment_ids),
|
||||||
|
)
|
||||||
|
segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
|
||||||
|
|
||||||
|
if index_node_segments:
|
||||||
|
segments.extend(index_node_segments)
|
||||||
|
|
||||||
|
for segment in segments:
|
||||||
|
doc_id = doc_segment_map.get(segment.id)
|
||||||
|
child_chunk = child_chunk_map.get(segment.id)
|
||||||
|
attachment_info = attachment_map.get(segment.id)
|
||||||
|
|
||||||
|
if doc_id:
|
||||||
|
document = doc_to_document_map[doc_id]
|
||||||
|
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(
|
||||||
|
document.metadata.get("document_id")
|
||||||
|
)
|
||||||
|
|
||||||
|
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||||
if segment.id not in include_segment_ids:
|
if segment.id not in include_segment_ids:
|
||||||
include_segment_ids.add(segment.id)
|
include_segment_ids.add(segment.id)
|
||||||
if child_chunk:
|
if child_chunk:
|
||||||
|
|
@ -430,10 +479,10 @@ class RetrievalService:
|
||||||
"id": child_chunk.id,
|
"id": child_chunk.id,
|
||||||
"content": child_chunk.content,
|
"content": child_chunk.content,
|
||||||
"position": child_chunk.position,
|
"position": child_chunk.position,
|
||||||
"score": document.metadata.get("score", 0.0),
|
"score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||||
}
|
}
|
||||||
map_detail = {
|
map_detail = {
|
||||||
"max_score": document.metadata.get("score", 0.0),
|
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||||
"child_chunks": [child_chunk_detail],
|
"child_chunks": [child_chunk_detail],
|
||||||
}
|
}
|
||||||
segment_child_map[segment.id] = map_detail
|
segment_child_map[segment.id] = map_detail
|
||||||
|
|
@ -452,13 +501,14 @@ class RetrievalService:
|
||||||
"score": document.metadata.get("score", 0.0),
|
"score": document.metadata.get("score", 0.0),
|
||||||
}
|
}
|
||||||
if segment.id in segment_child_map:
|
if segment.id in segment_child_map:
|
||||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) # type: ignore
|
||||||
segment_child_map[segment.id]["max_score"] = max(
|
segment_child_map[segment.id]["max_score"] = max(
|
||||||
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
segment_child_map[segment.id]["max_score"],
|
||||||
|
document.metadata.get("score", 0.0) if document else 0.0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
segment_child_map[segment.id] = {
|
segment_child_map[segment.id] = {
|
||||||
"max_score": document.metadata.get("score", 0.0),
|
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||||
"child_chunks": [child_chunk_detail],
|
"child_chunks": [child_chunk_detail],
|
||||||
}
|
}
|
||||||
if attachment_info:
|
if attachment_info:
|
||||||
|
|
@ -467,46 +517,11 @@ class RetrievalService:
|
||||||
else:
|
else:
|
||||||
segment_file_map[segment.id] = [attachment_info]
|
segment_file_map[segment.id] = [attachment_info]
|
||||||
else:
|
else:
|
||||||
# Handle normal documents
|
|
||||||
segment = None
|
|
||||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
|
||||||
attachment_info_dict = cls.get_segment_attachment_info(
|
|
||||||
dataset_document.dataset_id,
|
|
||||||
dataset_document.tenant_id,
|
|
||||||
document.metadata.get("doc_id") or "",
|
|
||||||
session,
|
|
||||||
)
|
|
||||||
if attachment_info_dict:
|
|
||||||
attachment_info = attachment_info_dict["attachment_info"]
|
|
||||||
segment_id = attachment_info_dict["segment_id"]
|
|
||||||
document_segment_stmt = select(DocumentSegment).where(
|
|
||||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
|
||||||
DocumentSegment.enabled == True,
|
|
||||||
DocumentSegment.status == "completed",
|
|
||||||
DocumentSegment.id == segment_id,
|
|
||||||
)
|
|
||||||
segment = session.scalar(document_segment_stmt)
|
|
||||||
if segment:
|
|
||||||
segment_file_map[segment.id] = [attachment_info]
|
|
||||||
else:
|
|
||||||
index_node_id = document.metadata.get("doc_id")
|
|
||||||
if not index_node_id:
|
|
||||||
continue
|
|
||||||
document_segment_stmt = select(DocumentSegment).where(
|
|
||||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
|
||||||
DocumentSegment.enabled == True,
|
|
||||||
DocumentSegment.status == "completed",
|
|
||||||
DocumentSegment.index_node_id == index_node_id,
|
|
||||||
)
|
|
||||||
segment = session.scalar(document_segment_stmt)
|
|
||||||
|
|
||||||
if not segment:
|
|
||||||
continue
|
|
||||||
if segment.id not in include_segment_ids:
|
if segment.id not in include_segment_ids:
|
||||||
include_segment_ids.add(segment.id)
|
include_segment_ids.add(segment.id)
|
||||||
record = {
|
record = {
|
||||||
"segment": segment,
|
"segment": segment,
|
||||||
"score": document.metadata.get("score"), # type: ignore
|
"score": document.metadata.get("score", 0.0), # type: ignore
|
||||||
}
|
}
|
||||||
if attachment_info:
|
if attachment_info:
|
||||||
segment_file_map[segment.id] = [attachment_info]
|
segment_file_map[segment.id] = [attachment_info]
|
||||||
|
|
@ -522,7 +537,7 @@ class RetrievalService:
|
||||||
for record in records:
|
for record in records:
|
||||||
if record["segment"].id in segment_child_map:
|
if record["segment"].id in segment_child_map:
|
||||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||||
record["score"] = segment_child_map[record["segment"].id]["max_score"]
|
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
|
||||||
if record["segment"].id in segment_file_map:
|
if record["segment"].id in segment_file_map:
|
||||||
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
|
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
@ -565,6 +580,8 @@ class RetrievalService:
|
||||||
flask_app: Flask,
|
flask_app: Flask,
|
||||||
retrieval_method: RetrievalMethod,
|
retrieval_method: RetrievalMethod,
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
|
all_documents: list[Document],
|
||||||
|
exceptions: list[str],
|
||||||
query: str | None = None,
|
query: str | None = None,
|
||||||
top_k: int = 4,
|
top_k: int = 4,
|
||||||
score_threshold: float | None = 0.0,
|
score_threshold: float | None = 0.0,
|
||||||
|
|
@ -573,8 +590,6 @@ class RetrievalService:
|
||||||
weights: dict | None = None,
|
weights: dict | None = None,
|
||||||
document_ids_filter: list[str] | None = None,
|
document_ids_filter: list[str] | None = None,
|
||||||
attachment_id: str | None = None,
|
attachment_id: str | None = None,
|
||||||
all_documents: list[Document] = [],
|
|
||||||
exceptions: list[str] = [],
|
|
||||||
):
|
):
|
||||||
if not query and not attachment_id:
|
if not query and not attachment_id:
|
||||||
return
|
return
|
||||||
|
|
@ -696,3 +711,37 @@ class RetrievalService:
|
||||||
}
|
}
|
||||||
return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id}
|
return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
|
||||||
|
attachment_infos = []
|
||||||
|
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
|
||||||
|
if upload_files:
|
||||||
|
upload_file_ids = [upload_file.id for upload_file in upload_files]
|
||||||
|
attachment_bindings = (
|
||||||
|
session.query(SegmentAttachmentBinding)
|
||||||
|
.where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings}
|
||||||
|
|
||||||
|
if attachment_bindings:
|
||||||
|
for upload_file in upload_files:
|
||||||
|
attachment_binding = attachment_binding_map.get(upload_file.id)
|
||||||
|
attachment_info = {
|
||||||
|
"id": upload_file.id,
|
||||||
|
"name": upload_file.name,
|
||||||
|
"extension": "." + upload_file.extension,
|
||||||
|
"mime_type": upload_file.mime_type,
|
||||||
|
"source_url": sign_upload_file(upload_file.id, upload_file.extension),
|
||||||
|
"size": upload_file.size,
|
||||||
|
}
|
||||||
|
if attachment_binding:
|
||||||
|
attachment_infos.append(
|
||||||
|
{
|
||||||
|
"attachment_id": attachment_binding.attachment_id,
|
||||||
|
"attachment_info": attachment_info,
|
||||||
|
"segment_id": attachment_binding.segment_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return attachment_infos
|
||||||
|
|
|
||||||
|
|
@ -289,7 +289,8 @@ class OracleVector(BaseVector):
|
||||||
words = pseg.cut(query)
|
words = pseg.cut(query)
|
||||||
current_entity = ""
|
current_entity = ""
|
||||||
for word, pos in words:
|
for word, pos in words:
|
||||||
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名
|
# `nr`: Person, `ns`: Location, `nt`: Organization
|
||||||
|
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}:
|
||||||
current_entity += word
|
current_entity += word
|
||||||
else:
|
else:
|
||||||
if current_entity:
|
if current_entity:
|
||||||
|
|
|
||||||
|
|
@ -213,7 +213,7 @@ class VastbaseVector(BaseVector):
|
||||||
|
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
||||||
# Vastbase 支持的向量维度取值范围为 [1,16000]
|
# Vastbase supports vector dimensions in the range [1, 16,000]
|
||||||
if dimension <= 16000:
|
if dimension <= 16000:
|
||||||
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
||||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ class FirecrawlApp:
|
||||||
}
|
}
|
||||||
if params:
|
if params:
|
||||||
json_data.update(params)
|
json_data.update(params)
|
||||||
response = self._post_request(f"{self.base_url}/v2/scrape", json_data, headers)
|
response = self._post_request(self._build_url("v2/scrape"), json_data, headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
data = response_data["data"]
|
data = response_data["data"]
|
||||||
|
|
@ -42,7 +42,7 @@ class FirecrawlApp:
|
||||||
json_data = {"url": url}
|
json_data = {"url": url}
|
||||||
if params:
|
if params:
|
||||||
json_data.update(params)
|
json_data.update(params)
|
||||||
response = self._post_request(f"{self.base_url}/v2/crawl", json_data, headers)
|
response = self._post_request(self._build_url("v2/crawl"), json_data, headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
# There's also another two fields in the response: "success" (bool) and "url" (str)
|
# There's also another two fields in the response: "success" (bool) and "url" (str)
|
||||||
job_id = response.json().get("id")
|
job_id = response.json().get("id")
|
||||||
|
|
@ -58,7 +58,7 @@ class FirecrawlApp:
|
||||||
if params:
|
if params:
|
||||||
# Pass through provided params, including optional "sitemap": "only" | "include" | "skip"
|
# Pass through provided params, including optional "sitemap": "only" | "include" | "skip"
|
||||||
json_data.update(params)
|
json_data.update(params)
|
||||||
response = self._post_request(f"{self.base_url}/v2/map", json_data, headers)
|
response = self._post_request(self._build_url("v2/map"), json_data, headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return cast(dict[str, Any], response.json())
|
return cast(dict[str, Any], response.json())
|
||||||
elif response.status_code in {402, 409, 500, 429, 408}:
|
elif response.status_code in {402, 409, 500, 429, 408}:
|
||||||
|
|
@ -69,7 +69,7 @@ class FirecrawlApp:
|
||||||
|
|
||||||
def check_crawl_status(self, job_id) -> dict[str, Any]:
|
def check_crawl_status(self, job_id) -> dict[str, Any]:
|
||||||
headers = self._prepare_headers()
|
headers = self._prepare_headers()
|
||||||
response = self._get_request(f"{self.base_url}/v2/crawl/{job_id}", headers)
|
response = self._get_request(self._build_url(f"v2/crawl/{job_id}"), headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
crawl_status_response = response.json()
|
crawl_status_response = response.json()
|
||||||
if crawl_status_response.get("status") == "completed":
|
if crawl_status_response.get("status") == "completed":
|
||||||
|
|
@ -120,6 +120,10 @@ class FirecrawlApp:
|
||||||
def _prepare_headers(self) -> dict[str, Any]:
|
def _prepare_headers(self) -> dict[str, Any]:
|
||||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||||
|
|
||||||
|
def _build_url(self, path: str) -> str:
|
||||||
|
# ensure exactly one slash between base and path, regardless of user-provided base_url
|
||||||
|
return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
|
||||||
|
|
||||||
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
|
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
|
||||||
for attempt in range(retries):
|
for attempt in range(retries):
|
||||||
response = httpx.post(url, headers=headers, json=data)
|
response = httpx.post(url, headers=headers, json=data)
|
||||||
|
|
@ -139,7 +143,11 @@ class FirecrawlApp:
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _handle_error(self, response, action):
|
def _handle_error(self, response, action):
|
||||||
error_message = response.json().get("error", "Unknown error occurred")
|
try:
|
||||||
|
payload = response.json()
|
||||||
|
error_message = payload.get("error") or payload.get("message") or response.text or "Unknown error occurred"
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
error_message = response.text or "Unknown error occurred"
|
||||||
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return]
|
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return]
|
||||||
|
|
||||||
def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
|
def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||||
|
|
@ -160,7 +168,7 @@ class FirecrawlApp:
|
||||||
}
|
}
|
||||||
if params:
|
if params:
|
||||||
json_data.update(params)
|
json_data.update(params)
|
||||||
response = self._post_request(f"{self.base_url}/v2/search", json_data, headers)
|
response = self._post_request(self._build_url("v2/search"), json_data, headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
if not response_data.get("success"):
|
if not response_data.get("success"):
|
||||||
|
|
|
||||||
|
|
@ -48,13 +48,21 @@ class NotionExtractor(BaseExtractor):
|
||||||
if notion_access_token:
|
if notion_access_token:
|
||||||
self._notion_access_token = notion_access_token
|
self._notion_access_token = notion_access_token
|
||||||
else:
|
else:
|
||||||
self._notion_access_token = self._get_access_token(tenant_id, self._credential_id)
|
try:
|
||||||
if not self._notion_access_token:
|
self._notion_access_token = self._get_access_token(tenant_id, self._credential_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
(
|
||||||
|
"Failed to get Notion access token from datasource credentials: %s, "
|
||||||
|
"falling back to environment variable NOTION_INTEGRATION_TOKEN"
|
||||||
|
),
|
||||||
|
e,
|
||||||
|
)
|
||||||
integration_token = dify_config.NOTION_INTEGRATION_TOKEN
|
integration_token = dify_config.NOTION_INTEGRATION_TOKEN
|
||||||
if integration_token is None:
|
if integration_token is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Must specify `integration_token` or set environment variable `NOTION_INTEGRATION_TOKEN`."
|
"Must specify `integration_token` or set environment variable `NOTION_INTEGRATION_TOKEN`."
|
||||||
)
|
) from e
|
||||||
|
|
||||||
self._notion_access_token = integration_token
|
self._notion_access_token = integration_token
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -83,6 +83,7 @@ class WordExtractor(BaseExtractor):
|
||||||
def _extract_images_from_docx(self, doc):
|
def _extract_images_from_docx(self, doc):
|
||||||
image_count = 0
|
image_count = 0
|
||||||
image_map = {}
|
image_map = {}
|
||||||
|
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||||
|
|
||||||
for r_id, rel in doc.part.rels.items():
|
for r_id, rel in doc.part.rels.items():
|
||||||
if "image" in rel.target_ref:
|
if "image" in rel.target_ref:
|
||||||
|
|
@ -121,8 +122,7 @@ class WordExtractor(BaseExtractor):
|
||||||
used_at=naive_utc_now(),
|
used_at=naive_utc_now(),
|
||||||
)
|
)
|
||||||
db.session.add(upload_file)
|
db.session.add(upload_file)
|
||||||
# Use r_id as key for external images since target_part is undefined
|
image_map[r_id] = f""
|
||||||
image_map[r_id] = f""
|
|
||||||
else:
|
else:
|
||||||
image_ext = rel.target_ref.split(".")[-1]
|
image_ext = rel.target_ref.split(".")[-1]
|
||||||
if image_ext is None:
|
if image_ext is None:
|
||||||
|
|
@ -150,10 +150,7 @@ class WordExtractor(BaseExtractor):
|
||||||
used_at=naive_utc_now(),
|
used_at=naive_utc_now(),
|
||||||
)
|
)
|
||||||
db.session.add(upload_file)
|
db.session.add(upload_file)
|
||||||
# Use target_part as key for internal images
|
image_map[rel.target_part] = f""
|
||||||
image_map[rel.target_part] = (
|
|
||||||
f""
|
|
||||||
)
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return image_map
|
return image_map
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -231,7 +231,7 @@ class BaseIndexProcessor(ABC):
|
||||||
|
|
||||||
if not filename:
|
if not filename:
|
||||||
parsed_url = urlparse(image_url)
|
parsed_url = urlparse(image_url)
|
||||||
# unquote 处理 URL 中的中文
|
# Decode percent-encoded characters in the URL path.
|
||||||
path = unquote(parsed_url.path)
|
path = unquote(parsed_url.path)
|
||||||
filename = os.path.basename(path)
|
filename = os.path.basename(path)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -151,20 +151,14 @@ class DatasetRetrieval:
|
||||||
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
|
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
|
||||||
planning_strategy = PlanningStrategy.ROUTER
|
planning_strategy = PlanningStrategy.ROUTER
|
||||||
available_datasets = []
|
available_datasets = []
|
||||||
for dataset_id in dataset_ids:
|
|
||||||
# get dataset from dataset id
|
|
||||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
|
||||||
dataset = db.session.scalar(dataset_stmt)
|
|
||||||
|
|
||||||
# pass if dataset is not available
|
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
|
||||||
if not dataset:
|
datasets: list[Dataset] = db.session.execute(dataset_stmt).scalars().all() # type: ignore
|
||||||
|
for dataset in datasets:
|
||||||
|
if dataset.available_document_count == 0 and dataset.provider != "external":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# pass if dataset is not available
|
|
||||||
if dataset and dataset.available_document_count == 0 and dataset.provider != "external":
|
|
||||||
continue
|
|
||||||
|
|
||||||
available_datasets.append(dataset)
|
available_datasets.append(dataset)
|
||||||
|
|
||||||
if inputs:
|
if inputs:
|
||||||
inputs = {key: str(value) for key, value in inputs.items()}
|
inputs = {key: str(value) for key, value in inputs.items()}
|
||||||
else:
|
else:
|
||||||
|
|
@ -282,26 +276,35 @@ class DatasetRetrieval:
|
||||||
)
|
)
|
||||||
context_files.append(attachment_info)
|
context_files.append(attachment_info)
|
||||||
if show_retrieve_source:
|
if show_retrieve_source:
|
||||||
|
dataset_ids = [record.segment.dataset_id for record in records]
|
||||||
|
document_ids = [record.segment.document_id for record in records]
|
||||||
|
dataset_document_stmt = select(DatasetDocument).where(
|
||||||
|
DatasetDocument.id.in_(document_ids),
|
||||||
|
DatasetDocument.enabled == True,
|
||||||
|
DatasetDocument.archived == False,
|
||||||
|
)
|
||||||
|
documents = db.session.execute(dataset_document_stmt).scalars().all() # type: ignore
|
||||||
|
dataset_stmt = select(Dataset).where(
|
||||||
|
Dataset.id.in_(dataset_ids),
|
||||||
|
)
|
||||||
|
datasets = db.session.execute(dataset_stmt).scalars().all() # type: ignore
|
||||||
|
dataset_map = {i.id: i for i in datasets}
|
||||||
|
document_map = {i.id: i for i in documents}
|
||||||
for record in records:
|
for record in records:
|
||||||
segment = record.segment
|
segment = record.segment
|
||||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
dataset_item = dataset_map.get(segment.dataset_id)
|
||||||
dataset_document_stmt = select(DatasetDocument).where(
|
document_item = document_map.get(segment.document_id)
|
||||||
DatasetDocument.id == segment.document_id,
|
if dataset_item and document_item:
|
||||||
DatasetDocument.enabled == True,
|
|
||||||
DatasetDocument.archived == False,
|
|
||||||
)
|
|
||||||
document = db.session.scalar(dataset_document_stmt)
|
|
||||||
if dataset and document:
|
|
||||||
source = RetrievalSourceMetadata(
|
source = RetrievalSourceMetadata(
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset_item.id,
|
||||||
dataset_name=dataset.name,
|
dataset_name=dataset_item.name,
|
||||||
document_id=document.id,
|
document_id=document_item.id,
|
||||||
document_name=document.name,
|
document_name=document_item.name,
|
||||||
data_source_type=document.data_source_type,
|
data_source_type=document_item.data_source_type,
|
||||||
segment_id=segment.id,
|
segment_id=segment.id,
|
||||||
retriever_from=invoke_from.to_source(),
|
retriever_from=invoke_from.to_source(),
|
||||||
score=record.score or 0.0,
|
score=record.score or 0.0,
|
||||||
doc_metadata=document.doc_metadata,
|
doc_metadata=document_item.doc_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
if invoke_from.to_source() == "dev":
|
if invoke_from.to_source() == "dev":
|
||||||
|
|
|
||||||
|
|
@ -153,11 +153,11 @@ class ToolInvokeMessage(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def transform_variable_value(cls, values):
|
def transform_variable_value(cls, values):
|
||||||
"""
|
"""
|
||||||
Only basic types and lists are allowed.
|
Only basic types, lists, and None are allowed.
|
||||||
"""
|
"""
|
||||||
value = values.get("variable_value")
|
value = values.get("variable_value")
|
||||||
if not isinstance(value, dict | list | str | int | float | bool):
|
if value is not None and not isinstance(value, dict | list | str | int | float | bool):
|
||||||
raise ValueError("Only basic types and lists are allowed.")
|
raise ValueError("Only basic types, lists, and None are allowed.")
|
||||||
|
|
||||||
# if stream is true, the value must be a string
|
# if stream is true, the value must be a string
|
||||||
if values.get("stream"):
|
if values.get("stream"):
|
||||||
|
|
|
||||||
|
|
@ -67,12 +67,16 @@ def create_trigger_provider_encrypter_for_subscription(
|
||||||
|
|
||||||
|
|
||||||
def delete_cache_for_subscription(tenant_id: str, provider_id: str, subscription_id: str):
|
def delete_cache_for_subscription(tenant_id: str, provider_id: str, subscription_id: str):
|
||||||
cache = TriggerProviderCredentialsCache(
|
TriggerProviderCredentialsCache(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
credential_id=subscription_id,
|
credential_id=subscription_id,
|
||||||
)
|
).delete()
|
||||||
cache.delete()
|
TriggerProviderPropertiesCache(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
).delete()
|
||||||
|
|
||||||
|
|
||||||
def create_trigger_provider_encrypter_for_properties(
|
def create_trigger_provider_encrypter_for_properties(
|
||||||
|
|
|
||||||
|
|
@ -247,6 +247,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||||
DATASOURCE_INFO = "datasource_info"
|
DATASOURCE_INFO = "datasource_info"
|
||||||
|
COMPLETED_REASON = "completed_reason" # completed reason for loop node
|
||||||
|
|
||||||
|
|
||||||
class WorkflowNodeExecutionStatus(StrEnum):
|
class WorkflowNodeExecutionStatus(StrEnum):
|
||||||
|
|
|
||||||
|
|
@ -86,6 +86,11 @@ class Executor:
|
||||||
node_data.authorization.config.api_key = variable_pool.convert_template(
|
node_data.authorization.config.api_key = variable_pool.convert_template(
|
||||||
node_data.authorization.config.api_key
|
node_data.authorization.config.api_key
|
||||||
).text
|
).text
|
||||||
|
# Validate that API key is not empty after template conversion
|
||||||
|
if not node_data.authorization.config.api_key or not node_data.authorization.config.api_key.strip():
|
||||||
|
raise AuthorizationConfigError(
|
||||||
|
"API key is required for authorization but was empty. Please provide a valid API key."
|
||||||
|
)
|
||||||
|
|
||||||
self.url = node_data.url
|
self.url = node_data.url
|
||||||
self.method = node_data.method
|
self.method = node_data.method
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from enum import StrEnum
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
||||||
|
|
@ -96,3 +97,8 @@ class LoopState(BaseLoopState):
|
||||||
Get current output.
|
Get current output.
|
||||||
"""
|
"""
|
||||||
return self.current_output
|
return self.current_output
|
||||||
|
|
||||||
|
|
||||||
|
class LoopCompletedReason(StrEnum):
|
||||||
|
LOOP_BREAK = "loop_break"
|
||||||
|
LOOP_COMPLETED = "loop_completed"
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ from core.workflow.node_events import (
|
||||||
)
|
)
|
||||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
|
from core.workflow.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
|
||||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
|
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
|
|
@ -96,6 +96,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||||
loop_duration_map: dict[str, float] = {}
|
loop_duration_map: dict[str, float] = {}
|
||||||
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
||||||
loop_usage = LLMUsage.empty_usage()
|
loop_usage = LLMUsage.empty_usage()
|
||||||
|
loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
|
||||||
|
|
||||||
# Start Loop event
|
# Start Loop event
|
||||||
yield LoopStartedEvent(
|
yield LoopStartedEvent(
|
||||||
|
|
@ -118,6 +119,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||||
loop_count = 0
|
loop_count = 0
|
||||||
|
|
||||||
for i in range(loop_count):
|
for i in range(loop_count):
|
||||||
|
# Clear stale variables from previous loop iterations to avoid streaming old values
|
||||||
|
self._clear_loop_subgraph_variables(loop_node_ids)
|
||||||
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
|
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
|
||||||
|
|
||||||
loop_start_time = naive_utc_now()
|
loop_start_time = naive_utc_now()
|
||||||
|
|
@ -177,7 +180,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||||
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||||
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
|
WorkflowNodeExecutionMetadataKey.COMPLETED_REASON: (
|
||||||
|
LoopCompletedReason.LOOP_BREAK
|
||||||
|
if reach_break_condition
|
||||||
|
else LoopCompletedReason.LOOP_COMPLETED.value
|
||||||
|
),
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||||
},
|
},
|
||||||
|
|
@ -274,6 +281,17 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||||
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata:
|
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata:
|
||||||
event.node_run_result.metadata = {**current_metadata, **loop_metadata}
|
event.node_run_result.metadata = {**current_metadata, **loop_metadata}
|
||||||
|
|
||||||
|
def _clear_loop_subgraph_variables(self, loop_node_ids: set[str]) -> None:
|
||||||
|
"""
|
||||||
|
Remove variables produced by loop sub-graph nodes from previous iterations.
|
||||||
|
|
||||||
|
Keeping stale variables causes a freshly created response coordinator in the
|
||||||
|
next iteration to fall back to outdated values when no stream chunks exist.
|
||||||
|
"""
|
||||||
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
|
for node_id in loop_node_ids:
|
||||||
|
variable_pool.remove([node_id])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
cls,
|
cls,
|
||||||
|
|
|
||||||
|
|
@ -281,7 +281,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||||
|
|
||||||
# handle invoke result
|
# handle invoke result
|
||||||
|
|
||||||
text = invoke_result.message.content or ""
|
text = invoke_result.message.get_text_content()
|
||||||
if not isinstance(text, str):
|
if not isinstance(text, str):
|
||||||
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
|
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ class AliyunOssStorage(BaseStorage):
|
||||||
self.bucket_name,
|
self.bucket_name,
|
||||||
connect_timeout=30,
|
connect_timeout=30,
|
||||||
region=region,
|
region=region,
|
||||||
|
cloudbox_id=dify_config.ALIYUN_CLOUDBOX_ID,
|
||||||
)
|
)
|
||||||
|
|
||||||
def save(self, filename, data):
|
def save(self, filename, data):
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ class HuaweiObsStorage(BaseStorage):
|
||||||
access_key_id=dify_config.HUAWEI_OBS_ACCESS_KEY,
|
access_key_id=dify_config.HUAWEI_OBS_ACCESS_KEY,
|
||||||
secret_access_key=dify_config.HUAWEI_OBS_SECRET_KEY,
|
secret_access_key=dify_config.HUAWEI_OBS_SECRET_KEY,
|
||||||
server=dify_config.HUAWEI_OBS_SERVER,
|
server=dify_config.HUAWEI_OBS_SERVER,
|
||||||
|
path_style=dify_config.HUAWEI_OBS_PATH_STYLE,
|
||||||
)
|
)
|
||||||
|
|
||||||
def save(self, filename, data):
|
def save(self, filename, data):
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,7 @@ dependencies = [
|
||||||
"pydantic-extra-types~=2.10.3",
|
"pydantic-extra-types~=2.10.3",
|
||||||
"pydantic-settings~=2.11.0",
|
"pydantic-settings~=2.11.0",
|
||||||
"pyjwt~=2.10.1",
|
"pyjwt~=2.10.1",
|
||||||
"pypdfium2==4.30.0",
|
"pypdfium2==5.2.0",
|
||||||
"python-docx~=1.1.0",
|
"python-docx~=1.1.0",
|
||||||
"python-dotenv==1.0.1",
|
"python-dotenv==1.0.1",
|
||||||
"pyyaml~=6.0.1",
|
"pyyaml~=6.0.1",
|
||||||
|
|
|
||||||
|
|
@ -155,6 +155,7 @@ class AppDslService:
|
||||||
parsed_url.scheme == "https"
|
parsed_url.scheme == "https"
|
||||||
and parsed_url.netloc == "github.com"
|
and parsed_url.netloc == "github.com"
|
||||||
and parsed_url.path.endswith((".yml", ".yaml"))
|
and parsed_url.path.endswith((".yml", ".yaml"))
|
||||||
|
and "/blob/" in parsed_url.path
|
||||||
):
|
):
|
||||||
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
|
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
|
||||||
yaml_url = yaml_url.replace("/blob/", "/")
|
yaml_url = yaml_url.replace("/blob/", "/")
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ class FirecrawlAuth(ApiKeyAuthBase):
|
||||||
"limit": 1,
|
"limit": 1,
|
||||||
"scrapeOptions": {"onlyMainContent": True},
|
"scrapeOptions": {"onlyMainContent": True},
|
||||||
}
|
}
|
||||||
response = self._post_request(f"{self.base_url}/v1/crawl", options, headers)
|
response = self._post_request(self._build_url("v1/crawl"), options, headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
|
|
@ -35,15 +35,17 @@ class FirecrawlAuth(ApiKeyAuthBase):
|
||||||
def _prepare_headers(self):
|
def _prepare_headers(self):
|
||||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||||
|
|
||||||
|
def _build_url(self, path: str) -> str:
|
||||||
|
# ensure exactly one slash between base and path, regardless of user-provided base_url
|
||||||
|
return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
|
||||||
|
|
||||||
def _post_request(self, url, data, headers):
|
def _post_request(self, url, data, headers):
|
||||||
return httpx.post(url, headers=headers, json=data)
|
return httpx.post(url, headers=headers, json=data)
|
||||||
|
|
||||||
def _handle_error(self, response):
|
def _handle_error(self, response):
|
||||||
if response.status_code in {402, 409, 500}:
|
try:
|
||||||
error_message = response.json().get("error", "Unknown error occurred")
|
payload = response.json()
|
||||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
except json.JSONDecodeError:
|
||||||
else:
|
payload = {}
|
||||||
if response.text:
|
error_message = payload.get("error") or payload.get("message") or (response.text or "Unknown error occurred")
|
||||||
error_message = json.loads(response.text).get("error", "Unknown error occurred")
|
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
|
||||||
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,9 @@ from typing import Any, Union
|
||||||
from sqlalchemy import asc, desc, func, or_, select
|
from sqlalchemy import asc, desc, func, or_, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.db.session_factory import session_factory
|
||||||
from core.llm_generator.llm_generator import LLMGenerator
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
from core.variables.types import SegmentType
|
from core.variables.types import SegmentType
|
||||||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
||||||
|
|
@ -202,6 +204,7 @@ class ConversationService:
|
||||||
user: Union[Account, EndUser] | None,
|
user: Union[Account, EndUser] | None,
|
||||||
limit: int,
|
limit: int,
|
||||||
last_id: str | None,
|
last_id: str | None,
|
||||||
|
variable_name: str | None = None,
|
||||||
) -> InfiniteScrollPagination:
|
) -> InfiniteScrollPagination:
|
||||||
conversation = cls.get_conversation(app_model, conversation_id, user)
|
conversation = cls.get_conversation(app_model, conversation_id, user)
|
||||||
|
|
||||||
|
|
@ -212,7 +215,25 @@ class ConversationService:
|
||||||
.order_by(ConversationVariable.created_at)
|
.order_by(ConversationVariable.created_at)
|
||||||
)
|
)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
# Apply variable_name filter if provided
|
||||||
|
if variable_name:
|
||||||
|
# Filter using JSON extraction to match variable names case-insensitively
|
||||||
|
escaped_variable_name = variable_name.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||||
|
# Filter using JSON extraction to match variable names case-insensitively
|
||||||
|
if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]:
|
||||||
|
stmt = stmt.where(
|
||||||
|
func.json_extract(ConversationVariable.data, "$.name").ilike(
|
||||||
|
f"%{escaped_variable_name}%", escape="\\"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif dify_config.DB_TYPE == "postgresql":
|
||||||
|
stmt = stmt.where(
|
||||||
|
func.json_extract_path_text(ConversationVariable.data, "name").ilike(
|
||||||
|
f"%{escaped_variable_name}%", escape="\\"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with session_factory.create_session() as session:
|
||||||
if last_id:
|
if last_id:
|
||||||
last_variable = session.scalar(stmt.where(ConversationVariable.id == last_id))
|
last_variable = session.scalar(stmt.where(ConversationVariable.id == last_id))
|
||||||
if not last_variable:
|
if not last_variable:
|
||||||
|
|
@ -279,7 +300,7 @@ class ConversationService:
|
||||||
.where(ConversationVariable.id == variable_id)
|
.where(ConversationVariable.id == variable_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with session_factory.create_session() as session:
|
||||||
existing_variable = session.scalar(stmt)
|
existing_variable = session.scalar(stmt)
|
||||||
if not existing_variable:
|
if not existing_variable:
|
||||||
raise ConversationVariableNotExistsError()
|
raise ConversationVariableNotExistsError()
|
||||||
|
|
|
||||||
|
|
@ -105,3 +105,49 @@ class PluginParameterService:
|
||||||
)
|
)
|
||||||
.options
|
.options
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_dynamic_select_options_with_credentials(
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
plugin_id: str,
|
||||||
|
provider: str,
|
||||||
|
action: str,
|
||||||
|
parameter: str,
|
||||||
|
credential_id: str,
|
||||||
|
credentials: Mapping[str, Any],
|
||||||
|
) -> Sequence[PluginParameterOption]:
|
||||||
|
"""
|
||||||
|
Get dynamic select options using provided credentials directly.
|
||||||
|
Used for edit mode when credentials have been modified but not yet saved.
|
||||||
|
|
||||||
|
Security: credential_id is validated against tenant_id to ensure
|
||||||
|
users can only access their own credentials.
|
||||||
|
"""
|
||||||
|
from constants import HIDDEN_VALUE
|
||||||
|
|
||||||
|
# Get original subscription to replace hidden values (with tenant_id check for security)
|
||||||
|
original_subscription = TriggerProviderService.get_subscription_by_id(tenant_id, credential_id)
|
||||||
|
if not original_subscription:
|
||||||
|
raise ValueError(f"Subscription {credential_id} not found")
|
||||||
|
|
||||||
|
# Replace [__HIDDEN__] with original values
|
||||||
|
resolved_credentials: dict[str, Any] = {
|
||||||
|
key: (original_subscription.credentials.get(key) if value == HIDDEN_VALUE else value)
|
||||||
|
for key, value in credentials.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
DynamicSelectClient()
|
||||||
|
.fetch_dynamic_select_options(
|
||||||
|
tenant_id,
|
||||||
|
user_id,
|
||||||
|
plugin_id,
|
||||||
|
provider,
|
||||||
|
action,
|
||||||
|
resolved_credentials,
|
||||||
|
CredentialType.API_KEY.value,
|
||||||
|
parameter,
|
||||||
|
)
|
||||||
|
.options
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ from sqlalchemy.orm import Session
|
||||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
|
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
|
||||||
from core.mcp.auth.auth_flow import auth
|
from core.mcp.auth.auth_flow import auth
|
||||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||||
from core.mcp.error import MCPAuthError, MCPError
|
from core.mcp.error import MCPAuthError, MCPError
|
||||||
|
|
@ -65,6 +64,15 @@ class ServerUrlValidationResult(BaseModel):
|
||||||
return self.needs_validation and self.validation_passed and self.reconnect_result is not None
|
return self.needs_validation and self.validation_passed and self.reconnect_result is not None
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderUrlValidationData(BaseModel):
|
||||||
|
"""Data required for URL validation, extracted from database to perform network operations outside of session"""
|
||||||
|
|
||||||
|
current_server_url_hash: str
|
||||||
|
headers: dict[str, str]
|
||||||
|
timeout: float | None
|
||||||
|
sse_read_timeout: float | None
|
||||||
|
|
||||||
|
|
||||||
class MCPToolManageService:
|
class MCPToolManageService:
|
||||||
"""Service class for managing MCP tools and providers."""
|
"""Service class for managing MCP tools and providers."""
|
||||||
|
|
||||||
|
|
@ -166,9 +174,6 @@ class MCPToolManageService:
|
||||||
self._session.add(mcp_tool)
|
self._session.add(mcp_tool)
|
||||||
self._session.flush()
|
self._session.flush()
|
||||||
|
|
||||||
# Invalidate tool providers cache
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
|
|
||||||
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
||||||
return mcp_providers
|
return mcp_providers
|
||||||
|
|
||||||
|
|
@ -192,7 +197,7 @@ class MCPToolManageService:
|
||||||
Update an MCP provider.
|
Update an MCP provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
validation_result: Pre-validation result from validate_server_url_change.
|
validation_result: Pre-validation result from validate_server_url_standalone.
|
||||||
If provided and contains reconnect_result, it will be used
|
If provided and contains reconnect_result, it will be used
|
||||||
instead of performing network operations.
|
instead of performing network operations.
|
||||||
"""
|
"""
|
||||||
|
|
@ -251,8 +256,6 @@ class MCPToolManageService:
|
||||||
# Flush changes to database
|
# Flush changes to database
|
||||||
self._session.flush()
|
self._session.flush()
|
||||||
|
|
||||||
# Invalidate tool providers cache
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
self._handle_integrity_error(e, name, server_url, server_identifier)
|
self._handle_integrity_error(e, name, server_url, server_identifier)
|
||||||
|
|
||||||
|
|
@ -261,9 +264,6 @@ class MCPToolManageService:
|
||||||
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
self._session.delete(mcp_tool)
|
self._session.delete(mcp_tool)
|
||||||
|
|
||||||
# Invalidate tool providers cache
|
|
||||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
|
||||||
|
|
||||||
def list_providers(
|
def list_providers(
|
||||||
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
|
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
|
||||||
) -> list[ToolProviderApiEntity]:
|
) -> list[ToolProviderApiEntity]:
|
||||||
|
|
@ -546,30 +546,39 @@ class MCPToolManageService:
|
||||||
)
|
)
|
||||||
return self.execute_auth_actions(auth_result)
|
return self.execute_auth_actions(auth_result)
|
||||||
|
|
||||||
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:
|
def get_provider_for_url_validation(self, *, tenant_id: str, provider_id: str) -> ProviderUrlValidationData:
|
||||||
"""Attempt to reconnect to MCP provider with new server URL."""
|
"""
|
||||||
|
Get provider data required for URL validation.
|
||||||
|
This method performs database read and should be called within a session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ProviderUrlValidationData: Data needed for standalone URL validation
|
||||||
|
"""
|
||||||
|
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||||
provider_entity = provider.to_entity()
|
provider_entity = provider.to_entity()
|
||||||
headers = provider_entity.headers
|
return ProviderUrlValidationData(
|
||||||
|
current_server_url_hash=provider.server_url_hash,
|
||||||
|
headers=provider_entity.headers,
|
||||||
|
timeout=provider_entity.timeout,
|
||||||
|
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
@staticmethod
|
||||||
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
|
def validate_server_url_standalone(
|
||||||
return ReconnectResult(
|
*,
|
||||||
authed=True,
|
tenant_id: str,
|
||||||
tools=json.dumps([tool.model_dump() for tool in tools]),
|
new_server_url: str,
|
||||||
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
validation_data: ProviderUrlValidationData,
|
||||||
)
|
|
||||||
except MCPAuthError:
|
|
||||||
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
|
|
||||||
except MCPError as e:
|
|
||||||
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
|
||||||
|
|
||||||
def validate_server_url_change(
|
|
||||||
self, *, tenant_id: str, provider_id: str, new_server_url: str
|
|
||||||
) -> ServerUrlValidationResult:
|
) -> ServerUrlValidationResult:
|
||||||
"""
|
"""
|
||||||
Validate server URL change by attempting to connect to the new server.
|
Validate server URL change by attempting to connect to the new server.
|
||||||
This method should be called BEFORE update_provider to perform network operations
|
This method performs network operations and MUST be called OUTSIDE of any database session
|
||||||
outside of the database transaction.
|
to avoid holding locks during network I/O.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant ID for encryption
|
||||||
|
new_server_url: The new server URL to validate
|
||||||
|
validation_data: Provider data obtained from get_provider_for_url_validation
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ServerUrlValidationResult: Validation result with connection status and tools if successful
|
ServerUrlValidationResult: Validation result with connection status and tools if successful
|
||||||
|
|
@ -579,25 +588,30 @@ class MCPToolManageService:
|
||||||
return ServerUrlValidationResult(needs_validation=False)
|
return ServerUrlValidationResult(needs_validation=False)
|
||||||
|
|
||||||
# Validate URL format
|
# Validate URL format
|
||||||
if not self._is_valid_url(new_server_url):
|
parsed = urlparse(new_server_url)
|
||||||
|
if not all([parsed.scheme, parsed.netloc]) or parsed.scheme not in ["http", "https"]:
|
||||||
raise ValueError("Server URL is not valid.")
|
raise ValueError("Server URL is not valid.")
|
||||||
|
|
||||||
# Always encrypt and hash the URL
|
# Always encrypt and hash the URL
|
||||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
|
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
|
||||||
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
|
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
|
||||||
|
|
||||||
# Get current provider
|
|
||||||
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
|
||||||
|
|
||||||
# Check if URL is actually different
|
# Check if URL is actually different
|
||||||
if new_server_url_hash == provider.server_url_hash:
|
if new_server_url_hash == validation_data.current_server_url_hash:
|
||||||
# URL hasn't changed, but still return the encrypted data
|
# URL hasn't changed, but still return the encrypted data
|
||||||
return ServerUrlValidationResult(
|
return ServerUrlValidationResult(
|
||||||
needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash
|
needs_validation=False,
|
||||||
|
encrypted_server_url=encrypted_server_url,
|
||||||
|
server_url_hash=new_server_url_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Perform validation by attempting to connect
|
# Perform network validation - this is the expensive operation that should be outside session
|
||||||
reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider)
|
reconnect_result = MCPToolManageService._reconnect_with_url(
|
||||||
|
server_url=new_server_url,
|
||||||
|
headers=validation_data.headers,
|
||||||
|
timeout=validation_data.timeout,
|
||||||
|
sse_read_timeout=validation_data.sse_read_timeout,
|
||||||
|
)
|
||||||
return ServerUrlValidationResult(
|
return ServerUrlValidationResult(
|
||||||
needs_validation=True,
|
needs_validation=True,
|
||||||
validation_passed=True,
|
validation_passed=True,
|
||||||
|
|
@ -606,6 +620,38 @@ class MCPToolManageService:
|
||||||
server_url_hash=new_server_url_hash,
|
server_url_hash=new_server_url_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reconnect_with_url(
|
||||||
|
*,
|
||||||
|
server_url: str,
|
||||||
|
headers: dict[str, str],
|
||||||
|
timeout: float | None,
|
||||||
|
sse_read_timeout: float | None,
|
||||||
|
) -> ReconnectResult:
|
||||||
|
"""
|
||||||
|
Attempt to connect to MCP server with given URL.
|
||||||
|
This is a static method that performs network I/O without database access.
|
||||||
|
"""
|
||||||
|
from core.mcp.mcp_client import MCPClient
|
||||||
|
|
||||||
|
try:
|
||||||
|
with MCPClient(
|
||||||
|
server_url=server_url,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
sse_read_timeout=sse_read_timeout,
|
||||||
|
) as mcp_client:
|
||||||
|
tools = mcp_client.list_tools()
|
||||||
|
return ReconnectResult(
|
||||||
|
authed=True,
|
||||||
|
tools=json.dumps([tool.model_dump() for tool in tools]),
|
||||||
|
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||||
|
)
|
||||||
|
except MCPAuthError:
|
||||||
|
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
|
||||||
|
except MCPError as e:
|
||||||
|
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
||||||
|
|
||||||
def _build_tool_provider_response(
|
def _build_tool_provider_response(
|
||||||
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
|
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
|
||||||
) -> ToolProviderApiEntity:
|
) -> ToolProviderApiEntity:
|
||||||
|
|
|
||||||
|
|
@ -94,16 +94,23 @@ class TriggerProviderService:
|
||||||
|
|
||||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||||
for subscription in subscriptions:
|
for subscription in subscriptions:
|
||||||
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
controller=provider_controller,
|
controller=provider_controller,
|
||||||
subscription=subscription,
|
subscription=subscription,
|
||||||
)
|
)
|
||||||
subscription.credentials = dict(
|
subscription.credentials = dict(
|
||||||
encrypter.mask_credentials(dict(encrypter.decrypt(subscription.credentials)))
|
credential_encrypter.mask_credentials(dict(credential_encrypter.decrypt(subscription.credentials)))
|
||||||
)
|
)
|
||||||
subscription.properties = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.properties))))
|
properties_encrypter, _ = create_trigger_provider_encrypter_for_properties(
|
||||||
subscription.parameters = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.parameters))))
|
tenant_id=tenant_id,
|
||||||
|
controller=provider_controller,
|
||||||
|
subscription=subscription,
|
||||||
|
)
|
||||||
|
subscription.properties = dict(
|
||||||
|
properties_encrypter.mask_credentials(dict(properties_encrypter.decrypt(subscription.properties)))
|
||||||
|
)
|
||||||
|
subscription.parameters = dict(subscription.parameters)
|
||||||
count = workflows_in_use_map.get(subscription.id)
|
count = workflows_in_use_map.get(subscription.id)
|
||||||
subscription.workflows_in_use = count if count is not None else 0
|
subscription.workflows_in_use = count if count is not None else 0
|
||||||
|
|
||||||
|
|
@ -209,6 +216,101 @@ class TriggerProviderService:
|
||||||
logger.exception("Failed to add trigger provider")
|
logger.exception("Failed to add trigger provider")
|
||||||
raise ValueError(str(e))
|
raise ValueError(str(e))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_trigger_subscription(
|
||||||
|
cls,
|
||||||
|
tenant_id: str,
|
||||||
|
subscription_id: str,
|
||||||
|
name: str | None = None,
|
||||||
|
properties: Mapping[str, Any] | None = None,
|
||||||
|
parameters: Mapping[str, Any] | None = None,
|
||||||
|
credentials: Mapping[str, Any] | None = None,
|
||||||
|
credential_expires_at: int | None = None,
|
||||||
|
expires_at: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update an existing trigger subscription.
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param subscription_id: Subscription instance ID
|
||||||
|
:param name: Optional new name for this subscription
|
||||||
|
:param properties: Optional new properties
|
||||||
|
:param parameters: Optional new parameters
|
||||||
|
:param credentials: Optional new credentials
|
||||||
|
:param credential_expires_at: Optional new credential expiration timestamp
|
||||||
|
:param expires_at: Optional new expiration timestamp
|
||||||
|
:return: Success response with updated subscription info
|
||||||
|
"""
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
|
# Use distributed lock to prevent race conditions on the same subscription
|
||||||
|
lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}"
|
||||||
|
with redis_client.lock(lock_key, timeout=20):
|
||||||
|
subscription: TriggerSubscription | None = (
|
||||||
|
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||||
|
)
|
||||||
|
if not subscription:
|
||||||
|
raise ValueError(f"Trigger subscription {subscription_id} not found")
|
||||||
|
|
||||||
|
provider_id = TriggerProviderID(subscription.provider_id)
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
|
||||||
|
# Check for name uniqueness if name is being updated
|
||||||
|
if name is not None and name != subscription.name:
|
||||||
|
existing = (
|
||||||
|
session.query(TriggerSubscription)
|
||||||
|
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
raise ValueError(f"Subscription name '{name}' already exists for this provider")
|
||||||
|
subscription.name = name
|
||||||
|
|
||||||
|
# Update properties if provided
|
||||||
|
if properties is not None:
|
||||||
|
properties_encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=provider_controller.get_properties_schema(),
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
# Handle hidden values - preserve original encrypted values
|
||||||
|
original_properties = properties_encrypter.decrypt(subscription.properties)
|
||||||
|
new_properties: dict[str, Any] = {
|
||||||
|
key: value if value != HIDDEN_VALUE else original_properties.get(key, UNKNOWN_VALUE)
|
||||||
|
for key, value in properties.items()
|
||||||
|
}
|
||||||
|
subscription.properties = dict(properties_encrypter.encrypt(new_properties))
|
||||||
|
|
||||||
|
# Update parameters if provided
|
||||||
|
if parameters is not None:
|
||||||
|
subscription.parameters = dict(parameters)
|
||||||
|
|
||||||
|
# Update credentials if provided
|
||||||
|
if credentials is not None:
|
||||||
|
credential_type = CredentialType.of(subscription.credential_type)
|
||||||
|
credential_encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=provider_controller.get_credential_schema_config(credential_type),
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
subscription.credentials = dict(credential_encrypter.encrypt(dict(credentials)))
|
||||||
|
|
||||||
|
# Update credential expiration timestamp if provided
|
||||||
|
if credential_expires_at is not None:
|
||||||
|
subscription.credential_expires_at = credential_expires_at
|
||||||
|
|
||||||
|
# Update expiration timestamp if provided
|
||||||
|
if expires_at is not None:
|
||||||
|
subscription.expires_at = expires_at
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Clear subscription cache
|
||||||
|
delete_cache_for_subscription(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=subscription.provider_id,
|
||||||
|
subscription_id=subscription.id,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_subscription_by_id(cls, tenant_id: str, subscription_id: str | None = None) -> TriggerSubscription | None:
|
def get_subscription_by_id(cls, tenant_id: str, subscription_id: str | None = None) -> TriggerSubscription | None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -257,17 +359,18 @@ class TriggerProviderService:
|
||||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||||
|
|
||||||
credential_type: CredentialType = CredentialType.of(subscription.credential_type)
|
credential_type: CredentialType = CredentialType.of(subscription.credential_type)
|
||||||
|
provider_id = TriggerProviderID(subscription.provider_id)
|
||||||
|
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||||
|
tenant_id=tenant_id, provider_id=provider_id
|
||||||
|
)
|
||||||
|
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
controller=provider_controller,
|
||||||
|
subscription=subscription,
|
||||||
|
)
|
||||||
|
|
||||||
is_auto_created: bool = credential_type in [CredentialType.OAUTH2, CredentialType.API_KEY]
|
is_auto_created: bool = credential_type in [CredentialType.OAUTH2, CredentialType.API_KEY]
|
||||||
if is_auto_created:
|
if is_auto_created:
|
||||||
provider_id = TriggerProviderID(subscription.provider_id)
|
|
||||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
|
||||||
tenant_id=tenant_id, provider_id=provider_id
|
|
||||||
)
|
|
||||||
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
controller=provider_controller,
|
|
||||||
subscription=subscription,
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
TriggerManager.unsubscribe_trigger(
|
TriggerManager.unsubscribe_trigger(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
|
|
@ -280,8 +383,8 @@ class TriggerProviderService:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error unsubscribing trigger", exc_info=e)
|
logger.exception("Error unsubscribing trigger", exc_info=e)
|
||||||
|
|
||||||
# Clear cache
|
|
||||||
session.delete(subscription)
|
session.delete(subscription)
|
||||||
|
# Clear cache
|
||||||
delete_cache_for_subscription(
|
delete_cache_for_subscription(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider_id=subscription.provider_id,
|
provider_id=subscription.provider_id,
|
||||||
|
|
@ -688,3 +791,125 @@ class TriggerProviderService:
|
||||||
)
|
)
|
||||||
subscription.properties = dict(properties_encrypter.decrypt(subscription.properties))
|
subscription.properties = dict(properties_encrypter.decrypt(subscription.properties))
|
||||||
return subscription
|
return subscription
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def verify_subscription_credentials(
|
||||||
|
cls,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
provider_id: TriggerProviderID,
|
||||||
|
subscription_id: str,
|
||||||
|
credentials: Mapping[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Verify credentials for an existing subscription without updating it.
|
||||||
|
|
||||||
|
This is used in edit mode to validate new credentials before rebuild.
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param user_id: User ID
|
||||||
|
:param provider_id: Provider identifier
|
||||||
|
:param subscription_id: Subscription ID
|
||||||
|
:param credentials: New credentials to verify
|
||||||
|
:return: dict with 'verified' boolean
|
||||||
|
"""
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
if not provider_controller:
|
||||||
|
raise ValueError(f"Provider {provider_id} not found")
|
||||||
|
|
||||||
|
subscription = cls.get_subscription_by_id(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
)
|
||||||
|
if not subscription:
|
||||||
|
raise ValueError(f"Subscription {subscription_id} not found")
|
||||||
|
|
||||||
|
credential_type = CredentialType.of(subscription.credential_type)
|
||||||
|
|
||||||
|
# For API Key, validate the new credentials
|
||||||
|
if credential_type == CredentialType.API_KEY:
|
||||||
|
new_credentials: dict[str, Any] = {
|
||||||
|
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
|
||||||
|
for key, value in credentials.items()
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
provider_controller.validate_credentials(user_id, credentials=new_credentials)
|
||||||
|
return {"verified": True}
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid credentials: {e}") from e
|
||||||
|
|
||||||
|
return {"verified": True}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def rebuild_trigger_subscription(
|
||||||
|
cls,
|
||||||
|
tenant_id: str,
|
||||||
|
provider_id: TriggerProviderID,
|
||||||
|
subscription_id: str,
|
||||||
|
credentials: Mapping[str, Any],
|
||||||
|
parameters: Mapping[str, Any],
|
||||||
|
name: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create a subscription builder for rebuilding an existing subscription.
|
||||||
|
|
||||||
|
This method creates a builder pre-filled with data from the rebuild request,
|
||||||
|
keeping the same subscription_id and endpoint_id so the webhook URL remains unchanged.
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param name: Name for the subscription
|
||||||
|
:param subscription_id: Subscription ID
|
||||||
|
:param provider_id: Provider identifier
|
||||||
|
:param credentials: Credentials for the subscription
|
||||||
|
:param parameters: Parameters for the subscription
|
||||||
|
:return: SubscriptionBuilderApiEntity
|
||||||
|
"""
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
if not provider_controller:
|
||||||
|
raise ValueError(f"Provider {provider_id} not found")
|
||||||
|
|
||||||
|
subscription = TriggerProviderService.get_subscription_by_id(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
)
|
||||||
|
if not subscription:
|
||||||
|
raise ValueError(f"Subscription {subscription_id} not found")
|
||||||
|
|
||||||
|
credential_type = CredentialType.of(subscription.credential_type)
|
||||||
|
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
|
||||||
|
raise ValueError("Credential type not supported for rebuild")
|
||||||
|
|
||||||
|
# TODO: Trying to invoke update api of the plugin trigger provider
|
||||||
|
|
||||||
|
# FALLBACK: If the update api is not implemented, delete the previous subscription and create a new one
|
||||||
|
|
||||||
|
# Delete the previous subscription
|
||||||
|
user_id = subscription.user_id
|
||||||
|
TriggerManager.unsubscribe_trigger(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
subscription=subscription.to_entity(),
|
||||||
|
credentials=subscription.credentials,
|
||||||
|
credential_type=credential_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a new subscription with the same subscription_id and endpoint_id
|
||||||
|
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
|
||||||
|
parameters=parameters,
|
||||||
|
credentials=credentials,
|
||||||
|
credential_type=credential_type,
|
||||||
|
)
|
||||||
|
TriggerProviderService.update_trigger_subscription(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
subscription_id=subscription.id,
|
||||||
|
name=name,
|
||||||
|
parameters=parameters,
|
||||||
|
credentials=credentials,
|
||||||
|
properties=new_subscription.properties,
|
||||||
|
expires_at=new_subscription.expires_at,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -453,11 +453,12 @@ class TriggerSubscriptionBuilderService:
|
||||||
if not subscription_builder:
|
if not subscription_builder:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# response to validation endpoint
|
|
||||||
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
|
||||||
tenant_id=subscription_builder.tenant_id, provider_id=TriggerProviderID(subscription_builder.provider_id)
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
|
# response to validation endpoint
|
||||||
|
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||||
|
tenant_id=subscription_builder.tenant_id,
|
||||||
|
provider_id=TriggerProviderID(subscription_builder.provider_id),
|
||||||
|
)
|
||||||
dispatch_response: TriggerDispatchResponse = controller.dispatch(
|
dispatch_response: TriggerDispatchResponse = controller.dispatch(
|
||||||
request=request,
|
request=request,
|
||||||
subscription=subscription_builder.to_subscription(),
|
subscription=subscription_builder.to_subscription(),
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import sqlalchemy as sa
|
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
|
@ -12,7 +11,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from models.dataset import Dataset, Document, DocumentSegment
|
from models.dataset import Dataset, Document, DocumentSegment
|
||||||
from models.source import DataSourceOauthBinding
|
from services.datasource_provider_service import DatasourceProviderService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -48,27 +47,36 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||||
page_id = data_source_info["notion_page_id"]
|
page_id = data_source_info["notion_page_id"]
|
||||||
page_type = data_source_info["type"]
|
page_type = data_source_info["type"]
|
||||||
page_edited_time = data_source_info["last_edited_time"]
|
page_edited_time = data_source_info["last_edited_time"]
|
||||||
|
credential_id = data_source_info.get("credential_id")
|
||||||
|
|
||||||
data_source_binding = (
|
# Get credentials from datasource provider
|
||||||
db.session.query(DataSourceOauthBinding)
|
datasource_provider_service = DatasourceProviderService()
|
||||||
.where(
|
credential = datasource_provider_service.get_datasource_credentials(
|
||||||
sa.and_(
|
tenant_id=document.tenant_id,
|
||||||
DataSourceOauthBinding.tenant_id == document.tenant_id,
|
credential_id=credential_id,
|
||||||
DataSourceOauthBinding.provider == "notion",
|
provider="notion_datasource",
|
||||||
DataSourceOauthBinding.disabled == False,
|
plugin_id="langgenius/notion_datasource",
|
||||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
if not data_source_binding:
|
|
||||||
raise ValueError("Data source binding not found.")
|
if not credential:
|
||||||
|
logger.error(
|
||||||
|
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
|
||||||
|
document_id,
|
||||||
|
document.tenant_id,
|
||||||
|
credential_id,
|
||||||
|
)
|
||||||
|
document.indexing_status = "error"
|
||||||
|
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
|
||||||
|
document.stopped_at = naive_utc_now()
|
||||||
|
db.session.commit()
|
||||||
|
db.session.close()
|
||||||
|
return
|
||||||
|
|
||||||
loader = NotionExtractor(
|
loader = NotionExtractor(
|
||||||
notion_workspace_id=workspace_id,
|
notion_workspace_id=workspace_id,
|
||||||
notion_obj_id=page_id,
|
notion_obj_id=page_id,
|
||||||
notion_page_type=page_type,
|
notion_page_type=page_type,
|
||||||
notion_access_token=data_source_binding.access_token,
|
notion_access_token=credential.get("integration_secret"),
|
||||||
tenant_id=document.tenant_id,
|
tenant_id=document.tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import pytest
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
|
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||||
|
|
@ -169,13 +170,14 @@ def test_custom_authorization_header(setup_http_mock):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||||
def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
|
def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
|
||||||
"""Test: In custom authentication mode, when the api_key is empty, no header should be set."""
|
"""Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised."""
|
||||||
from core.workflow.nodes.http_request.entities import (
|
from core.workflow.nodes.http_request.entities import (
|
||||||
HttpRequestNodeAuthorization,
|
HttpRequestNodeAuthorization,
|
||||||
HttpRequestNodeData,
|
HttpRequestNodeData,
|
||||||
HttpRequestNodeTimeout,
|
HttpRequestNodeTimeout,
|
||||||
)
|
)
|
||||||
|
from core.workflow.nodes.http_request.exc import AuthorizationConfigError
|
||||||
from core.workflow.nodes.http_request.executor import Executor
|
from core.workflow.nodes.http_request.executor import Executor
|
||||||
from core.workflow.runtime import VariablePool
|
from core.workflow.runtime import VariablePool
|
||||||
from core.workflow.system_variable import SystemVariable
|
from core.workflow.system_variable import SystemVariable
|
||||||
|
|
@ -208,16 +210,13 @@ def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
|
||||||
ssl_verify=True,
|
ssl_verify=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create executor
|
# Create executor should raise AuthorizationConfigError
|
||||||
executor = Executor(
|
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||||
node_data=node_data, timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), variable_pool=variable_pool
|
Executor(
|
||||||
)
|
node_data=node_data,
|
||||||
|
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10),
|
||||||
# Get assembled headers
|
variable_pool=variable_pool,
|
||||||
headers = executor._assembling_headers()
|
)
|
||||||
|
|
||||||
# When api_key is empty, the custom header should NOT be set
|
|
||||||
assert "X-Custom-Auth" not in headers
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||||
|
|
@ -305,9 +304,10 @@ def test_basic_authorization_with_custom_header_ignored(setup_http_mock):
|
||||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||||
def test_custom_authorization_with_empty_api_key(setup_http_mock):
|
def test_custom_authorization_with_empty_api_key(setup_http_mock):
|
||||||
"""
|
"""
|
||||||
Test that custom authorization doesn't set header when api_key is empty.
|
Test that custom authorization raises error when api_key is empty.
|
||||||
This test verifies the fix for issue #23554.
|
This test verifies the fix for issue #21830.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
node = init_http_node(
|
node = init_http_node(
|
||||||
config={
|
config={
|
||||||
"id": "1",
|
"id": "1",
|
||||||
|
|
@ -333,11 +333,10 @@ def test_custom_authorization_with_empty_api_key(setup_http_mock):
|
||||||
)
|
)
|
||||||
|
|
||||||
result = node._run()
|
result = node._run()
|
||||||
assert result.process_data is not None
|
# Should fail with AuthorizationConfigError
|
||||||
data = result.process_data.get("request", "")
|
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||||
|
assert "API key is required" in result.error
|
||||||
# Custom header should NOT be set when api_key is empty
|
assert result.error_type == "AuthorizationConfigError"
|
||||||
assert "X-Custom-Auth:" not in data
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||||
|
|
|
||||||
|
|
@ -1308,18 +1308,17 @@ class TestMCPToolManageService:
|
||||||
type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(),
|
type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(),
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
|
||||||
# Setup mock client
|
# Setup mock client
|
||||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||||
mock_client_instance.list_tools.return_value = mock_tools
|
mock_client_instance.list_tools.return_value = mock_tools
|
||||||
|
|
||||||
# Act: Execute the method under test
|
# Act: Execute the method under test
|
||||||
from extensions.ext_database import db
|
result = MCPToolManageService._reconnect_with_url(
|
||||||
|
|
||||||
service = MCPToolManageService(db.session())
|
|
||||||
result = service._reconnect_provider(
|
|
||||||
server_url="https://example.com/mcp",
|
server_url="https://example.com/mcp",
|
||||||
provider=mcp_provider,
|
headers={"X-Test": "1"},
|
||||||
|
timeout=mcp_provider.timeout,
|
||||||
|
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
|
|
@ -1337,8 +1336,12 @@ class TestMCPToolManageService:
|
||||||
assert tools_data[1]["name"] == "test_tool_2"
|
assert tools_data[1]["name"] == "test_tool_2"
|
||||||
|
|
||||||
# Verify mock interactions
|
# Verify mock interactions
|
||||||
provider_entity = mcp_provider.to_entity()
|
mock_mcp_client.assert_called_once_with(
|
||||||
mock_mcp_client.assert_called_once()
|
server_url="https://example.com/mcp",
|
||||||
|
headers={"X-Test": "1"},
|
||||||
|
timeout=mcp_provider.timeout,
|
||||||
|
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
|
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||||
"""
|
"""
|
||||||
|
|
@ -1361,19 +1364,18 @@ class TestMCPToolManageService:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock MCPClient to raise authentication error
|
# Mock MCPClient to raise authentication error
|
||||||
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
|
||||||
from core.mcp.error import MCPAuthError
|
from core.mcp.error import MCPAuthError
|
||||||
|
|
||||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||||
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
|
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
|
||||||
|
|
||||||
# Act: Execute the method under test
|
# Act: Execute the method under test
|
||||||
from extensions.ext_database import db
|
result = MCPToolManageService._reconnect_with_url(
|
||||||
|
|
||||||
service = MCPToolManageService(db.session())
|
|
||||||
result = service._reconnect_provider(
|
|
||||||
server_url="https://example.com/mcp",
|
server_url="https://example.com/mcp",
|
||||||
provider=mcp_provider,
|
headers={},
|
||||||
|
timeout=mcp_provider.timeout,
|
||||||
|
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert: Verify the expected outcomes
|
# Assert: Verify the expected outcomes
|
||||||
|
|
@ -1404,18 +1406,17 @@ class TestMCPToolManageService:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock MCPClient to raise connection error
|
# Mock MCPClient to raise connection error
|
||||||
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
|
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
|
||||||
from core.mcp.error import MCPError
|
from core.mcp.error import MCPError
|
||||||
|
|
||||||
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
|
||||||
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
|
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
|
||||||
|
|
||||||
# Act & Assert: Verify proper error handling
|
# Act & Assert: Verify proper error handling
|
||||||
from extensions.ext_database import db
|
|
||||||
|
|
||||||
service = MCPToolManageService(db.session())
|
|
||||||
with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"):
|
with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"):
|
||||||
service._reconnect_provider(
|
MCPToolManageService._reconnect_with_url(
|
||||||
server_url="https://example.com/mcp",
|
server_url="https://example.com/mcp",
|
||||||
provider=mcp_provider,
|
headers={"X-Test": "1"},
|
||||||
|
timeout=mcp_provider.timeout,
|
||||||
|
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,46 @@
|
||||||
|
from flask import Response
|
||||||
|
|
||||||
|
from controllers.common.file_response import enforce_download_for_html, is_html_content
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileResponseHelpers:
|
||||||
|
def test_is_html_content_detects_mime_type(self):
|
||||||
|
mime_type = "text/html; charset=UTF-8"
|
||||||
|
|
||||||
|
result = is_html_content(mime_type, filename="file.txt", extension="txt")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
def test_is_html_content_detects_extension(self):
|
||||||
|
result = is_html_content("text/plain", filename="report.html", extension=None)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
def test_enforce_download_for_html_sets_headers(self):
|
||||||
|
response = Response("payload", mimetype="text/html")
|
||||||
|
|
||||||
|
updated = enforce_download_for_html(
|
||||||
|
response,
|
||||||
|
mime_type="text/html",
|
||||||
|
filename="unsafe.html",
|
||||||
|
extension="html",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert updated is True
|
||||||
|
assert "attachment" in response.headers["Content-Disposition"]
|
||||||
|
assert response.headers["Content-Type"] == "application/octet-stream"
|
||||||
|
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
||||||
|
|
||||||
|
def test_enforce_download_for_html_no_change_for_non_html(self):
|
||||||
|
response = Response("payload", mimetype="text/plain")
|
||||||
|
|
||||||
|
updated = enforce_download_for_html(
|
||||||
|
response,
|
||||||
|
mime_type="text/plain",
|
||||||
|
filename="notes.txt",
|
||||||
|
extension="txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert updated is False
|
||||||
|
assert "Content-Disposition" not in response.headers
|
||||||
|
assert "X-Content-Type-Options" not in response.headers
|
||||||
|
|
@ -163,34 +163,17 @@ class TestActivateApi:
|
||||||
"account": mock_account,
|
"account": mock_account,
|
||||||
}
|
}
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_token_pair(self):
|
|
||||||
"""Create mock token pair object."""
|
|
||||||
token_pair = MagicMock()
|
|
||||||
token_pair.access_token = "access_token"
|
|
||||||
token_pair.refresh_token = "refresh_token"
|
|
||||||
token_pair.csrf_token = "csrf_token"
|
|
||||||
token_pair.model_dump.return_value = {
|
|
||||||
"access_token": "access_token",
|
|
||||||
"refresh_token": "refresh_token",
|
|
||||||
"csrf_token": "csrf_token",
|
|
||||||
}
|
|
||||||
return token_pair
|
|
||||||
|
|
||||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||||
@patch("controllers.console.auth.activate.db")
|
@patch("controllers.console.auth.activate.db")
|
||||||
@patch("controllers.console.auth.activate.AccountService.login")
|
|
||||||
def test_successful_account_activation(
|
def test_successful_account_activation(
|
||||||
self,
|
self,
|
||||||
mock_login,
|
|
||||||
mock_db,
|
mock_db,
|
||||||
mock_revoke_token,
|
mock_revoke_token,
|
||||||
mock_get_invitation,
|
mock_get_invitation,
|
||||||
app,
|
app,
|
||||||
mock_invitation,
|
mock_invitation,
|
||||||
mock_account,
|
mock_account,
|
||||||
mock_token_pair,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Test successful account activation.
|
Test successful account activation.
|
||||||
|
|
@ -198,12 +181,10 @@ class TestActivateApi:
|
||||||
Verifies that:
|
Verifies that:
|
||||||
- Account is activated with user preferences
|
- Account is activated with user preferences
|
||||||
- Account status is set to ACTIVE
|
- Account status is set to ACTIVE
|
||||||
- User is logged in after activation
|
|
||||||
- Invitation token is revoked
|
- Invitation token is revoked
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_get_invitation.return_value = mock_invitation
|
mock_get_invitation.return_value = mock_invitation
|
||||||
mock_login.return_value = mock_token_pair
|
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with app.test_request_context(
|
with app.test_request_context(
|
||||||
|
|
@ -230,7 +211,6 @@ class TestActivateApi:
|
||||||
assert mock_account.initialized_at is not None
|
assert mock_account.initialized_at is not None
|
||||||
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
|
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
|
||||||
mock_db.session.commit.assert_called_once()
|
mock_db.session.commit.assert_called_once()
|
||||||
mock_login.assert_called_once()
|
|
||||||
|
|
||||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||||
def test_activation_with_invalid_token(self, mock_get_invitation, app):
|
def test_activation_with_invalid_token(self, mock_get_invitation, app):
|
||||||
|
|
@ -264,17 +244,14 @@ class TestActivateApi:
|
||||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||||
@patch("controllers.console.auth.activate.db")
|
@patch("controllers.console.auth.activate.db")
|
||||||
@patch("controllers.console.auth.activate.AccountService.login")
|
|
||||||
def test_activation_sets_interface_theme(
|
def test_activation_sets_interface_theme(
|
||||||
self,
|
self,
|
||||||
mock_login,
|
|
||||||
mock_db,
|
mock_db,
|
||||||
mock_revoke_token,
|
mock_revoke_token,
|
||||||
mock_get_invitation,
|
mock_get_invitation,
|
||||||
app,
|
app,
|
||||||
mock_invitation,
|
mock_invitation,
|
||||||
mock_account,
|
mock_account,
|
||||||
mock_token_pair,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Test that activation sets default interface theme.
|
Test that activation sets default interface theme.
|
||||||
|
|
@ -284,7 +261,6 @@ class TestActivateApi:
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_get_invitation.return_value = mock_invitation
|
mock_get_invitation.return_value = mock_invitation
|
||||||
mock_login.return_value = mock_token_pair
|
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with app.test_request_context(
|
with app.test_request_context(
|
||||||
|
|
@ -317,17 +293,14 @@ class TestActivateApi:
|
||||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||||
@patch("controllers.console.auth.activate.db")
|
@patch("controllers.console.auth.activate.db")
|
||||||
@patch("controllers.console.auth.activate.AccountService.login")
|
|
||||||
def test_activation_with_different_locales(
|
def test_activation_with_different_locales(
|
||||||
self,
|
self,
|
||||||
mock_login,
|
|
||||||
mock_db,
|
mock_db,
|
||||||
mock_revoke_token,
|
mock_revoke_token,
|
||||||
mock_get_invitation,
|
mock_get_invitation,
|
||||||
app,
|
app,
|
||||||
mock_invitation,
|
mock_invitation,
|
||||||
mock_account,
|
mock_account,
|
||||||
mock_token_pair,
|
|
||||||
language,
|
language,
|
||||||
timezone,
|
timezone,
|
||||||
):
|
):
|
||||||
|
|
@ -341,7 +314,6 @@ class TestActivateApi:
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_get_invitation.return_value = mock_invitation
|
mock_get_invitation.return_value = mock_invitation
|
||||||
mock_login.return_value = mock_token_pair
|
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with app.test_request_context(
|
with app.test_request_context(
|
||||||
|
|
@ -367,27 +339,23 @@ class TestActivateApi:
|
||||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||||
@patch("controllers.console.auth.activate.db")
|
@patch("controllers.console.auth.activate.db")
|
||||||
@patch("controllers.console.auth.activate.AccountService.login")
|
def test_activation_returns_success_response(
|
||||||
def test_activation_returns_token_data(
|
|
||||||
self,
|
self,
|
||||||
mock_login,
|
|
||||||
mock_db,
|
mock_db,
|
||||||
mock_revoke_token,
|
mock_revoke_token,
|
||||||
mock_get_invitation,
|
mock_get_invitation,
|
||||||
app,
|
app,
|
||||||
mock_invitation,
|
mock_invitation,
|
||||||
mock_token_pair,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Test that activation returns authentication tokens.
|
Test that activation returns a success response without authentication tokens.
|
||||||
|
|
||||||
Verifies that:
|
Verifies that:
|
||||||
- Token pair is returned in response
|
- Response contains a success result
|
||||||
- All token types are included (access, refresh, csrf)
|
- No token data is returned
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_get_invitation.return_value = mock_invitation
|
mock_get_invitation.return_value = mock_invitation
|
||||||
mock_login.return_value = mock_token_pair
|
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with app.test_request_context(
|
with app.test_request_context(
|
||||||
|
|
@ -406,24 +374,18 @@ class TestActivateApi:
|
||||||
response = api.post()
|
response = api.post()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert "data" in response
|
assert response == {"result": "success"}
|
||||||
assert response["data"]["access_token"] == "access_token"
|
|
||||||
assert response["data"]["refresh_token"] == "refresh_token"
|
|
||||||
assert response["data"]["csrf_token"] == "csrf_token"
|
|
||||||
|
|
||||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||||
@patch("controllers.console.auth.activate.db")
|
@patch("controllers.console.auth.activate.db")
|
||||||
@patch("controllers.console.auth.activate.AccountService.login")
|
|
||||||
def test_activation_without_workspace_id(
|
def test_activation_without_workspace_id(
|
||||||
self,
|
self,
|
||||||
mock_login,
|
|
||||||
mock_db,
|
mock_db,
|
||||||
mock_revoke_token,
|
mock_revoke_token,
|
||||||
mock_get_invitation,
|
mock_get_invitation,
|
||||||
app,
|
app,
|
||||||
mock_invitation,
|
mock_invitation,
|
||||||
mock_token_pair,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Test account activation without workspace_id.
|
Test account activation without workspace_id.
|
||||||
|
|
@ -434,7 +396,6 @@ class TestActivateApi:
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_get_invitation.return_value = mock_invitation
|
mock_get_invitation.return_value = mock_invitation
|
||||||
mock_login.return_value = mock_token_pair
|
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with app.test_request_context(
|
with app.test_request_context(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,236 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import builtins
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask import Flask
|
||||||
|
from flask.views import MethodView as FlaskMethodView
|
||||||
|
|
||||||
|
_NEEDS_METHOD_VIEW_CLEANUP = False
|
||||||
|
if not hasattr(builtins, "MethodView"):
|
||||||
|
builtins.MethodView = FlaskMethodView
|
||||||
|
_NEEDS_METHOD_VIEW_CLEANUP = True
|
||||||
|
|
||||||
|
from constants import HIDDEN_VALUE
|
||||||
|
from controllers.console.extension import (
|
||||||
|
APIBasedExtensionAPI,
|
||||||
|
APIBasedExtensionDetailAPI,
|
||||||
|
CodeBasedExtensionAPI,
|
||||||
|
)
|
||||||
|
|
||||||
|
if _NEEDS_METHOD_VIEW_CLEANUP:
|
||||||
|
delattr(builtins, "MethodView")
|
||||||
|
from models.account import AccountStatus
|
||||||
|
from models.api_based_extension import APIBasedExtension
|
||||||
|
|
||||||
|
|
||||||
|
def _make_extension(
|
||||||
|
*,
|
||||||
|
name: str = "Sample Extension",
|
||||||
|
api_endpoint: str = "https://example.com/api",
|
||||||
|
api_key: str = "super-secret-key",
|
||||||
|
) -> APIBasedExtension:
|
||||||
|
extension = APIBasedExtension(
|
||||||
|
tenant_id="tenant-123",
|
||||||
|
name=name,
|
||||||
|
api_endpoint=api_endpoint,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
extension.id = f"{uuid.uuid4()}"
|
||||||
|
extension.created_at = datetime.now(tz=UTC)
|
||||||
|
return extension
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _mock_console_guards(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
|
||||||
|
"""Bypass console decorators so handlers can run in isolation."""
|
||||||
|
|
||||||
|
import controllers.console.extension as extension_module
|
||||||
|
from controllers.console import wraps as wraps_module
|
||||||
|
|
||||||
|
account = MagicMock()
|
||||||
|
account.status = AccountStatus.ACTIVE
|
||||||
|
account.current_tenant_id = "tenant-123"
|
||||||
|
account.id = "account-123"
|
||||||
|
account.is_authenticated = True
|
||||||
|
|
||||||
|
monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD")
|
||||||
|
monkeypatch.setattr("libs.login.dify_config.LOGIN_DISABLED", True)
|
||||||
|
monkeypatch.delenv("INIT_PASSWORD", raising=False)
|
||||||
|
monkeypatch.setattr(extension_module, "current_account_with_tenant", lambda: (account, "tenant-123"))
|
||||||
|
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (account, "tenant-123"))
|
||||||
|
|
||||||
|
# The login_required decorator consults the shared LocalProxy in libs.login.
|
||||||
|
monkeypatch.setattr("libs.login.current_user", account)
|
||||||
|
monkeypatch.setattr("libs.login.check_csrf_token", lambda *_, **__: None)
|
||||||
|
|
||||||
|
return account
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _restx_mask_defaults(app: Flask):
|
||||||
|
app.config.setdefault("RESTX_MASK_HEADER", "X-Fields")
|
||||||
|
app.config.setdefault("RESTX_MASK_SWAGGER", False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_code_based_extension_get_returns_service_data(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
service_result = {"entrypoint": "main:agent"}
|
||||||
|
service_mock = MagicMock(return_value=service_result)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"controllers.console.extension.CodeBasedExtensionService.get_code_based_extension",
|
||||||
|
service_mock,
|
||||||
|
)
|
||||||
|
|
||||||
|
with app.test_request_context(
|
||||||
|
"/console/api/code-based-extension",
|
||||||
|
method="GET",
|
||||||
|
query_string={"module": "workflow.tools"},
|
||||||
|
):
|
||||||
|
response = CodeBasedExtensionAPI().get()
|
||||||
|
|
||||||
|
assert response == {"module": "workflow.tools", "data": service_result}
|
||||||
|
service_mock.assert_called_once_with("workflow.tools")
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_based_extension_get_returns_tenant_extensions(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
extension = _make_extension(name="Weather API", api_key="abcdefghi123")
|
||||||
|
service_mock = MagicMock(return_value=[extension])
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"controllers.console.extension.APIBasedExtensionService.get_all_by_tenant_id",
|
||||||
|
service_mock,
|
||||||
|
)
|
||||||
|
|
||||||
|
with app.test_request_context("/console/api/api-based-extension", method="GET"):
|
||||||
|
response = APIBasedExtensionAPI().get()
|
||||||
|
|
||||||
|
assert response[0]["id"] == extension.id
|
||||||
|
assert response[0]["name"] == "Weather API"
|
||||||
|
assert response[0]["api_endpoint"] == extension.api_endpoint
|
||||||
|
assert response[0]["api_key"].startswith(extension.api_key[:3])
|
||||||
|
service_mock.assert_called_once_with("tenant-123")
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_based_extension_post_creates_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
saved_extension = _make_extension(name="Docs API", api_key="saved-secret")
|
||||||
|
save_mock = MagicMock(return_value=saved_extension)
|
||||||
|
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"name": "Docs API",
|
||||||
|
"api_endpoint": "https://docs.example.com/hook",
|
||||||
|
"api_key": "plain-secret",
|
||||||
|
}
|
||||||
|
|
||||||
|
with app.test_request_context("/console/api/api-based-extension", method="POST", json=payload):
|
||||||
|
response = APIBasedExtensionAPI().post()
|
||||||
|
|
||||||
|
args, _ = save_mock.call_args
|
||||||
|
created_extension: APIBasedExtension = args[0]
|
||||||
|
assert created_extension.tenant_id == "tenant-123"
|
||||||
|
assert created_extension.name == payload["name"]
|
||||||
|
assert created_extension.api_endpoint == payload["api_endpoint"]
|
||||||
|
assert created_extension.api_key == payload["api_key"]
|
||||||
|
assert response["name"] == saved_extension.name
|
||||||
|
save_mock.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_based_extension_detail_get_fetches_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
extension = _make_extension(name="Docs API", api_key="abcdefg12345")
|
||||||
|
service_mock = MagicMock(return_value=extension)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
|
||||||
|
service_mock,
|
||||||
|
)
|
||||||
|
|
||||||
|
extension_id = uuid.uuid4()
|
||||||
|
with app.test_request_context(f"/console/api/api-based-extension/{extension_id}", method="GET"):
|
||||||
|
response = APIBasedExtensionDetailAPI().get(extension_id)
|
||||||
|
|
||||||
|
assert response["id"] == extension.id
|
||||||
|
assert response["name"] == extension.name
|
||||||
|
service_mock.assert_called_once_with("tenant-123", str(extension_id))
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_based_extension_detail_post_keeps_hidden_api_key(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
existing_extension = _make_extension(name="Docs API", api_key="keep-me")
|
||||||
|
get_mock = MagicMock(return_value=existing_extension)
|
||||||
|
save_mock = MagicMock(return_value=existing_extension)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
|
||||||
|
get_mock,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"name": "Docs API Updated",
|
||||||
|
"api_endpoint": "https://docs.example.com/v2",
|
||||||
|
"api_key": HIDDEN_VALUE,
|
||||||
|
}
|
||||||
|
|
||||||
|
extension_id = uuid.uuid4()
|
||||||
|
with app.test_request_context(
|
||||||
|
f"/console/api/api-based-extension/{extension_id}",
|
||||||
|
method="POST",
|
||||||
|
json=payload,
|
||||||
|
):
|
||||||
|
response = APIBasedExtensionDetailAPI().post(extension_id)
|
||||||
|
|
||||||
|
assert existing_extension.name == payload["name"]
|
||||||
|
assert existing_extension.api_endpoint == payload["api_endpoint"]
|
||||||
|
assert existing_extension.api_key == "keep-me"
|
||||||
|
save_mock.assert_called_once_with(existing_extension)
|
||||||
|
assert response["name"] == payload["name"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_based_extension_detail_post_updates_api_key_when_provided(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
existing_extension = _make_extension(name="Docs API", api_key="old-secret")
|
||||||
|
get_mock = MagicMock(return_value=existing_extension)
|
||||||
|
save_mock = MagicMock(return_value=existing_extension)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
|
||||||
|
get_mock,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"name": "Docs API Updated",
|
||||||
|
"api_endpoint": "https://docs.example.com/v2",
|
||||||
|
"api_key": "new-secret",
|
||||||
|
}
|
||||||
|
|
||||||
|
extension_id = uuid.uuid4()
|
||||||
|
with app.test_request_context(
|
||||||
|
f"/console/api/api-based-extension/{extension_id}",
|
||||||
|
method="POST",
|
||||||
|
json=payload,
|
||||||
|
):
|
||||||
|
response = APIBasedExtensionDetailAPI().post(extension_id)
|
||||||
|
|
||||||
|
assert existing_extension.api_key == "new-secret"
|
||||||
|
save_mock.assert_called_once_with(existing_extension)
|
||||||
|
assert response["name"] == payload["name"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_based_extension_detail_delete_removes_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
existing_extension = _make_extension()
|
||||||
|
get_mock = MagicMock(return_value=existing_extension)
|
||||||
|
delete_mock = MagicMock()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
|
||||||
|
get_mock,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.delete", delete_mock)
|
||||||
|
|
||||||
|
extension_id = uuid.uuid4()
|
||||||
|
with app.test_request_context(
|
||||||
|
f"/console/api/api-based-extension/{extension_id}",
|
||||||
|
method="DELETE",
|
||||||
|
):
|
||||||
|
response, status = APIBasedExtensionDetailAPI().delete(extension_id)
|
||||||
|
|
||||||
|
delete_mock.assert_called_once_with(existing_extension)
|
||||||
|
assert response == {"result": "success"}
|
||||||
|
assert status == 204
|
||||||
|
|
@ -41,6 +41,7 @@ class TestFilePreviewApi:
|
||||||
upload_file = Mock(spec=UploadFile)
|
upload_file = Mock(spec=UploadFile)
|
||||||
upload_file.id = str(uuid.uuid4())
|
upload_file.id = str(uuid.uuid4())
|
||||||
upload_file.name = "test_file.jpg"
|
upload_file.name = "test_file.jpg"
|
||||||
|
upload_file.extension = "jpg"
|
||||||
upload_file.mime_type = "image/jpeg"
|
upload_file.mime_type = "image/jpeg"
|
||||||
upload_file.size = 1024
|
upload_file.size = 1024
|
||||||
upload_file.key = "storage/key/test_file.jpg"
|
upload_file.key = "storage/key/test_file.jpg"
|
||||||
|
|
@ -210,6 +211,19 @@ class TestFilePreviewApi:
|
||||||
assert mock_upload_file.name in response.headers["Content-Disposition"]
|
assert mock_upload_file.name in response.headers["Content-Disposition"]
|
||||||
assert response.headers["Content-Type"] == "application/octet-stream"
|
assert response.headers["Content-Type"] == "application/octet-stream"
|
||||||
|
|
||||||
|
def test_build_file_response_html_forces_attachment(self, file_preview_api, mock_upload_file):
|
||||||
|
"""Test HTML files are forced to download"""
|
||||||
|
mock_generator = Mock()
|
||||||
|
mock_upload_file.mime_type = "text/html"
|
||||||
|
mock_upload_file.name = "unsafe.html"
|
||||||
|
mock_upload_file.extension = "html"
|
||||||
|
|
||||||
|
response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False)
|
||||||
|
|
||||||
|
assert "attachment" in response.headers["Content-Disposition"]
|
||||||
|
assert response.headers["Content-Type"] == "application/octet-stream"
|
||||||
|
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
||||||
|
|
||||||
def test_build_file_response_audio_video(self, file_preview_api, mock_upload_file):
|
def test_build_file_response_audio_video(self, file_preview_api, mock_upload_file):
|
||||||
"""Test file response building for audio/video files"""
|
"""Test file response building for audio/video files"""
|
||||||
mock_generator = Mock()
|
mock_generator = Mock()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,195 @@
|
||||||
|
"""Unit tests for controllers.web.forgot_password endpoints."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import builtins
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask import Flask
|
||||||
|
from flask.views import MethodView
|
||||||
|
|
||||||
|
# Ensure flask_restx.api finds MethodView during import.
|
||||||
|
if not hasattr(builtins, "MethodView"):
|
||||||
|
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_controller_module():
|
||||||
|
"""Import controllers.web.forgot_password using a stub package."""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import importlib.util
|
||||||
|
import sys
|
||||||
|
from types import ModuleType
|
||||||
|
|
||||||
|
parent_module_name = "controllers.web"
|
||||||
|
module_name = f"{parent_module_name}.forgot_password"
|
||||||
|
|
||||||
|
if parent_module_name not in sys.modules:
|
||||||
|
from flask_restx import Namespace
|
||||||
|
|
||||||
|
stub = ModuleType(parent_module_name)
|
||||||
|
stub.__file__ = "controllers/web/__init__.py"
|
||||||
|
stub.__path__ = ["controllers/web"]
|
||||||
|
stub.__package__ = "controllers"
|
||||||
|
stub.__spec__ = importlib.util.spec_from_loader(parent_module_name, loader=None, is_package=True)
|
||||||
|
stub.web_ns = Namespace("web", description="Web API", path="/")
|
||||||
|
sys.modules[parent_module_name] = stub
|
||||||
|
|
||||||
|
return importlib.import_module(module_name)
|
||||||
|
|
||||||
|
|
||||||
|
forgot_password_module = _load_controller_module()
|
||||||
|
ForgotPasswordCheckApi = forgot_password_module.ForgotPasswordCheckApi
|
||||||
|
ForgotPasswordResetApi = forgot_password_module.ForgotPasswordResetApi
|
||||||
|
ForgotPasswordSendEmailApi = forgot_password_module.ForgotPasswordSendEmailApi
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app() -> Flask:
|
||||||
|
"""Configure a minimal Flask app for request contexts."""
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.config["TESTING"] = True
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _enable_web_endpoint_guards():
|
||||||
|
"""Stub enterprise and feature toggles used by route decorators."""
|
||||||
|
|
||||||
|
features = SimpleNamespace(enable_email_password_login=True)
|
||||||
|
with (
|
||||||
|
patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True),
|
||||||
|
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
|
||||||
|
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=features),
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _mock_controller_db():
|
||||||
|
"""Replace controller-level db reference with a simple stub."""
|
||||||
|
|
||||||
|
fake_db = SimpleNamespace(engine=MagicMock(name="engine"))
|
||||||
|
fake_wraps_db = SimpleNamespace(
|
||||||
|
session=MagicMock(query=MagicMock(return_value=MagicMock(first=MagicMock(return_value=True))))
|
||||||
|
)
|
||||||
|
with (
|
||||||
|
patch("controllers.web.forgot_password.db", fake_db),
|
||||||
|
patch("controllers.console.wraps.db", fake_wraps_db),
|
||||||
|
):
|
||||||
|
yield fake_db
|
||||||
|
|
||||||
|
|
||||||
|
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email", return_value="reset-token")
|
||||||
|
@patch("controllers.web.forgot_password.Session")
|
||||||
|
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||||
|
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="203.0.113.10")
|
||||||
|
def test_send_reset_email_success(
|
||||||
|
mock_extract_ip: MagicMock,
|
||||||
|
mock_is_ip_limit: MagicMock,
|
||||||
|
mock_session: MagicMock,
|
||||||
|
mock_send_email: MagicMock,
|
||||||
|
app: Flask,
|
||||||
|
):
|
||||||
|
"""POST /forgot-password returns token when email exists and limits allow."""
|
||||||
|
|
||||||
|
mock_account = MagicMock()
|
||||||
|
session_ctx = MagicMock()
|
||||||
|
mock_session.return_value.__enter__.return_value = session_ctx
|
||||||
|
session_ctx.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||||
|
|
||||||
|
with app.test_request_context(
|
||||||
|
"/forgot-password",
|
||||||
|
method="POST",
|
||||||
|
json={"email": "user@example.com"},
|
||||||
|
):
|
||||||
|
response = ForgotPasswordSendEmailApi().post()
|
||||||
|
|
||||||
|
assert response == {"result": "success", "data": "reset-token"}
|
||||||
|
mock_extract_ip.assert_called_once()
|
||||||
|
mock_is_ip_limit.assert_called_once_with("203.0.113.10")
|
||||||
|
mock_send_email.assert_called_once_with(account=mock_account, email="user@example.com", language="en-US")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||||
|
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token", return_value=({}, "new-token"))
|
||||||
|
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||||
|
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||||
|
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit", return_value=False)
|
||||||
|
def test_check_token_success(
|
||||||
|
mock_is_rate_limited: MagicMock,
|
||||||
|
mock_get_data: MagicMock,
|
||||||
|
mock_revoke: MagicMock,
|
||||||
|
mock_generate: MagicMock,
|
||||||
|
mock_reset_limit: MagicMock,
|
||||||
|
app: Flask,
|
||||||
|
):
|
||||||
|
"""POST /forgot-password/validity validates the code and refreshes token."""
|
||||||
|
|
||||||
|
mock_get_data.return_value = {"email": "user@example.com", "code": "123456"}
|
||||||
|
|
||||||
|
with app.test_request_context(
|
||||||
|
"/forgot-password/validity",
|
||||||
|
method="POST",
|
||||||
|
json={"email": "user@example.com", "code": "123456", "token": "old-token"},
|
||||||
|
):
|
||||||
|
response = ForgotPasswordCheckApi().post()
|
||||||
|
|
||||||
|
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
|
||||||
|
mock_is_rate_limited.assert_called_once_with("user@example.com")
|
||||||
|
mock_get_data.assert_called_once_with("old-token")
|
||||||
|
mock_revoke.assert_called_once_with("old-token")
|
||||||
|
mock_generate.assert_called_once_with(
|
||||||
|
"user@example.com",
|
||||||
|
code="123456",
|
||||||
|
additional_data={"phase": "reset"},
|
||||||
|
)
|
||||||
|
mock_reset_limit.assert_called_once_with("user@example.com")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
|
||||||
|
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
|
||||||
|
@patch("controllers.web.forgot_password.Session")
|
||||||
|
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||||
|
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||||
|
def test_reset_password_success(
|
||||||
|
mock_get_data: MagicMock,
|
||||||
|
mock_revoke_token: MagicMock,
|
||||||
|
mock_session: MagicMock,
|
||||||
|
mock_token_bytes: MagicMock,
|
||||||
|
mock_hash_password: MagicMock,
|
||||||
|
app: Flask,
|
||||||
|
):
|
||||||
|
"""POST /forgot-password/resets updates the stored password when token is valid."""
|
||||||
|
|
||||||
|
mock_get_data.return_value = {"email": "user@example.com", "phase": "reset"}
|
||||||
|
account = MagicMock()
|
||||||
|
session_ctx = MagicMock()
|
||||||
|
mock_session.return_value.__enter__.return_value = session_ctx
|
||||||
|
session_ctx.execute.return_value.scalar_one_or_none.return_value = account
|
||||||
|
|
||||||
|
with app.test_request_context(
|
||||||
|
"/forgot-password/resets",
|
||||||
|
method="POST",
|
||||||
|
json={
|
||||||
|
"token": "reset-token",
|
||||||
|
"new_password": "StrongPass123!",
|
||||||
|
"password_confirm": "StrongPass123!",
|
||||||
|
},
|
||||||
|
):
|
||||||
|
response = ForgotPasswordResetApi().post()
|
||||||
|
|
||||||
|
assert response == {"result": "success"}
|
||||||
|
mock_get_data.assert_called_once_with("reset-token")
|
||||||
|
mock_revoke_token.assert_called_once_with("reset-token")
|
||||||
|
mock_token_bytes.assert_called_once_with(16)
|
||||||
|
mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef")
|
||||||
|
expected_password = base64.b64encode(b"hashed-value").decode()
|
||||||
|
assert account.password == expected_password
|
||||||
|
expected_salt = base64.b64encode(b"0123456789abcdef").decode()
|
||||||
|
assert account.password_salt == expected_salt
|
||||||
|
session_ctx.commit.assert_called_once()
|
||||||
|
|
@ -287,7 +287,7 @@ def test_validate_inputs_optional_file_with_empty_string():
|
||||||
|
|
||||||
|
|
||||||
def test_validate_inputs_optional_file_list_with_empty_list():
|
def test_validate_inputs_optional_file_list_with_empty_list():
|
||||||
"""Test that optional FILE_LIST variable with empty list returns None"""
|
"""Test that optional FILE_LIST variable with empty list returns empty list (not None)"""
|
||||||
base_app_generator = BaseAppGenerator()
|
base_app_generator = BaseAppGenerator()
|
||||||
|
|
||||||
var_file_list = VariableEntity(
|
var_file_list = VariableEntity(
|
||||||
|
|
@ -302,6 +302,28 @@ def test_validate_inputs_optional_file_list_with_empty_list():
|
||||||
value=[],
|
value=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Empty list should be preserved, not converted to None
|
||||||
|
# This allows downstream components like document_extractor to handle empty lists properly
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_inputs_optional_file_list_with_empty_string():
|
||||||
|
"""Test that optional FILE_LIST variable with empty string returns None"""
|
||||||
|
base_app_generator = BaseAppGenerator()
|
||||||
|
|
||||||
|
var_file_list = VariableEntity(
|
||||||
|
variable="test_file_list",
|
||||||
|
label="test_file_list",
|
||||||
|
type=VariableEntityType.FILE_LIST,
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = base_app_generator._validate_inputs(
|
||||||
|
variable_entity=var_file_list,
|
||||||
|
value="",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Empty string should be treated as unset
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -96,7 +96,7 @@ class TestNotionExtractorAuthentication:
|
||||||
def test_init_with_integration_token_fallback(self, mock_get_token, mock_config, mock_document_model):
|
def test_init_with_integration_token_fallback(self, mock_get_token, mock_config, mock_document_model):
|
||||||
"""Test NotionExtractor falls back to integration token when credential not found."""
|
"""Test NotionExtractor falls back to integration token when credential not found."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_get_token.return_value = None
|
mock_get_token.side_effect = Exception("No credential id found")
|
||||||
mock_config.NOTION_INTEGRATION_TOKEN = "integration-token-fallback"
|
mock_config.NOTION_INTEGRATION_TOKEN = "integration-token-fallback"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
|
@ -105,7 +105,7 @@ class TestNotionExtractorAuthentication:
|
||||||
notion_obj_id="page-456",
|
notion_obj_id="page-456",
|
||||||
notion_page_type="page",
|
notion_page_type="page",
|
||||||
tenant_id="tenant-789",
|
tenant_id="tenant-789",
|
||||||
credential_id="cred-123",
|
credential_id=None,
|
||||||
document_model=mock_document_model,
|
document_model=mock_document_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -117,7 +117,7 @@ class TestNotionExtractorAuthentication:
|
||||||
def test_init_missing_credentials_raises_error(self, mock_get_token, mock_config, mock_document_model):
|
def test_init_missing_credentials_raises_error(self, mock_get_token, mock_config, mock_document_model):
|
||||||
"""Test NotionExtractor raises error when no credentials available."""
|
"""Test NotionExtractor raises error when no credentials available."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_get_token.return_value = None
|
mock_get_token.side_effect = Exception("No credential id found")
|
||||||
mock_config.NOTION_INTEGRATION_TOKEN = None
|
mock_config.NOTION_INTEGRATION_TOKEN = None
|
||||||
|
|
||||||
# Act & Assert
|
# Act & Assert
|
||||||
|
|
@ -127,7 +127,7 @@ class TestNotionExtractorAuthentication:
|
||||||
notion_obj_id="page-456",
|
notion_obj_id="page-456",
|
||||||
notion_page_type="page",
|
notion_page_type="page",
|
||||||
tenant_id="tenant-789",
|
tenant_id="tenant-789",
|
||||||
credential_id="cred-123",
|
credential_id=None,
|
||||||
document_model=mock_document_model,
|
document_model=mock_document_model,
|
||||||
)
|
)
|
||||||
assert "Must specify `integration_token`" in str(exc_info.value)
|
assert "Must specify `integration_token`" in str(exc_info.value)
|
||||||
|
|
|
||||||
|
|
@ -1,52 +1,109 @@
|
||||||
import secrets
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
|
from core.helper.ssrf_proxy import (
|
||||||
|
SSRF_DEFAULT_MAX_RETRIES,
|
||||||
|
_get_user_provided_host_header,
|
||||||
|
make_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@patch("httpx.Client.request")
|
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||||
def test_successful_request(mock_request):
|
def test_successful_request(mock_get_client):
|
||||||
|
mock_client = MagicMock()
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_request.return_value = mock_response
|
mock_client.send.return_value = mock_response
|
||||||
|
mock_client.request.return_value = mock_response
|
||||||
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
response = make_request("GET", "http://example.com")
|
response = make_request("GET", "http://example.com")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
@patch("httpx.Client.request")
|
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||||
def test_retry_exceed_max_retries(mock_request):
|
def test_retry_exceed_max_retries(mock_get_client):
|
||||||
|
mock_client = MagicMock()
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status_code = 500
|
mock_response.status_code = 500
|
||||||
|
mock_client.send.return_value = mock_response
|
||||||
side_effects = [mock_response] * SSRF_DEFAULT_MAX_RETRIES
|
mock_client.request.return_value = mock_response
|
||||||
mock_request.side_effect = side_effects
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
with pytest.raises(Exception) as e:
|
with pytest.raises(Exception) as e:
|
||||||
make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1)
|
make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1)
|
||||||
assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
|
assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
|
||||||
|
|
||||||
|
|
||||||
@patch("httpx.Client.request")
|
class TestGetUserProvidedHostHeader:
|
||||||
def test_retry_logic_success(mock_request):
|
"""Tests for _get_user_provided_host_header function."""
|
||||||
side_effects = []
|
|
||||||
|
|
||||||
for _ in range(SSRF_DEFAULT_MAX_RETRIES):
|
def test_returns_none_when_headers_is_none(self):
|
||||||
status_code = secrets.choice(STATUS_FORCELIST)
|
assert _get_user_provided_host_header(None) is None
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.status_code = status_code
|
|
||||||
side_effects.append(mock_response)
|
|
||||||
|
|
||||||
mock_response_200 = MagicMock()
|
def test_returns_none_when_headers_is_empty(self):
|
||||||
mock_response_200.status_code = 200
|
assert _get_user_provided_host_header({}) is None
|
||||||
side_effects.append(mock_response_200)
|
|
||||||
|
|
||||||
mock_request.side_effect = side_effects
|
def test_returns_none_when_host_header_not_present(self):
|
||||||
|
headers = {"Content-Type": "application/json", "Authorization": "Bearer token"}
|
||||||
|
assert _get_user_provided_host_header(headers) is None
|
||||||
|
|
||||||
response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES)
|
def test_returns_host_header_lowercase(self):
|
||||||
|
headers = {"host": "example.com"}
|
||||||
|
assert _get_user_provided_host_header(headers) == "example.com"
|
||||||
|
|
||||||
|
def test_returns_host_header_uppercase(self):
|
||||||
|
headers = {"HOST": "example.com"}
|
||||||
|
assert _get_user_provided_host_header(headers) == "example.com"
|
||||||
|
|
||||||
|
def test_returns_host_header_mixed_case(self):
|
||||||
|
headers = {"HoSt": "example.com"}
|
||||||
|
assert _get_user_provided_host_header(headers) == "example.com"
|
||||||
|
|
||||||
|
def test_returns_host_header_from_multiple_headers(self):
|
||||||
|
headers = {"Content-Type": "application/json", "Host": "api.example.com", "Authorization": "Bearer token"}
|
||||||
|
assert _get_user_provided_host_header(headers) == "api.example.com"
|
||||||
|
|
||||||
|
def test_returns_first_host_header_when_duplicates(self):
|
||||||
|
headers = {"host": "first.com", "Host": "second.com"}
|
||||||
|
# Should return the first one encountered (iteration order is preserved in dict)
|
||||||
|
result = _get_user_provided_host_header(headers)
|
||||||
|
assert result in ("first.com", "second.com")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||||
|
def test_host_header_preservation_without_user_header(mock_get_client):
|
||||||
|
"""Test that when no Host header is provided, the default behavior is maintained."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.headers = {}
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_client.send.return_value = mock_response
|
||||||
|
mock_client.request.return_value = mock_response
|
||||||
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
response = make_request("GET", "http://example.com")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
# Host should not be set if not provided by user
|
||||||
|
assert "Host" not in mock_request.headers or mock_request.headers.get("Host") is None
|
||||||
|
|
||||||
|
|
||||||
|
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||||
|
def test_host_header_preservation_with_user_header(mock_get_client):
|
||||||
|
"""Test that user-provided Host header is preserved in the request."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.headers = {}
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_client.send.return_value = mock_response
|
||||||
|
mock_client.request.return_value = mock_response
|
||||||
|
mock_get_client.return_value = mock_client
|
||||||
|
|
||||||
|
custom_host = "custom.example.com:8080"
|
||||||
|
response = make_request("GET", "http://example.com", headers={"Host": custom_host})
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
|
|
||||||
assert mock_request.call_args_list[0][1].get("method") == "GET"
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,9 @@
|
||||||
|
import re
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
|
from core.ops.utils import generate_dotted_order, validate_project_name, validate_url, validate_url_with_path
|
||||||
|
|
||||||
|
|
||||||
class TestValidateUrl:
|
class TestValidateUrl:
|
||||||
|
|
@ -136,3 +139,51 @@ class TestValidateProjectName:
|
||||||
"""Test custom default name"""
|
"""Test custom default name"""
|
||||||
result = validate_project_name("", "Custom Default")
|
result = validate_project_name("", "Custom Default")
|
||||||
assert result == "Custom Default"
|
assert result == "Custom Default"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateDottedOrder:
|
||||||
|
"""Test cases for generate_dotted_order function"""
|
||||||
|
|
||||||
|
def test_dotted_order_has_6_digit_microseconds(self):
|
||||||
|
"""Test that timestamp includes full 6-digit microseconds for LangSmith API compatibility.
|
||||||
|
|
||||||
|
LangSmith API expects timestamps in format: YYYYMMDDTHHMMSSffffffZ (6-digit microseconds).
|
||||||
|
Previously, the code truncated to 3 digits which caused API errors:
|
||||||
|
'cannot parse .111 as .000000'
|
||||||
|
"""
|
||||||
|
start_time = datetime(2025, 12, 23, 4, 19, 55, 111000)
|
||||||
|
run_id = "test-run-id"
|
||||||
|
result = generate_dotted_order(run_id, start_time)
|
||||||
|
|
||||||
|
# Extract timestamp portion (before the run_id)
|
||||||
|
timestamp_match = re.match(r"^(\d{8}T\d{6})(\d+)Z", result)
|
||||||
|
assert timestamp_match is not None, "Timestamp format should match YYYYMMDDTHHMMSSffffffZ"
|
||||||
|
|
||||||
|
microseconds = timestamp_match.group(2)
|
||||||
|
assert len(microseconds) == 6, f"Microseconds should be 6 digits, got {len(microseconds)}: {microseconds}"
|
||||||
|
|
||||||
|
def test_dotted_order_format_matches_langsmith_expected(self):
|
||||||
|
"""Test that dotted_order format matches LangSmith API expected format."""
|
||||||
|
start_time = datetime(2025, 1, 15, 10, 30, 45, 123456)
|
||||||
|
run_id = "abc123"
|
||||||
|
result = generate_dotted_order(run_id, start_time)
|
||||||
|
|
||||||
|
# LangSmith expects: YYYYMMDDTHHMMSSffffffZ followed by run_id
|
||||||
|
assert result == "20250115T103045123456Zabc123"
|
||||||
|
|
||||||
|
def test_dotted_order_with_parent(self):
|
||||||
|
"""Test dotted_order generation with parent order uses dot separator."""
|
||||||
|
start_time = datetime(2025, 12, 23, 4, 19, 55, 111000)
|
||||||
|
run_id = "child-run-id"
|
||||||
|
parent_order = "20251223T041955000000Zparent-run-id"
|
||||||
|
result = generate_dotted_order(run_id, start_time, parent_order)
|
||||||
|
|
||||||
|
assert result == "20251223T041955000000Zparent-run-id.20251223T041955111000Zchild-run-id"
|
||||||
|
|
||||||
|
def test_dotted_order_without_parent_has_no_dot(self):
|
||||||
|
"""Test dotted_order generation without parent has no dot separator."""
|
||||||
|
start_time = datetime(2025, 12, 23, 4, 19, 55, 111000)
|
||||||
|
run_id = "test-run-id"
|
||||||
|
result = generate_dotted_order(run_id, start_time, None)
|
||||||
|
|
||||||
|
assert "." not in result
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
import os
|
import os
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||||
|
|
@ -25,3 +27,35 @@ def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
|
||||||
|
|
||||||
assert job_id is not None
|
assert job_id is not None
|
||||||
assert isinstance(job_id, str)
|
assert isinstance(job_id, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_url_normalizes_slashes_for_crawl(mocker: MockerFixture):
|
||||||
|
api_key = "fc-"
|
||||||
|
base_urls = ["https://custom.firecrawl.dev", "https://custom.firecrawl.dev/"]
|
||||||
|
for base in base_urls:
|
||||||
|
app = FirecrawlApp(api_key=api_key, base_url=base)
|
||||||
|
mock_post = mocker.patch("httpx.post")
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.json.return_value = {"id": "job123"}
|
||||||
|
mock_post.return_value = mock_resp
|
||||||
|
app.crawl_url("https://example.com", params=None)
|
||||||
|
called_url = mock_post.call_args[0][0]
|
||||||
|
assert called_url == "https://custom.firecrawl.dev/v2/crawl"
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_handler_handles_non_json_error_bodies(mocker: MockerFixture):
|
||||||
|
api_key = "fc-"
|
||||||
|
app = FirecrawlApp(api_key=api_key, base_url="https://custom.firecrawl.dev/")
|
||||||
|
mock_post = mocker.patch("httpx.post")
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 404
|
||||||
|
mock_resp.text = "Not Found"
|
||||||
|
mock_resp.json.side_effect = Exception("Not JSON")
|
||||||
|
mock_post.return_value = mock_resp
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as excinfo:
|
||||||
|
app.scrape_url("https://example.com")
|
||||||
|
|
||||||
|
# Should not raise a JSONDecodeError; current behavior reports status code only
|
||||||
|
assert str(excinfo.value) == "Failed to scrape URL. Status code: 404"
|
||||||
|
|
|
||||||
|
|
@ -132,3 +132,36 @@ def test_extract_images_from_docx(monkeypatch):
|
||||||
# DB interactions should be recorded
|
# DB interactions should be recorded
|
||||||
assert len(db_stub.session.added) == 2
|
assert len(db_stub.session.added) == 2
|
||||||
assert db_stub.session.committed is True
|
assert db_stub.session.committed is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_images_from_docx_uses_internal_files_url():
|
||||||
|
"""Test that INTERNAL_FILES_URL takes precedence over FILES_URL for plugin access."""
|
||||||
|
# Test the URL generation logic directly
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
|
# Mock the configuration values
|
||||||
|
original_files_url = getattr(dify_config, "FILES_URL", None)
|
||||||
|
original_internal_files_url = getattr(dify_config, "INTERNAL_FILES_URL", None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set both URLs - INTERNAL should take precedence
|
||||||
|
dify_config.FILES_URL = "http://external.example.com"
|
||||||
|
dify_config.INTERNAL_FILES_URL = "http://internal.docker:5001"
|
||||||
|
|
||||||
|
# Test the URL generation logic (same as in word_extractor.py)
|
||||||
|
upload_file_id = "test_file_id"
|
||||||
|
|
||||||
|
# This is the pattern we fixed in the word extractor
|
||||||
|
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||||
|
generated_url = f"{base_url}/files/{upload_file_id}/file-preview"
|
||||||
|
|
||||||
|
# Verify that INTERNAL_FILES_URL is used instead of FILES_URL
|
||||||
|
assert "http://internal.docker:5001" in generated_url, f"Expected internal URL, got: {generated_url}"
|
||||||
|
assert "http://external.example.com" not in generated_url, f"Should not use external URL, got: {generated_url}"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original values
|
||||||
|
if original_files_url is not None:
|
||||||
|
dify_config.FILES_URL = original_files_url
|
||||||
|
if original_internal_files_url is not None:
|
||||||
|
dify_config.INTERNAL_FILES_URL = original_internal_files_url
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
from core.workflow.nodes.http_request import (
|
from core.workflow.nodes.http_request import (
|
||||||
BodyData,
|
BodyData,
|
||||||
HttpRequestNodeAuthorization,
|
HttpRequestNodeAuthorization,
|
||||||
|
|
@ -5,6 +7,7 @@ from core.workflow.nodes.http_request import (
|
||||||
HttpRequestNodeData,
|
HttpRequestNodeData,
|
||||||
)
|
)
|
||||||
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
|
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
|
||||||
|
from core.workflow.nodes.http_request.exc import AuthorizationConfigError
|
||||||
from core.workflow.nodes.http_request.executor import Executor
|
from core.workflow.nodes.http_request.executor import Executor
|
||||||
from core.workflow.runtime import VariablePool
|
from core.workflow.runtime import VariablePool
|
||||||
from core.workflow.system_variable import SystemVariable
|
from core.workflow.system_variable import SystemVariable
|
||||||
|
|
@ -348,3 +351,127 @@ def test_init_params():
|
||||||
executor = create_executor("key1:value1\n\nkey2:value2\n\n")
|
executor = create_executor("key1:value1\n\nkey2:value2\n\n")
|
||||||
executor._init_params()
|
executor._init_params()
|
||||||
assert executor.params == [("key1", "value1"), ("key2", "value2")]
|
assert executor.params == [("key1", "value1"), ("key2", "value2")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_api_key_raises_error_bearer():
|
||||||
|
"""Test that empty API key raises AuthorizationConfigError for bearer auth."""
|
||||||
|
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||||
|
node_data = HttpRequestNodeData(
|
||||||
|
title="test",
|
||||||
|
method="get",
|
||||||
|
url="http://example.com",
|
||||||
|
headers="",
|
||||||
|
params="",
|
||||||
|
authorization=HttpRequestNodeAuthorization(
|
||||||
|
type="api-key",
|
||||||
|
config={"type": "bearer", "api_key": ""},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||||
|
|
||||||
|
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||||
|
Executor(
|
||||||
|
node_data=node_data,
|
||||||
|
timeout=timeout,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_api_key_raises_error_basic():
|
||||||
|
"""Test that empty API key raises AuthorizationConfigError for basic auth."""
|
||||||
|
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||||
|
node_data = HttpRequestNodeData(
|
||||||
|
title="test",
|
||||||
|
method="get",
|
||||||
|
url="http://example.com",
|
||||||
|
headers="",
|
||||||
|
params="",
|
||||||
|
authorization=HttpRequestNodeAuthorization(
|
||||||
|
type="api-key",
|
||||||
|
config={"type": "basic", "api_key": ""},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||||
|
|
||||||
|
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||||
|
Executor(
|
||||||
|
node_data=node_data,
|
||||||
|
timeout=timeout,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_api_key_raises_error_custom():
|
||||||
|
"""Test that empty API key raises AuthorizationConfigError for custom auth."""
|
||||||
|
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||||
|
node_data = HttpRequestNodeData(
|
||||||
|
title="test",
|
||||||
|
method="get",
|
||||||
|
url="http://example.com",
|
||||||
|
headers="",
|
||||||
|
params="",
|
||||||
|
authorization=HttpRequestNodeAuthorization(
|
||||||
|
type="api-key",
|
||||||
|
config={"type": "custom", "api_key": "", "header": "X-Custom-Auth"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||||
|
|
||||||
|
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||||
|
Executor(
|
||||||
|
node_data=node_data,
|
||||||
|
timeout=timeout,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_whitespace_only_api_key_raises_error():
|
||||||
|
"""Test that whitespace-only API key raises AuthorizationConfigError."""
|
||||||
|
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||||
|
node_data = HttpRequestNodeData(
|
||||||
|
title="test",
|
||||||
|
method="get",
|
||||||
|
url="http://example.com",
|
||||||
|
headers="",
|
||||||
|
params="",
|
||||||
|
authorization=HttpRequestNodeAuthorization(
|
||||||
|
type="api-key",
|
||||||
|
config={"type": "bearer", "api_key": " "},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||||
|
|
||||||
|
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||||
|
Executor(
|
||||||
|
node_data=node_data,
|
||||||
|
timeout=timeout,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid_api_key_works():
|
||||||
|
"""Test that valid API key works correctly for bearer auth."""
|
||||||
|
variable_pool = VariablePool(system_variables=SystemVariable.empty())
|
||||||
|
node_data = HttpRequestNodeData(
|
||||||
|
title="test",
|
||||||
|
method="get",
|
||||||
|
url="http://example.com",
|
||||||
|
headers="",
|
||||||
|
params="",
|
||||||
|
authorization=HttpRequestNodeAuthorization(
|
||||||
|
type="api-key",
|
||||||
|
config={"type": "bearer", "api_key": "valid-api-key-123"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
|
||||||
|
|
||||||
|
executor = Executor(
|
||||||
|
node_data=node_data,
|
||||||
|
timeout=timeout,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not raise an error
|
||||||
|
headers = executor._assembling_headers()
|
||||||
|
assert "Authorization" in headers
|
||||||
|
assert headers["Authorization"] == "Bearer valid-api-key-123"
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import json
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
@ -110,9 +111,11 @@ class TestFirecrawlAuth:
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("status_code", "response_text", "has_json_error", "expected_error_contains"),
|
("status_code", "response_text", "has_json_error", "expected_error_contains"),
|
||||||
[
|
[
|
||||||
(403, '{"error": "Forbidden"}', True, "Failed to authorize. Status code: 403. Error: Forbidden"),
|
(403, '{"error": "Forbidden"}', False, "Failed to authorize. Status code: 403. Error: Forbidden"),
|
||||||
(404, "", True, "Unexpected error occurred while trying to authorize. Status code: 404"),
|
# empty body falls back to generic message
|
||||||
(401, "Not JSON", True, "Expecting value"), # JSON decode error
|
(404, "", True, "Failed to authorize. Status code: 404. Error: Unknown error occurred"),
|
||||||
|
# non-JSON body is surfaced directly
|
||||||
|
(401, "Not JSON", True, "Failed to authorize. Status code: 401. Error: Not JSON"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||||
|
|
@ -124,12 +127,14 @@ class TestFirecrawlAuth:
|
||||||
mock_response.status_code = status_code
|
mock_response.status_code = status_code
|
||||||
mock_response.text = response_text
|
mock_response.text = response_text
|
||||||
if has_json_error:
|
if has_json_error:
|
||||||
mock_response.json.side_effect = Exception("Not JSON")
|
mock_response.json.side_effect = json.JSONDecodeError("Not JSON", "", 0)
|
||||||
|
else:
|
||||||
|
mock_response.json.return_value = {"error": "Forbidden"}
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
with pytest.raises(Exception) as exc_info:
|
with pytest.raises(Exception) as exc_info:
|
||||||
auth_instance.validate_credentials()
|
auth_instance.validate_credentials()
|
||||||
assert expected_error_contains in str(exc_info.value)
|
assert str(exc_info.value) == expected_error_contains
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("exception_type", "exception_message"),
|
("exception_type", "exception_message"),
|
||||||
|
|
@ -164,20 +169,21 @@ class TestFirecrawlAuth:
|
||||||
|
|
||||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||||
def test_should_use_custom_base_url_in_validation(self, mock_post):
|
def test_should_use_custom_base_url_in_validation(self, mock_post):
|
||||||
"""Test that custom base URL is used in validation"""
|
"""Test that custom base URL is used in validation and normalized"""
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_post.return_value = mock_response
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
credentials = {
|
for base in ("https://custom.firecrawl.dev", "https://custom.firecrawl.dev/"):
|
||||||
"auth_type": "bearer",
|
credentials = {
|
||||||
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.firecrawl.dev"},
|
"auth_type": "bearer",
|
||||||
}
|
"config": {"api_key": "test_api_key_123", "base_url": base},
|
||||||
auth = FirecrawlAuth(credentials)
|
}
|
||||||
result = auth.validate_credentials()
|
auth = FirecrawlAuth(credentials)
|
||||||
|
result = auth.validate_credentials()
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
|
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
|
||||||
|
|
||||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||||
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
|
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from models import Account
|
||||||
|
from services import app_dsl_service
|
||||||
|
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
|
||||||
|
|
||||||
|
|
||||||
|
def _build_response(url: str, status_code: int, content: bytes = b"") -> httpx.Response:
|
||||||
|
request = httpx.Request("GET", url)
|
||||||
|
return httpx.Response(status_code=status_code, request=request, content=content)
|
||||||
|
|
||||||
|
|
||||||
|
def _pending_yaml_content(version: str = "99.0.0") -> bytes:
|
||||||
|
return (f'version: "{version}"\nkind: app\napp:\n name: Loop Test\n mode: workflow\n').encode()
|
||||||
|
|
||||||
|
|
||||||
|
def _account_mock() -> MagicMock:
|
||||||
|
account = MagicMock(spec=Account)
|
||||||
|
account.current_tenant_id = "tenant-1"
|
||||||
|
return account
|
||||||
|
|
||||||
|
|
||||||
|
def test_import_app_yaml_url_user_attachments_keeps_original_url(monkeypatch):
|
||||||
|
yaml_url = "https://github.com/user-attachments/files/24290802/loop-test.yml"
|
||||||
|
raw_url = "https://raw.githubusercontent.com/user-attachments/files/24290802/loop-test.yml"
|
||||||
|
yaml_bytes = _pending_yaml_content()
|
||||||
|
|
||||||
|
def fake_get(url: str, **kwargs):
|
||||||
|
if url == raw_url:
|
||||||
|
return _build_response(url, status_code=404)
|
||||||
|
assert url == yaml_url
|
||||||
|
return _build_response(url, status_code=200, content=yaml_bytes)
|
||||||
|
|
||||||
|
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get)
|
||||||
|
|
||||||
|
service = AppDslService(MagicMock())
|
||||||
|
result = service.import_app(
|
||||||
|
account=_account_mock(),
|
||||||
|
import_mode=ImportMode.YAML_URL,
|
||||||
|
yaml_url=yaml_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.status == ImportStatus.PENDING
|
||||||
|
assert result.imported_dsl_version == "99.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
def test_import_app_yaml_url_github_blob_rewrites_to_raw(monkeypatch):
|
||||||
|
yaml_url = "https://github.com/acme/repo/blob/main/app.yml"
|
||||||
|
raw_url = "https://raw.githubusercontent.com/acme/repo/main/app.yml"
|
||||||
|
yaml_bytes = _pending_yaml_content()
|
||||||
|
|
||||||
|
requested_urls: list[str] = []
|
||||||
|
|
||||||
|
def fake_get(url: str, **kwargs):
|
||||||
|
requested_urls.append(url)
|
||||||
|
assert url == raw_url
|
||||||
|
return _build_response(url, status_code=200, content=yaml_bytes)
|
||||||
|
|
||||||
|
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get)
|
||||||
|
|
||||||
|
service = AppDslService(MagicMock())
|
||||||
|
result = service.import_app(
|
||||||
|
account=_account_mock(),
|
||||||
|
import_mode=ImportMode.YAML_URL,
|
||||||
|
yaml_url=yaml_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.status == ImportStatus.PENDING
|
||||||
|
assert requested_urls == [raw_url]
|
||||||
|
|
@ -0,0 +1,520 @@
|
||||||
|
"""
|
||||||
|
Unit tests for document indexing sync task.
|
||||||
|
|
||||||
|
This module tests the document indexing sync task functionality including:
|
||||||
|
- Syncing Notion documents when updated
|
||||||
|
- Validating document and data source existence
|
||||||
|
- Credential validation and retrieval
|
||||||
|
- Cleaning old segments before re-indexing
|
||||||
|
- Error handling and edge cases
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||||
|
from models.dataset import Dataset, Document, DocumentSegment
|
||||||
|
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Fixtures
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tenant_id():
|
||||||
|
"""Generate a unique tenant ID for testing."""
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dataset_id():
|
||||||
|
"""Generate a unique dataset ID for testing."""
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def document_id():
|
||||||
|
"""Generate a unique document ID for testing."""
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def notion_workspace_id():
|
||||||
|
"""Generate a Notion workspace ID for testing."""
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def notion_page_id():
|
||||||
|
"""Generate a Notion page ID for testing."""
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def credential_id():
|
||||||
|
"""Generate a credential ID for testing."""
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_dataset(dataset_id, tenant_id):
|
||||||
|
"""Create a mock Dataset object."""
|
||||||
|
dataset = Mock(spec=Dataset)
|
||||||
|
dataset.id = dataset_id
|
||||||
|
dataset.tenant_id = tenant_id
|
||||||
|
dataset.indexing_technique = "high_quality"
|
||||||
|
dataset.embedding_model_provider = "openai"
|
||||||
|
dataset.embedding_model = "text-embedding-ada-002"
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_document(document_id, dataset_id, tenant_id, notion_workspace_id, notion_page_id, credential_id):
|
||||||
|
"""Create a mock Document object with Notion data source."""
|
||||||
|
doc = Mock(spec=Document)
|
||||||
|
doc.id = document_id
|
||||||
|
doc.dataset_id = dataset_id
|
||||||
|
doc.tenant_id = tenant_id
|
||||||
|
doc.data_source_type = "notion_import"
|
||||||
|
doc.indexing_status = "completed"
|
||||||
|
doc.error = None
|
||||||
|
doc.stopped_at = None
|
||||||
|
doc.processing_started_at = None
|
||||||
|
doc.doc_form = "text_model"
|
||||||
|
doc.data_source_info_dict = {
|
||||||
|
"notion_workspace_id": notion_workspace_id,
|
||||||
|
"notion_page_id": notion_page_id,
|
||||||
|
"type": "page",
|
||||||
|
"last_edited_time": "2024-01-01T00:00:00Z",
|
||||||
|
"credential_id": credential_id,
|
||||||
|
}
|
||||||
|
return doc
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_document_segments(document_id):
|
||||||
|
"""Create mock DocumentSegment objects."""
|
||||||
|
segments = []
|
||||||
|
for i in range(3):
|
||||||
|
segment = Mock(spec=DocumentSegment)
|
||||||
|
segment.id = str(uuid.uuid4())
|
||||||
|
segment.document_id = document_id
|
||||||
|
segment.index_node_id = f"node-{document_id}-{i}"
|
||||||
|
segments.append(segment)
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_db_session():
|
||||||
|
"""Mock database session."""
|
||||||
|
with patch("tasks.document_indexing_sync_task.db.session") as mock_session:
|
||||||
|
mock_query = MagicMock()
|
||||||
|
mock_session.query.return_value = mock_query
|
||||||
|
mock_query.where.return_value = mock_query
|
||||||
|
mock_session.scalars.return_value = MagicMock()
|
||||||
|
yield mock_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_datasource_provider_service():
|
||||||
|
"""Mock DatasourceProviderService."""
|
||||||
|
with patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class:
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"}
|
||||||
|
mock_service_class.return_value = mock_service
|
||||||
|
yield mock_service
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_notion_extractor():
|
||||||
|
"""Mock NotionExtractor."""
|
||||||
|
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
|
||||||
|
mock_extractor = MagicMock()
|
||||||
|
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" # Updated time
|
||||||
|
mock_extractor_class.return_value = mock_extractor
|
||||||
|
yield mock_extractor
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_index_processor_factory():
|
||||||
|
"""Mock IndexProcessorFactory."""
|
||||||
|
with patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_factory:
|
||||||
|
mock_processor = MagicMock()
|
||||||
|
mock_processor.clean = Mock()
|
||||||
|
mock_factory.return_value.init_index_processor.return_value = mock_processor
|
||||||
|
yield mock_factory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_indexing_runner():
|
||||||
|
"""Mock IndexingRunner."""
|
||||||
|
with patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_runner_class:
|
||||||
|
mock_runner = MagicMock(spec=IndexingRunner)
|
||||||
|
mock_runner.run = Mock()
|
||||||
|
mock_runner_class.return_value = mock_runner
|
||||||
|
yield mock_runner
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for document_indexing_sync_task
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentIndexingSyncTask:
|
||||||
|
"""Tests for the document_indexing_sync_task function."""
|
||||||
|
|
||||||
|
def test_document_not_found(self, mock_db_session, dataset_id, document_id):
|
||||||
|
"""Test that task handles document not found gracefully."""
|
||||||
|
# Arrange
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.return_value = None
|
||||||
|
|
||||||
|
# Act
|
||||||
|
document_indexing_sync_task(dataset_id, document_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_db_session.close.assert_called_once()
|
||||||
|
|
||||||
|
def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||||
|
"""Test that task raises error when notion_workspace_id is missing."""
|
||||||
|
# Arrange
|
||||||
|
mock_document.data_source_info_dict = {"notion_page_id": "page123", "type": "page"}
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(ValueError, match="no notion page found"):
|
||||||
|
document_indexing_sync_task(dataset_id, document_id)
|
||||||
|
|
||||||
|
def test_missing_notion_page_id(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||||
|
"""Test that task raises error when notion_page_id is missing."""
|
||||||
|
# Arrange
|
||||||
|
mock_document.data_source_info_dict = {"notion_workspace_id": "ws123", "type": "page"}
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(ValueError, match="no notion page found"):
|
||||||
|
document_indexing_sync_task(dataset_id, document_id)
|
||||||
|
|
||||||
|
def test_empty_data_source_info(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||||
|
"""Test that task raises error when data_source_info is empty."""
|
||||||
|
# Arrange
|
||||||
|
mock_document.data_source_info_dict = None
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||||
|
|
||||||
|
# Act & Assert
|
||||||
|
with pytest.raises(ValueError, match="no notion page found"):
|
||||||
|
document_indexing_sync_task(dataset_id, document_id)
|
||||||
|
|
||||||
|
def test_credential_not_found(
|
||||||
|
self,
|
||||||
|
mock_db_session,
|
||||||
|
mock_datasource_provider_service,
|
||||||
|
mock_document,
|
||||||
|
dataset_id,
|
||||||
|
document_id,
|
||||||
|
):
|
||||||
|
"""Test that task handles missing credentials by updating document status."""
|
||||||
|
# Arrange
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||||
|
mock_datasource_provider_service.get_datasource_credentials.return_value = None
|
||||||
|
|
||||||
|
# Act
|
||||||
|
document_indexing_sync_task(dataset_id, document_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert mock_document.indexing_status == "error"
|
||||||
|
assert "Datasource credential not found" in mock_document.error
|
||||||
|
assert mock_document.stopped_at is not None
|
||||||
|
mock_db_session.commit.assert_called()
|
||||||
|
mock_db_session.close.assert_called()
|
||||||
|
|
||||||
|
def test_page_not_updated(
|
||||||
|
self,
|
||||||
|
mock_db_session,
|
||||||
|
mock_datasource_provider_service,
|
||||||
|
mock_notion_extractor,
|
||||||
|
mock_document,
|
||||||
|
dataset_id,
|
||||||
|
document_id,
|
||||||
|
):
|
||||||
|
"""Test that task does nothing when page has not been updated."""
|
||||||
|
# Arrange
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||||
|
# Return same time as stored in document
|
||||||
|
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
document_indexing_sync_task(dataset_id, document_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Document status should remain unchanged
|
||||||
|
assert mock_document.indexing_status == "completed"
|
||||||
|
# No session operations should be performed beyond the initial query
|
||||||
|
mock_db_session.close.assert_not_called()
|
||||||
|
|
||||||
|
def test_successful_sync_when_page_updated(
|
||||||
|
self,
|
||||||
|
mock_db_session,
|
||||||
|
mock_datasource_provider_service,
|
||||||
|
mock_notion_extractor,
|
||||||
|
mock_index_processor_factory,
|
||||||
|
mock_indexing_runner,
|
||||||
|
mock_dataset,
|
||||||
|
mock_document,
|
||||||
|
mock_document_segments,
|
||||||
|
dataset_id,
|
||||||
|
document_id,
|
||||||
|
):
|
||||||
|
"""Test successful sync flow when Notion page has been updated."""
|
||||||
|
# Arrange
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||||
|
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||||
|
# NotionExtractor returns updated time
|
||||||
|
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
document_indexing_sync_task(dataset_id, document_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Verify document status was updated to parsing
|
||||||
|
assert mock_document.indexing_status == "parsing"
|
||||||
|
assert mock_document.processing_started_at is not None
|
||||||
|
|
||||||
|
# Verify segments were cleaned
|
||||||
|
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||||
|
mock_processor.clean.assert_called_once()
|
||||||
|
|
||||||
|
# Verify segments were deleted from database
|
||||||
|
for segment in mock_document_segments:
|
||||||
|
mock_db_session.delete.assert_any_call(segment)
|
||||||
|
|
||||||
|
# Verify indexing runner was called
|
||||||
|
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||||
|
|
||||||
|
# Verify session operations
|
||||||
|
assert mock_db_session.commit.called
|
||||||
|
mock_db_session.close.assert_called_once()
|
||||||
|
|
||||||
|
def test_dataset_not_found_during_cleaning(
|
||||||
|
self,
|
||||||
|
mock_db_session,
|
||||||
|
mock_datasource_provider_service,
|
||||||
|
mock_notion_extractor,
|
||||||
|
mock_document,
|
||||||
|
dataset_id,
|
||||||
|
document_id,
|
||||||
|
):
|
||||||
|
"""Test that task handles dataset not found during cleaning phase."""
|
||||||
|
# Arrange
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, None]
|
||||||
|
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
document_indexing_sync_task(dataset_id, document_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Document should still be set to parsing
|
||||||
|
assert mock_document.indexing_status == "parsing"
|
||||||
|
# Session should be closed after error
|
||||||
|
mock_db_session.close.assert_called_once()
|
||||||
|
|
||||||
|
def test_cleaning_error_continues_to_indexing(
|
||||||
|
self,
|
||||||
|
mock_db_session,
|
||||||
|
mock_datasource_provider_service,
|
||||||
|
mock_notion_extractor,
|
||||||
|
mock_index_processor_factory,
|
||||||
|
mock_indexing_runner,
|
||||||
|
mock_dataset,
|
||||||
|
mock_document,
|
||||||
|
dataset_id,
|
||||||
|
document_id,
|
||||||
|
):
|
||||||
|
"""Test that indexing continues even if cleaning fails."""
|
||||||
|
# Arrange
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||||
|
mock_db_session.scalars.return_value.all.side_effect = Exception("Cleaning error")
|
||||||
|
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
document_indexing_sync_task(dataset_id, document_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Indexing should still be attempted despite cleaning error
|
||||||
|
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||||
|
mock_db_session.close.assert_called_once()
|
||||||
|
|
||||||
|
def test_indexing_runner_document_paused_error(
|
||||||
|
self,
|
||||||
|
mock_db_session,
|
||||||
|
mock_datasource_provider_service,
|
||||||
|
mock_notion_extractor,
|
||||||
|
mock_index_processor_factory,
|
||||||
|
mock_indexing_runner,
|
||||||
|
mock_dataset,
|
||||||
|
mock_document,
|
||||||
|
mock_document_segments,
|
||||||
|
dataset_id,
|
||||||
|
document_id,
|
||||||
|
):
|
||||||
|
"""Test that DocumentIsPausedError is handled gracefully."""
|
||||||
|
# Arrange
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||||
|
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||||
|
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||||
|
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
|
||||||
|
|
||||||
|
# Act
|
||||||
|
document_indexing_sync_task(dataset_id, document_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Session should be closed after handling error
|
||||||
|
mock_db_session.close.assert_called_once()
|
||||||
|
|
||||||
|
def test_indexing_runner_general_error(
|
||||||
|
self,
|
||||||
|
mock_db_session,
|
||||||
|
mock_datasource_provider_service,
|
||||||
|
mock_notion_extractor,
|
||||||
|
mock_index_processor_factory,
|
||||||
|
mock_indexing_runner,
|
||||||
|
mock_dataset,
|
||||||
|
mock_document,
|
||||||
|
mock_document_segments,
|
||||||
|
dataset_id,
|
||||||
|
document_id,
|
||||||
|
):
|
||||||
|
"""Test that general exceptions during indexing are handled."""
|
||||||
|
# Arrange
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||||
|
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||||
|
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||||
|
mock_indexing_runner.run.side_effect = Exception("Indexing error")
|
||||||
|
|
||||||
|
# Act
|
||||||
|
document_indexing_sync_task(dataset_id, document_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# Session should be closed after error
|
||||||
|
mock_db_session.close.assert_called_once()
|
||||||
|
|
||||||
|
def test_notion_extractor_initialized_with_correct_params(
|
||||||
|
self,
|
||||||
|
mock_db_session,
|
||||||
|
mock_datasource_provider_service,
|
||||||
|
mock_notion_extractor,
|
||||||
|
mock_document,
|
||||||
|
dataset_id,
|
||||||
|
document_id,
|
||||||
|
notion_workspace_id,
|
||||||
|
notion_page_id,
|
||||||
|
):
|
||||||
|
"""Test that NotionExtractor is initialized with correct parameters."""
|
||||||
|
# Arrange
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||||
|
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" # No update
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
|
||||||
|
mock_extractor = MagicMock()
|
||||||
|
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||||
|
mock_extractor_class.return_value = mock_extractor
|
||||||
|
|
||||||
|
document_indexing_sync_task(dataset_id, document_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_extractor_class.assert_called_once_with(
|
||||||
|
notion_workspace_id=notion_workspace_id,
|
||||||
|
notion_obj_id=notion_page_id,
|
||||||
|
notion_page_type="page",
|
||||||
|
notion_access_token="test_token",
|
||||||
|
tenant_id=mock_document.tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_datasource_credentials_requested_correctly(
|
||||||
|
self,
|
||||||
|
mock_db_session,
|
||||||
|
mock_datasource_provider_service,
|
||||||
|
mock_notion_extractor,
|
||||||
|
mock_document,
|
||||||
|
dataset_id,
|
||||||
|
document_id,
|
||||||
|
credential_id,
|
||||||
|
):
|
||||||
|
"""Test that datasource credentials are requested with correct parameters."""
|
||||||
|
# Arrange
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||||
|
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
document_indexing_sync_task(dataset_id, document_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with(
|
||||||
|
tenant_id=mock_document.tenant_id,
|
||||||
|
credential_id=credential_id,
|
||||||
|
provider="notion_datasource",
|
||||||
|
plugin_id="langgenius/notion_datasource",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_credential_id_missing_uses_none(
|
||||||
|
self,
|
||||||
|
mock_db_session,
|
||||||
|
mock_datasource_provider_service,
|
||||||
|
mock_notion_extractor,
|
||||||
|
mock_document,
|
||||||
|
dataset_id,
|
||||||
|
document_id,
|
||||||
|
):
|
||||||
|
"""Test that task handles missing credential_id by passing None."""
|
||||||
|
# Arrange
|
||||||
|
mock_document.data_source_info_dict = {
|
||||||
|
"notion_workspace_id": "ws123",
|
||||||
|
"notion_page_id": "page123",
|
||||||
|
"type": "page",
|
||||||
|
"last_edited_time": "2024-01-01T00:00:00Z",
|
||||||
|
}
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||||
|
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
document_indexing_sync_task(dataset_id, document_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with(
|
||||||
|
tenant_id=mock_document.tenant_id,
|
||||||
|
credential_id=None,
|
||||||
|
provider="notion_datasource",
|
||||||
|
plugin_id="langgenius/notion_datasource",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_index_processor_clean_called_with_correct_params(
|
||||||
|
self,
|
||||||
|
mock_db_session,
|
||||||
|
mock_datasource_provider_service,
|
||||||
|
mock_notion_extractor,
|
||||||
|
mock_index_processor_factory,
|
||||||
|
mock_indexing_runner,
|
||||||
|
mock_dataset,
|
||||||
|
mock_document,
|
||||||
|
mock_document_segments,
|
||||||
|
dataset_id,
|
||||||
|
document_id,
|
||||||
|
):
|
||||||
|
"""Test that index processor clean is called with correct parameters."""
|
||||||
|
# Arrange
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||||
|
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||||
|
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||||
|
|
||||||
|
# Act
|
||||||
|
document_indexing_sync_task(dataset_id, document_id)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||||
|
expected_node_ids = [seg.index_node_id for seg in mock_document_segments]
|
||||||
|
mock_processor.clean.assert_called_once_with(
|
||||||
|
mock_dataset, expected_node_ids, with_keywords=True, delete_child_chunks=True
|
||||||
|
)
|
||||||
39
api/uv.lock
39
api/uv.lock
|
|
@ -1636,7 +1636,7 @@ requires-dist = [
|
||||||
{ name = "pydantic-extra-types", specifier = "~=2.10.3" },
|
{ name = "pydantic-extra-types", specifier = "~=2.10.3" },
|
||||||
{ name = "pydantic-settings", specifier = "~=2.11.0" },
|
{ name = "pydantic-settings", specifier = "~=2.11.0" },
|
||||||
{ name = "pyjwt", specifier = "~=2.10.1" },
|
{ name = "pyjwt", specifier = "~=2.10.1" },
|
||||||
{ name = "pypdfium2", specifier = "==4.30.0" },
|
{ name = "pypdfium2", specifier = "==5.2.0" },
|
||||||
{ name = "python-docx", specifier = "~=1.1.0" },
|
{ name = "python-docx", specifier = "~=1.1.0" },
|
||||||
{ name = "python-dotenv", specifier = "==1.0.1" },
|
{ name = "python-dotenv", specifier = "==1.0.1" },
|
||||||
{ name = "pyyaml", specifier = "~=6.0.1" },
|
{ name = "pyyaml", specifier = "~=6.0.1" },
|
||||||
|
|
@ -4993,22 +4993,31 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pypdfium2"
|
name = "pypdfium2"
|
||||||
version = "4.30.0"
|
version = "5.2.0"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/a1/14/838b3ba247a0ba92e4df5d23f2bea9478edcfd72b78a39d6ca36ccd84ad2/pypdfium2-4.30.0.tar.gz", hash = "sha256:48b5b7e5566665bc1015b9d69c1ebabe21f6aee468b509531c3c8318eeee2e16", size = 140239, upload-time = "2024-05-09T18:33:17.552Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/f6/ab/73c7d24e4eac9ba952569403b32b7cca9412fc5b9bef54fdbd669551389f/pypdfium2-5.2.0.tar.gz", hash = "sha256:43863625231ce999c1ebbed6721a88de818b2ab4d909c1de558d413b9a400256", size = 269999, upload-time = "2025-12-12T13:20:15.353Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/c7/9a/c8ff5cc352c1b60b0b97642ae734f51edbab6e28b45b4fcdfe5306ee3c83/pypdfium2-4.30.0-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:b33ceded0b6ff5b2b93bc1fe0ad4b71aa6b7e7bd5875f1ca0cdfb6ba6ac01aab", size = 2837254, upload-time = "2024-05-09T18:32:48.653Z" },
|
{ url = "https://files.pythonhosted.org/packages/fb/0c/9108ae5266ee4cdf495f99205c44d4b5c83b4eb227c2b610d35c9e9fe961/pypdfium2-5.2.0-py3-none-android_23_arm64_v8a.whl", hash = "sha256:1ba4187a45ce4cf08f2a8c7e0f8970c36b9aa1770c8a3412a70781c1d80fb145", size = 2763268, upload-time = "2025-12-12T13:19:37.354Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/21/8b/27d4d5409f3c76b985f4ee4afe147b606594411e15ac4dc1c3363c9a9810/pypdfium2-4.30.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:4e55689f4b06e2d2406203e771f78789bd4f190731b5d57383d05cf611d829de", size = 2707624, upload-time = "2024-05-09T18:32:51.458Z" },
|
{ url = "https://files.pythonhosted.org/packages/35/8c/55f5c8a2c6b293f5c020be4aa123eaa891e797c514e5eccd8cb042740d37/pypdfium2-5.2.0-py3-none-android_23_armeabi_v7a.whl", hash = "sha256:80c55e10a8c9242f0901d35a9a306dd09accce8e497507bb23fcec017d45fe2e", size = 2301821, upload-time = "2025-12-12T13:19:39.484Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/11/63/28a73ca17c24b41a205d658e177d68e198d7dde65a8c99c821d231b6ee3d/pypdfium2-4.30.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e6e50f5ce7f65a40a33d7c9edc39f23140c57e37144c2d6d9e9262a2a854854", size = 2793126, upload-time = "2024-05-09T18:32:53.581Z" },
|
{ url = "https://files.pythonhosted.org/packages/5e/7d/efa013e3795b41c59dd1e472f7201c241232c3a6553be4917e3a26b9f225/pypdfium2-5.2.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:73523ae69cd95c084c1342096893b2143ea73c36fdde35494780ba431e6a7d6e", size = 2816428, upload-time = "2025-12-12T13:19:41.735Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/d1/96/53b3ebf0955edbd02ac6da16a818ecc65c939e98fdeb4e0958362bd385c8/pypdfium2-4.30.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3d0dd3ecaffd0b6dbda3da663220e705cb563918249bda26058c6036752ba3a2", size = 2591077, upload-time = "2024-05-09T18:32:55.99Z" },
|
{ url = "https://files.pythonhosted.org/packages/ec/ae/8c30af6ff2ab41a7cb84753ee79dd1e0a8932c9bda9fe19759d69cbbf115/pypdfium2-5.2.0-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:19c501d22ef5eb98e42416d22cc3ac66d4808b436e3d06686392f24d8d9f708d", size = 2939486, upload-time = "2025-12-12T13:19:43.176Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/ec/ee/0394e56e7cab8b5b21f744d988400948ef71a9a892cbeb0b200d324ab2c7/pypdfium2-4.30.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cc3bf29b0db8c76cdfaac1ec1cde8edf211a7de7390fbf8934ad2aa9b4d6dfad", size = 2864431, upload-time = "2024-05-09T18:32:57.911Z" },
|
{ url = "https://files.pythonhosted.org/packages/64/64/454a73c49a04c2c290917ad86184e4da959e9e5aba94b3b046328c89be93/pypdfium2-5.2.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ed15a3f58d6ee4905f0d0a731e30b381b457c30689512589c7f57950b0cdcec", size = 2979235, upload-time = "2025-12-12T13:19:44.635Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/65/cd/3f1edf20a0ef4a212a5e20a5900e64942c5a374473671ac0780eaa08ea80/pypdfium2-4.30.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1f78d2189e0ddf9ac2b7a9b9bd4f0c66f54d1389ff6c17e9fd9dc034d06eb3f", size = 2812008, upload-time = "2024-05-09T18:32:59.886Z" },
|
{ url = "https://files.pythonhosted.org/packages/4e/29/f1cab8e31192dd367dc7b1afa71f45cfcb8ff0b176f1d2a0f528faf04052/pypdfium2-5.2.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:329cd1e9f068e8729e0d0b79a070d6126f52bc48ff1e40505cb207a5e20ce0ba", size = 2763001, upload-time = "2025-12-12T13:19:47.598Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/c8/91/2d517db61845698f41a2a974de90762e50faeb529201c6b3574935969045/pypdfium2-4.30.0-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:5eda3641a2da7a7a0b2f4dbd71d706401a656fea521b6b6faa0675b15d31a163", size = 6181543, upload-time = "2024-05-09T18:33:02.597Z" },
|
{ url = "https://files.pythonhosted.org/packages/bc/5d/e95fad8fdac960854173469c4b6931d5de5e09d05e6ee7d9756f8b95eef0/pypdfium2-5.2.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:325259759886e66619504df4721fef3b8deabf8a233e4f4a66e0c32ebae60c2f", size = 3057024, upload-time = "2025-12-12T13:19:49.179Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/ba/c4/ed1315143a7a84b2c7616569dfb472473968d628f17c231c39e29ae9d780/pypdfium2-4.30.0-py3-none-musllinux_1_1_i686.whl", hash = "sha256:0dfa61421b5eb68e1188b0b2231e7ba35735aef2d867d86e48ee6cab6975195e", size = 6175911, upload-time = "2024-05-09T18:33:05.376Z" },
|
{ url = "https://files.pythonhosted.org/packages/f4/32/468591d017ab67f8142d40f4db8163b6d8bb404fe0d22da75a5c661dc144/pypdfium2-5.2.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5683e8f08ab38ed05e0e59e611451ec74332803d4e78f8c45658ea1d372a17af", size = 3448598, upload-time = "2025-12-12T13:19:50.979Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/7a/c4/9e62d03f414e0e3051c56d5943c3bf42aa9608ede4e19dc96438364e9e03/pypdfium2-4.30.0-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:f33bd79e7a09d5f7acca3b0b69ff6c8a488869a7fab48fdf400fec6e20b9c8be", size = 6267430, upload-time = "2024-05-09T18:33:08.067Z" },
|
{ url = "https://files.pythonhosted.org/packages/f9/a5/57b4e389b77ab5f7e9361dc7fc03b5378e678ba81b21e791e85350fbb235/pypdfium2-5.2.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da4815426a5adcf03bf4d2c5f26c0ff8109dbfaf2c3415984689931bc6006ef9", size = 2993946, upload-time = "2025-12-12T13:19:53.154Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/90/47/eda4904f715fb98561e34012826e883816945934a851745570521ec89520/pypdfium2-4.30.0-py3-none-win32.whl", hash = "sha256:ee2410f15d576d976c2ab2558c93d392a25fb9f6635e8dd0a8a3a5241b275e0e", size = 2775951, upload-time = "2024-05-09T18:33:10.567Z" },
|
{ url = "https://files.pythonhosted.org/packages/84/3a/e03e9978f817632aa56183bb7a4989284086fdd45de3245ead35f147179b/pypdfium2-5.2.0-py3-none-manylinux_2_27_s390x.manylinux_2_28_s390x.whl", hash = "sha256:64bf5c039b2c314dab1fd158bfff99db96299a5b5c6d96fc056071166056f1de", size = 3673148, upload-time = "2025-12-12T13:19:54.528Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/25/bd/56d9ec6b9f0fc4e0d95288759f3179f0fcd34b1a1526b75673d2f6d5196f/pypdfium2-4.30.0-py3-none-win_amd64.whl", hash = "sha256:90dbb2ac07be53219f56be09961eb95cf2473f834d01a42d901d13ccfad64b4c", size = 2892098, upload-time = "2024-05-09T18:33:13.107Z" },
|
{ url = "https://files.pythonhosted.org/packages/13/ee/e581506806553afa4b7939d47bf50dca35c1151b8cc960f4542a6eb135ce/pypdfium2-5.2.0-py3-none-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:76b42a17748ac7dc04d5ef04d0561c6a0a4b546d113ec1d101d59650c6a340f7", size = 2964757, upload-time = "2025-12-12T13:19:56.406Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/be/7a/097801205b991bc3115e8af1edb850d30aeaf0118520b016354cf5ccd3f6/pypdfium2-4.30.0-py3-none-win_arm64.whl", hash = "sha256:119b2969a6d6b1e8d55e99caaf05290294f2d0fe49c12a3f17102d01c441bd29", size = 2752118, upload-time = "2024-05-09T18:33:15.489Z" },
|
{ url = "https://files.pythonhosted.org/packages/00/be/3715c652aff30f12284523dd337843d0efe3e721020f0ec303a99ffffd8d/pypdfium2-5.2.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:9d4367d471439fae846f0aba91ff9e8d66e524edcf3c8d6e02fe96fa306e13b9", size = 4130319, upload-time = "2025-12-12T13:19:57.889Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b0/0b/28aa2ede9004dd4192266bbad394df0896787f7c7bcfa4d1a6e091ad9a2c/pypdfium2-5.2.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:613f6bb2b47d76b66c0bf2ca581c7c33e3dd9dcb29d65d8c34fef4135f933149", size = 3746488, upload-time = "2025-12-12T13:19:59.469Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/bc/04/1b791e1219652bbfc51df6498267d8dcec73ad508b99388b2890902ccd9d/pypdfium2-5.2.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c03fad3f2fa68d358f5dd4deb07e438482fa26fae439c49d127576d969769ca1", size = 4336534, upload-time = "2025-12-12T13:20:01.28Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4f/e3/6f00f963bb702ffd2e3e2d9c7286bc3bb0bebcdfa96ca897d466f66976c6/pypdfium2-5.2.0-py3-none-musllinux_1_2_ppc64le.whl", hash = "sha256:f10be1900ae21879d02d9f4d58c2d2db3a2e6da611736a8e9decc22d1fb02909", size = 4375079, upload-time = "2025-12-12T13:20:03.117Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3a/2a/7ec2b191b5e1b7716a0dfc14e6860e89bb355fb3b94ed0c1d46db526858c/pypdfium2-5.2.0-py3-none-musllinux_1_2_riscv64.whl", hash = "sha256:97c1a126d30378726872f94866e38c055740cae80313638dafd1cd448d05e7c0", size = 3928648, upload-time = "2025-12-12T13:20:05.041Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/bf/c3/c6d972fa095ff3ace76f9d3a91ceaf8a9dbbe0d9a5a84ac1d6178a46630e/pypdfium2-5.2.0-py3-none-musllinux_1_2_s390x.whl", hash = "sha256:c369f183a90781b788af9a357a877bc8caddc24801e8346d0bf23f3295f89f3a", size = 4997772, upload-time = "2025-12-12T13:20:06.453Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/22/45/2c64584b7a3ca5c4652280a884f4b85b8ed24e27662adeebdc06d991c917/pypdfium2-5.2.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b391f1cceb454934b612a05b54e90f98aafeffe5e73830d71700b17f0812226b", size = 4180046, upload-time = "2025-12-12T13:20:08.715Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d6/99/8d1ff87b626649400e62a2840e6e10fe258443ba518798e071fee4cd86f9/pypdfium2-5.2.0-py3-none-win32.whl", hash = "sha256:c68067938f617c37e4d17b18de7cac231fc7ce0eb7b6653b7283ebe8764d4999", size = 2990175, upload-time = "2025-12-12T13:20:10.241Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/93/fc/114fff8895b620aac4984808e93d01b6d7b93e342a1635fcfe2a5f39cf39/pypdfium2-5.2.0-py3-none-win_amd64.whl", hash = "sha256:eb0591b720e8aaeab9475c66d653655ec1be0464b946f3f48a53922e843f0f3b", size = 3098615, upload-time = "2025-12-12T13:20:11.795Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/08/97/eb738bff5998760d6e0cbcb7dd04cbf1a95a97b997fac6d4e57562a58992/pypdfium2-5.2.0-py3-none-win_arm64.whl", hash = "sha256:5dd1ef579f19fa3719aee4959b28bda44b1072405756708b5e83df8806a19521", size = 2939479, upload-time = "2025-12-12T13:20:13.815Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
|
||||||
|
|
@ -468,6 +468,7 @@ ALIYUN_OSS_REGION=ap-southeast-1
|
||||||
ALIYUN_OSS_AUTH_VERSION=v4
|
ALIYUN_OSS_AUTH_VERSION=v4
|
||||||
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
||||||
ALIYUN_OSS_PATH=your-path
|
ALIYUN_OSS_PATH=your-path
|
||||||
|
ALIYUN_CLOUDBOX_ID=your-cloudbox-id
|
||||||
|
|
||||||
# Tencent COS Configuration
|
# Tencent COS Configuration
|
||||||
#
|
#
|
||||||
|
|
@ -491,6 +492,7 @@ HUAWEI_OBS_BUCKET_NAME=your-bucket-name
|
||||||
HUAWEI_OBS_SECRET_KEY=your-secret-key
|
HUAWEI_OBS_SECRET_KEY=your-secret-key
|
||||||
HUAWEI_OBS_ACCESS_KEY=your-access-key
|
HUAWEI_OBS_ACCESS_KEY=your-access-key
|
||||||
HUAWEI_OBS_SERVER=your-server-url
|
HUAWEI_OBS_SERVER=your-server-url
|
||||||
|
HUAWEI_OBS_PATH_STYLE=false
|
||||||
|
|
||||||
# Volcengine TOS Configuration
|
# Volcengine TOS Configuration
|
||||||
#
|
#
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,10 @@ Welcome to the new `docker` directory for deploying Dify using Docker Compose. T
|
||||||
- Navigate to the `docker` directory.
|
- Navigate to the `docker` directory.
|
||||||
- Copy the `.env.example` file to a new file named `.env` by running `cp .env.example .env`.
|
- Copy the `.env.example` file to a new file named `.env` by running `cp .env.example .env`.
|
||||||
- Customize the `.env` file as needed. Refer to the `.env.example` file for detailed configuration options.
|
- Customize the `.env` file as needed. Refer to the `.env.example` file for detailed configuration options.
|
||||||
|
- **Optional (Recommended for upgrades)**:
|
||||||
|
You may use the environment synchronization tool to help keep your `.env` file aligned with the latest `.env.example` updates, while preserving your custom settings.
|
||||||
|
This is especially useful when upgrading Dify or managing a large, customized `.env` file.
|
||||||
|
See the [Environment Variables Synchronization](#environment-variables-synchronization) section below.
|
||||||
1. **Running the Services**:
|
1. **Running the Services**:
|
||||||
- Execute `docker compose up` from the `docker` directory to start the services.
|
- Execute `docker compose up` from the `docker` directory to start the services.
|
||||||
- To specify a vector database, set the `VECTOR_STORE` variable in your `.env` file to your desired vector database service, such as `milvus`, `weaviate`, or `opensearch`.
|
- To specify a vector database, set the `VECTOR_STORE` variable in your `.env` file to your desired vector database service, such as `milvus`, `weaviate`, or `opensearch`.
|
||||||
|
|
@ -111,6 +115,47 @@ The `.env.example` file provided in the Docker setup is extensive and covers a w
|
||||||
|
|
||||||
- Each service like `nginx`, `redis`, `db`, and vector databases have specific environment variables that are directly referenced in the `docker-compose.yaml`.
|
- Each service like `nginx`, `redis`, `db`, and vector databases have specific environment variables that are directly referenced in the `docker-compose.yaml`.
|
||||||
|
|
||||||
|
### Environment Variables Synchronization
|
||||||
|
|
||||||
|
When upgrading Dify or pulling the latest changes, new environment variables may be introduced in `.env.example`.
|
||||||
|
|
||||||
|
To help keep your existing `.env` file up to date **without losing your custom values**, an optional environment variables synchronization tool is provided.
|
||||||
|
|
||||||
|
> This tool performs a **one-way synchronization** from `.env.example` to `.env`.
|
||||||
|
> Existing values in `.env` are never overwritten automatically.
|
||||||
|
|
||||||
|
#### `dify-env-sync.sh` (Optional)
|
||||||
|
|
||||||
|
This script compares your current `.env` file with the latest `.env.example` template and helps safely apply new or updated environment variables.
|
||||||
|
|
||||||
|
**What it does**
|
||||||
|
|
||||||
|
- Creates a backup of the current `.env` file before making any changes
|
||||||
|
- Synchronizes newly added environment variables from `.env.example`
|
||||||
|
- Preserves all existing custom values in `.env`
|
||||||
|
- Displays differences and variables removed from `.env.example` for review
|
||||||
|
|
||||||
|
**Backup behavior**
|
||||||
|
|
||||||
|
Before synchronization, the current `.env` file is saved to the `env-backup/` directory with a timestamped filename
|
||||||
|
(e.g. `env-backup/.env.backup_20231218_143022`).
|
||||||
|
|
||||||
|
**When to use**
|
||||||
|
|
||||||
|
- After upgrading Dify to a newer version
|
||||||
|
- When `.env.example` has been updated with new environment variables
|
||||||
|
- When managing a large or heavily customized `.env` file
|
||||||
|
|
||||||
|
**Usage**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Grant execution permission (first time only)
|
||||||
|
chmod +x dify-env-sync.sh
|
||||||
|
|
||||||
|
# Run the synchronization
|
||||||
|
./dify-env-sync.sh
|
||||||
|
```
|
||||||
|
|
||||||
### Additional Information
|
### Additional Information
|
||||||
|
|
||||||
- **Continuous Improvement Phase**: We are actively seeking feedback from the community to refine and enhance the deployment process. As more users adopt this new method, we will continue to make improvements based on your experiences and suggestions.
|
- **Continuous Improvement Phase**: We are actively seeking feedback from the community to refine and enhance the deployment process. As more users adopt this new method, we will continue to make improvements based on your experiences and suggestions.
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,465 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# Dify Environment Variables Synchronization Script
|
||||||
|
#
|
||||||
|
# Features:
|
||||||
|
# - Synchronize latest settings from .env.example to .env
|
||||||
|
# - Preserve custom settings in existing .env
|
||||||
|
# - Add new environment variables
|
||||||
|
# - Detect removed environment variables
|
||||||
|
# - Create backup files
|
||||||
|
# ================================================================
|
||||||
|
|
||||||
|
set -eo pipefail # Exit on error and pipe failures (safer for complex variable handling)
|
||||||
|
|
||||||
|
# Error handling function
|
||||||
|
# Arguments:
|
||||||
|
# $1 - Line number where error occurred
|
||||||
|
# $2 - Error code
|
||||||
|
handle_error() {
|
||||||
|
local line_no=$1
|
||||||
|
local error_code=$2
|
||||||
|
echo -e "\033[0;31m[ERROR]\033[0m Script error: line $line_no with error code $error_code" >&2
|
||||||
|
echo -e "\033[0;31m[ERROR]\033[0m Debug info: current working directory $(pwd)" >&2
|
||||||
|
exit $error_code
|
||||||
|
}
|
||||||
|
|
||||||
|
# Set error trap
|
||||||
|
trap 'handle_error ${LINENO} $?' ERR
|
||||||
|
|
||||||
|
# Color settings for output
|
||||||
|
readonly RED='\033[0;31m'
|
||||||
|
readonly GREEN='\033[0;32m'
|
||||||
|
readonly YELLOW='\033[1;33m'
|
||||||
|
readonly BLUE='\033[0;34m'
|
||||||
|
readonly NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
# Logging functions
|
||||||
|
# Print informational message in blue
|
||||||
|
# Arguments: $1 - Message to print
|
||||||
|
log_info() {
|
||||||
|
echo -e "${BLUE}[INFO]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Print success message in green
|
||||||
|
# Arguments: $1 - Message to print
|
||||||
|
log_success() {
|
||||||
|
echo -e "${GREEN}[SUCCESS]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Print warning message in yellow
|
||||||
|
# Arguments: $1 - Message to print
|
||||||
|
log_warning() {
|
||||||
|
echo -e "${YELLOW}[WARNING]${NC} $1" >&2
|
||||||
|
}
|
||||||
|
|
||||||
|
# Print error message in red to stderr
|
||||||
|
# Arguments: $1 - Message to print
|
||||||
|
log_error() {
|
||||||
|
echo -e "${RED}[ERROR]${NC} $1" >&2
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check for required files and create .env if missing
|
||||||
|
# Verifies that .env.example exists and creates .env from template if needed
|
||||||
|
check_files() {
|
||||||
|
log_info "Checking required files..."
|
||||||
|
|
||||||
|
if [[ ! -f ".env.example" ]]; then
|
||||||
|
log_error ".env.example file not found"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ ! -f ".env" ]]; then
|
||||||
|
log_warning ".env file does not exist. Creating from .env.example."
|
||||||
|
cp ".env.example" ".env"
|
||||||
|
log_success ".env file created"
|
||||||
|
fi
|
||||||
|
|
||||||
|
log_success "Required files verified"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create timestamped backup of .env file
|
||||||
|
# Creates env-backup directory if needed and backs up current .env file
|
||||||
|
create_backup() {
|
||||||
|
local timestamp=$(date +"%Y%m%d_%H%M%S")
|
||||||
|
local backup_dir="env-backup"
|
||||||
|
|
||||||
|
# Create backup directory if it doesn't exist
|
||||||
|
if [[ ! -d "$backup_dir" ]]; then
|
||||||
|
mkdir -p "$backup_dir"
|
||||||
|
log_info "Created backup directory: $backup_dir"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -f ".env" ]]; then
|
||||||
|
local backup_file="${backup_dir}/.env.backup_${timestamp}"
|
||||||
|
cp ".env" "$backup_file"
|
||||||
|
log_success "Backed up existing .env to $backup_file"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Detect differences between .env and .env.example (optimized for large files)
|
||||||
|
detect_differences() {
|
||||||
|
log_info "Detecting differences between .env and .env.example..."
|
||||||
|
|
||||||
|
# Create secure temporary directory
|
||||||
|
local temp_dir=$(mktemp -d)
|
||||||
|
local temp_diff="$temp_dir/env_diff"
|
||||||
|
|
||||||
|
# Store diff file path as global variable
|
||||||
|
declare -g DIFF_FILE="$temp_diff"
|
||||||
|
declare -g TEMP_DIR="$temp_dir"
|
||||||
|
|
||||||
|
# Initialize difference file
|
||||||
|
> "$temp_diff"
|
||||||
|
|
||||||
|
# Use awk for efficient comparison (much faster for large files)
|
||||||
|
local diff_count=$(awk -F= '
|
||||||
|
BEGIN { OFS="\x01" }
|
||||||
|
FNR==NR {
|
||||||
|
if (!/^[[:space:]]*#/ && !/^[[:space:]]*$/ && /=/) {
|
||||||
|
gsub(/^[[:space:]]+|[[:space:]]+$/, "", $1)
|
||||||
|
key = $1
|
||||||
|
value = substr($0, index($0,"=")+1)
|
||||||
|
gsub(/^[[:space:]]+|[[:space:]]+$/, "", value)
|
||||||
|
env_values[key] = value
|
||||||
|
}
|
||||||
|
next
|
||||||
|
}
|
||||||
|
{
|
||||||
|
if (!/^[[:space:]]*#/ && !/^[[:space:]]*$/ && /=/) {
|
||||||
|
gsub(/^[[:space:]]+|[[:space:]]+$/, "", $1)
|
||||||
|
key = $1
|
||||||
|
example_value = substr($0, index($0,"=")+1)
|
||||||
|
gsub(/^[[:space:]]+|[[:space:]]+$/, "", example_value)
|
||||||
|
|
||||||
|
if (key in env_values && env_values[key] != example_value) {
|
||||||
|
print key, env_values[key], example_value > "'$temp_diff'"
|
||||||
|
diff_count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
END { print diff_count }
|
||||||
|
' .env .env.example)
|
||||||
|
|
||||||
|
if [[ $diff_count -gt 0 ]]; then
|
||||||
|
log_success "Detected differences in $diff_count environment variables"
|
||||||
|
# Show detailed differences
|
||||||
|
show_differences_detail
|
||||||
|
else
|
||||||
|
log_info "No differences detected"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Parse environment variable line
|
||||||
|
# Extracts key-value pairs from .env file format lines
|
||||||
|
# Arguments:
|
||||||
|
# $1 - Line to parse
|
||||||
|
# Returns:
|
||||||
|
# 0 - Success, outputs "key|value" format
|
||||||
|
# 1 - Skip (empty line, comment, or invalid format)
|
||||||
|
parse_env_line() {
|
||||||
|
local line="$1"
|
||||||
|
local key=""
|
||||||
|
local value=""
|
||||||
|
|
||||||
|
# Skip empty lines or comment lines
|
||||||
|
[[ -z "$line" || "$line" =~ ^[[:space:]]*# ]] && return 1
|
||||||
|
|
||||||
|
# Split by =
|
||||||
|
if [[ "$line" =~ ^([^=]+)=(.*)$ ]]; then
|
||||||
|
key="${BASH_REMATCH[1]}"
|
||||||
|
value="${BASH_REMATCH[2]}"
|
||||||
|
|
||||||
|
# Remove leading and trailing whitespace
|
||||||
|
key=$(echo "$key" | sed 's/^[[:space:]]*//; s/[[:space:]]*$//')
|
||||||
|
value=$(echo "$value" | sed 's/^[[:space:]]*//; s/[[:space:]]*$//')
|
||||||
|
|
||||||
|
if [[ -n "$key" ]]; then
|
||||||
|
echo "$key|$value"
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# Show detailed differences
|
||||||
|
show_differences_detail() {
|
||||||
|
log_info ""
|
||||||
|
log_info "=== Environment Variable Differences ==="
|
||||||
|
|
||||||
|
# Read differences from the already created diff file
|
||||||
|
if [[ ! -s "$DIFF_FILE" ]]; then
|
||||||
|
log_info "No differences to display"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Display differences
|
||||||
|
local count=1
|
||||||
|
while IFS=$'\x01' read -r key env_value example_value; do
|
||||||
|
echo ""
|
||||||
|
echo -e "${YELLOW}[$count] $key${NC}"
|
||||||
|
echo -e " ${GREEN}.env (current)${NC} : ${env_value}"
|
||||||
|
echo -e " ${BLUE}.env.example (recommended)${NC}: ${example_value}"
|
||||||
|
|
||||||
|
# Analyze value changes
|
||||||
|
analyze_value_change "$env_value" "$example_value"
|
||||||
|
((count++))
|
||||||
|
done < "$DIFF_FILE"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
log_info "=== Difference Analysis Complete ==="
|
||||||
|
log_info "Note: Consider changing to the recommended values above."
|
||||||
|
log_info "Current implementation preserves .env values."
|
||||||
|
echo ""
|
||||||
|
}
|
||||||
|
|
||||||
|
# Analyze value changes
|
||||||
|
analyze_value_change() {
|
||||||
|
local current_value="$1"
|
||||||
|
local recommended_value="$2"
|
||||||
|
|
||||||
|
# Analyze value characteristics
|
||||||
|
local analysis=""
|
||||||
|
|
||||||
|
# Empty value check
|
||||||
|
if [[ -z "$current_value" && -n "$recommended_value" ]]; then
|
||||||
|
analysis=" ${RED}→ Setting from empty to recommended value${NC}"
|
||||||
|
elif [[ -n "$current_value" && -z "$recommended_value" ]]; then
|
||||||
|
analysis=" ${RED}→ Recommended value changed to empty${NC}"
|
||||||
|
# Numeric check - using arithmetic evaluation for robust comparison
|
||||||
|
elif [[ "$current_value" =~ ^[0-9]+$ && "$recommended_value" =~ ^[0-9]+$ ]]; then
|
||||||
|
# Use arithmetic evaluation to handle leading zeros correctly
|
||||||
|
if (( 10#$current_value < 10#$recommended_value )); then
|
||||||
|
analysis=" ${BLUE}→ Numeric increase (${current_value} < ${recommended_value})${NC}"
|
||||||
|
elif (( 10#$current_value > 10#$recommended_value )); then
|
||||||
|
analysis=" ${YELLOW}→ Numeric decrease (${current_value} > ${recommended_value})${NC}"
|
||||||
|
fi
|
||||||
|
# Boolean check
|
||||||
|
elif [[ "$current_value" =~ ^(true|false)$ && "$recommended_value" =~ ^(true|false)$ ]]; then
|
||||||
|
if [[ "$current_value" != "$recommended_value" ]]; then
|
||||||
|
analysis=" ${BLUE}→ Boolean value change (${current_value} → ${recommended_value})${NC}"
|
||||||
|
fi
|
||||||
|
# URL/endpoint check
|
||||||
|
elif [[ "$current_value" =~ ^https?:// || "$recommended_value" =~ ^https?:// ]]; then
|
||||||
|
analysis=" ${BLUE}→ URL/endpoint change${NC}"
|
||||||
|
# File path check
|
||||||
|
elif [[ "$current_value" =~ ^/ || "$recommended_value" =~ ^/ ]]; then
|
||||||
|
analysis=" ${BLUE}→ File path change${NC}"
|
||||||
|
else
|
||||||
|
# Length comparison
|
||||||
|
local current_len=${#current_value}
|
||||||
|
local recommended_len=${#recommended_value}
|
||||||
|
if [[ $current_len -ne $recommended_len ]]; then
|
||||||
|
analysis=" ${YELLOW}→ String length change (${current_len} → ${recommended_len} characters)${NC}"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -n "$analysis" ]]; then
|
||||||
|
echo -e "$analysis"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Synchronize .env file with .env.example while preserving custom values
|
||||||
|
# Creates a new .env file based on .env.example structure, preserving existing custom values
|
||||||
|
# Global variables used: DIFF_FILE, TEMP_DIR
|
||||||
|
sync_env_file() {
|
||||||
|
log_info "Starting partial synchronization of .env file..."
|
||||||
|
|
||||||
|
local new_env_file=".env.new"
|
||||||
|
local preserved_count=0
|
||||||
|
local updated_count=0
|
||||||
|
|
||||||
|
# Pre-process diff file for efficient lookup
|
||||||
|
local lookup_file=""
|
||||||
|
if [[ -f "$DIFF_FILE" && -s "$DIFF_FILE" ]]; then
|
||||||
|
lookup_file="${DIFF_FILE}.lookup"
|
||||||
|
# Create sorted lookup file for fast search
|
||||||
|
sort "$DIFF_FILE" > "$lookup_file"
|
||||||
|
log_info "Created lookup file for $(wc -l < "$DIFF_FILE") preserved values"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Use AWK for efficient processing (much faster than bash loop for large files)
|
||||||
|
log_info "Processing $(wc -l < .env.example) lines with AWK..."
|
||||||
|
|
||||||
|
local preserved_keys_file="${TEMP_DIR}/preserved_keys"
|
||||||
|
local awk_preserved_count_file="${TEMP_DIR}/awk_preserved_count"
|
||||||
|
local awk_updated_count_file="${TEMP_DIR}/awk_updated_count"
|
||||||
|
|
||||||
|
awk -F'=' -v lookup_file="$lookup_file" -v preserved_file="$preserved_keys_file" \
|
||||||
|
-v preserved_count_file="$awk_preserved_count_file" -v updated_count_file="$awk_updated_count_file" '
|
||||||
|
BEGIN {
|
||||||
|
preserved_count = 0
|
||||||
|
updated_count = 0
|
||||||
|
|
||||||
|
# Load preserved values if lookup file exists
|
||||||
|
if (lookup_file != "") {
|
||||||
|
while ((getline line < lookup_file) > 0) {
|
||||||
|
split(line, parts, "\x01")
|
||||||
|
key = parts[1]
|
||||||
|
value = parts[2]
|
||||||
|
preserved_values[key] = value
|
||||||
|
}
|
||||||
|
close(lookup_file)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process each line
|
||||||
|
{
|
||||||
|
# Check if this is an environment variable line
|
||||||
|
if (/^[[:space:]]*[A-Za-z_][A-Za-z0-9_]*[[:space:]]*=/) {
|
||||||
|
# Extract key
|
||||||
|
key = $1
|
||||||
|
gsub(/^[[:space:]]+|[[:space:]]+$/, "", key)
|
||||||
|
|
||||||
|
# Check if key should be preserved
|
||||||
|
if (key in preserved_values) {
|
||||||
|
print key "=" preserved_values[key]
|
||||||
|
print key > preserved_file
|
||||||
|
preserved_count++
|
||||||
|
} else {
|
||||||
|
print $0
|
||||||
|
updated_count++
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
# Not an env var line, preserve as-is
|
||||||
|
print $0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
END {
|
||||||
|
print preserved_count > preserved_count_file
|
||||||
|
print updated_count > updated_count_file
|
||||||
|
}
|
||||||
|
' .env.example > "$new_env_file"
|
||||||
|
|
||||||
|
# Read counters and preserved keys
|
||||||
|
if [[ -f "$awk_preserved_count_file" ]]; then
|
||||||
|
preserved_count=$(cat "$awk_preserved_count_file")
|
||||||
|
fi
|
||||||
|
if [[ -f "$awk_updated_count_file" ]]; then
|
||||||
|
updated_count=$(cat "$awk_updated_count_file")
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Show what was preserved
|
||||||
|
if [[ -f "$preserved_keys_file" ]]; then
|
||||||
|
while read -r key; do
|
||||||
|
[[ -n "$key" ]] && log_info " Preserved: $key (.env value)"
|
||||||
|
done < "$preserved_keys_file"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Clean up lookup file
|
||||||
|
[[ -n "$lookup_file" ]] && rm -f "$lookup_file"
|
||||||
|
|
||||||
|
# Replace the original .env file
|
||||||
|
if mv "$new_env_file" ".env"; then
|
||||||
|
log_success "Successfully created new .env file"
|
||||||
|
else
|
||||||
|
log_error "Failed to replace .env file"
|
||||||
|
rm -f "$new_env_file"
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Clean up difference file and temporary directory
|
||||||
|
if [[ -n "${TEMP_DIR:-}" ]]; then
|
||||||
|
rm -rf "${TEMP_DIR}"
|
||||||
|
unset TEMP_DIR
|
||||||
|
fi
|
||||||
|
if [[ -n "${DIFF_FILE:-}" ]]; then
|
||||||
|
unset DIFF_FILE
|
||||||
|
fi
|
||||||
|
|
||||||
|
log_success "Partial synchronization of .env file completed"
|
||||||
|
log_info " Preserved .env values: $preserved_count"
|
||||||
|
log_info " Updated to .env.example values: $updated_count"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Detect removed environment variables
|
||||||
|
detect_removed_variables() {
|
||||||
|
log_info "Detecting removed environment variables..."
|
||||||
|
|
||||||
|
if [[ ! -f ".env" ]]; then
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Use temporary files for efficient lookup
|
||||||
|
local temp_dir="${TEMP_DIR:-$(mktemp -d)}"
|
||||||
|
local temp_example_keys="$temp_dir/example_keys"
|
||||||
|
local temp_current_keys="$temp_dir/current_keys"
|
||||||
|
local cleanup_temp_dir=""
|
||||||
|
|
||||||
|
# Set flag if we created a new temp directory
|
||||||
|
if [[ -z "${TEMP_DIR:-}" ]]; then
|
||||||
|
cleanup_temp_dir="$temp_dir"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Get keys from .env.example and .env, sorted for comm
|
||||||
|
awk -F= '!/^[[:space:]]*#/ && /=/ {gsub(/^[[:space:]]+|[[:space:]]+$/, "", $1); print $1}' .env.example | sort > "$temp_example_keys"
|
||||||
|
awk -F= '!/^[[:space:]]*#/ && /=/ {gsub(/^[[:space:]]+|[[:space:]]+$/, "", $1); print $1}' .env | sort > "$temp_current_keys"
|
||||||
|
|
||||||
|
# Get keys from existing .env and check for removals
|
||||||
|
local removed_vars=()
|
||||||
|
while IFS= read -r var; do
|
||||||
|
removed_vars+=("$var")
|
||||||
|
done < <(comm -13 "$temp_example_keys" "$temp_current_keys")
|
||||||
|
|
||||||
|
# Clean up temporary files if we created a new temp directory
|
||||||
|
if [[ -n "$cleanup_temp_dir" ]]; then
|
||||||
|
rm -rf "$cleanup_temp_dir"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ ${#removed_vars[@]} -gt 0 ]]; then
|
||||||
|
log_warning "The following environment variables have been removed from .env.example:"
|
||||||
|
for var in "${removed_vars[@]}"; do
|
||||||
|
log_warning " - $var"
|
||||||
|
done
|
||||||
|
log_warning "Consider manually removing these variables from .env"
|
||||||
|
else
|
||||||
|
log_success "No removed environment variables found"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Show statistics
|
||||||
|
show_statistics() {
|
||||||
|
log_info "Synchronization statistics:"
|
||||||
|
|
||||||
|
local total_example=$(grep -c "^[^#]*=" .env.example 2>/dev/null || echo "0")
|
||||||
|
local total_env=$(grep -c "^[^#]*=" .env 2>/dev/null || echo "0")
|
||||||
|
|
||||||
|
log_info " .env.example environment variables: $total_example"
|
||||||
|
log_info " .env environment variables: $total_env"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Main execution function
|
||||||
|
# Orchestrates the complete synchronization process in the correct order
|
||||||
|
main() {
|
||||||
|
log_info "=== Dify Environment Variables Synchronization Script ==="
|
||||||
|
log_info "Execution started: $(date)"
|
||||||
|
|
||||||
|
# Check prerequisites
|
||||||
|
check_files
|
||||||
|
|
||||||
|
# Create backup
|
||||||
|
create_backup
|
||||||
|
|
||||||
|
# Detect differences
|
||||||
|
detect_differences
|
||||||
|
|
||||||
|
# Detect removed variables (before sync)
|
||||||
|
detect_removed_variables
|
||||||
|
|
||||||
|
# Synchronize environment file
|
||||||
|
sync_env_file
|
||||||
|
|
||||||
|
# Show statistics
|
||||||
|
show_statistics
|
||||||
|
|
||||||
|
log_success "=== Synchronization process completed successfully ==="
|
||||||
|
log_info "Execution finished: $(date)"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Execute main function only when script is run directly
|
||||||
|
if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then
|
||||||
|
main "$@"
|
||||||
|
fi
|
||||||
|
|
@ -270,7 +270,7 @@ services:
|
||||||
|
|
||||||
# plugin daemon
|
# plugin daemon
|
||||||
plugin_daemon:
|
plugin_daemon:
|
||||||
image: langgenius/dify-plugin-daemon:0.5.1-local
|
image: langgenius/dify-plugin-daemon:0.5.2-local
|
||||||
restart: always
|
restart: always
|
||||||
environment:
|
environment:
|
||||||
# Use the shared environment variables.
|
# Use the shared environment variables.
|
||||||
|
|
|
||||||
|
|
@ -123,7 +123,7 @@ services:
|
||||||
|
|
||||||
# plugin daemon
|
# plugin daemon
|
||||||
plugin_daemon:
|
plugin_daemon:
|
||||||
image: langgenius/dify-plugin-daemon:0.5.1-local
|
image: langgenius/dify-plugin-daemon:0.5.2-local
|
||||||
restart: always
|
restart: always
|
||||||
env_file:
|
env_file:
|
||||||
- ./middleware.env
|
- ./middleware.env
|
||||||
|
|
|
||||||
|
|
@ -134,6 +134,7 @@ x-shared-env: &shared-api-worker-env
|
||||||
ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-ap-southeast-1}
|
ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-ap-southeast-1}
|
||||||
ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4}
|
ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4}
|
||||||
ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-your-path}
|
ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-your-path}
|
||||||
|
ALIYUN_CLOUDBOX_ID: ${ALIYUN_CLOUDBOX_ID:-your-cloudbox-id}
|
||||||
TENCENT_COS_BUCKET_NAME: ${TENCENT_COS_BUCKET_NAME:-your-bucket-name}
|
TENCENT_COS_BUCKET_NAME: ${TENCENT_COS_BUCKET_NAME:-your-bucket-name}
|
||||||
TENCENT_COS_SECRET_KEY: ${TENCENT_COS_SECRET_KEY:-your-secret-key}
|
TENCENT_COS_SECRET_KEY: ${TENCENT_COS_SECRET_KEY:-your-secret-key}
|
||||||
TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id}
|
TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id}
|
||||||
|
|
@ -148,6 +149,7 @@ x-shared-env: &shared-api-worker-env
|
||||||
HUAWEI_OBS_SECRET_KEY: ${HUAWEI_OBS_SECRET_KEY:-your-secret-key}
|
HUAWEI_OBS_SECRET_KEY: ${HUAWEI_OBS_SECRET_KEY:-your-secret-key}
|
||||||
HUAWEI_OBS_ACCESS_KEY: ${HUAWEI_OBS_ACCESS_KEY:-your-access-key}
|
HUAWEI_OBS_ACCESS_KEY: ${HUAWEI_OBS_ACCESS_KEY:-your-access-key}
|
||||||
HUAWEI_OBS_SERVER: ${HUAWEI_OBS_SERVER:-your-server-url}
|
HUAWEI_OBS_SERVER: ${HUAWEI_OBS_SERVER:-your-server-url}
|
||||||
|
HUAWEI_OBS_PATH_STYLE: ${HUAWEI_OBS_PATH_STYLE:-false}
|
||||||
VOLCENGINE_TOS_BUCKET_NAME: ${VOLCENGINE_TOS_BUCKET_NAME:-your-bucket-name}
|
VOLCENGINE_TOS_BUCKET_NAME: ${VOLCENGINE_TOS_BUCKET_NAME:-your-bucket-name}
|
||||||
VOLCENGINE_TOS_SECRET_KEY: ${VOLCENGINE_TOS_SECRET_KEY:-your-secret-key}
|
VOLCENGINE_TOS_SECRET_KEY: ${VOLCENGINE_TOS_SECRET_KEY:-your-secret-key}
|
||||||
VOLCENGINE_TOS_ACCESS_KEY: ${VOLCENGINE_TOS_ACCESS_KEY:-your-access-key}
|
VOLCENGINE_TOS_ACCESS_KEY: ${VOLCENGINE_TOS_ACCESS_KEY:-your-access-key}
|
||||||
|
|
@ -939,7 +941,7 @@ services:
|
||||||
|
|
||||||
# plugin daemon
|
# plugin daemon
|
||||||
plugin_daemon:
|
plugin_daemon:
|
||||||
image: langgenius/dify-plugin-daemon:0.5.1-local
|
image: langgenius/dify-plugin-daemon:0.5.2-local
|
||||||
restart: always
|
restart: always
|
||||||
environment:
|
environment:
|
||||||
# Use the shared environment variables.
|
# Use the shared environment variables.
|
||||||
|
|
|
||||||
|
|
@ -1,48 +1,40 @@
|
||||||
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
|
# Dependencies
|
||||||
|
node_modules/
|
||||||
|
|
||||||
# dependencies
|
# Build output
|
||||||
/node_modules
|
dist/
|
||||||
/.pnp
|
|
||||||
.pnp.js
|
|
||||||
|
|
||||||
# testing
|
# Testing
|
||||||
/coverage
|
coverage/
|
||||||
|
|
||||||
# next.js
|
# IDE
|
||||||
/.next/
|
.idea/
|
||||||
/out/
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
|
||||||
# production
|
# OS
|
||||||
/build
|
|
||||||
|
|
||||||
# misc
|
|
||||||
.DS_Store
|
.DS_Store
|
||||||
*.pem
|
Thumbs.db
|
||||||
|
|
||||||
# debug
|
# Debug logs
|
||||||
npm-debug.log*
|
npm-debug.log*
|
||||||
yarn-debug.log*
|
yarn-debug.log*
|
||||||
yarn-error.log*
|
yarn-error.log*
|
||||||
.pnpm-debug.log*
|
pnpm-debug.log*
|
||||||
|
|
||||||
# local env files
|
# Environment
|
||||||
.env*.local
|
.env
|
||||||
|
.env.local
|
||||||
|
.env.*.local
|
||||||
|
|
||||||
# vercel
|
# TypeScript
|
||||||
.vercel
|
|
||||||
|
|
||||||
# typescript
|
|
||||||
*.tsbuildinfo
|
*.tsbuildinfo
|
||||||
next-env.d.ts
|
|
||||||
|
|
||||||
# npm
|
# Lock files (use pnpm-lock.yaml in CI if needed)
|
||||||
package-lock.json
|
package-lock.json
|
||||||
|
yarn.lock
|
||||||
|
|
||||||
# yarn
|
# Misc
|
||||||
.pnp.cjs
|
*.pem
|
||||||
.pnp.loader.mjs
|
*.tgz
|
||||||
.yarn/
|
|
||||||
.yarnrc.yml
|
|
||||||
|
|
||||||
# pmpm
|
|
||||||
pnpm-lock.yaml
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 LangGenius
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue