mirror of https://github.com/langgenius/dify.git
mr main
This commit is contained in:
commit
130163ca65
|
|
@ -49,10 +49,10 @@ pnpm test
|
|||
pnpm test:watch
|
||||
|
||||
# Run specific file
|
||||
pnpm test -- path/to/file.spec.tsx
|
||||
pnpm test path/to/file.spec.tsx
|
||||
|
||||
# Generate coverage report
|
||||
pnpm test -- --coverage
|
||||
pnpm test:coverage
|
||||
|
||||
# Analyze component complexity
|
||||
pnpm analyze-component <path>
|
||||
|
|
@ -155,7 +155,7 @@ describe('ComponentName', () => {
|
|||
For each file:
|
||||
┌────────────────────────────────────────┐
|
||||
│ 1. Write test │
|
||||
│ 2. Run: pnpm test -- <file>.spec.tsx │
|
||||
│ 2. Run: pnpm test <file>.spec.tsx │
|
||||
│ 3. PASS? → Mark complete, next file │
|
||||
│ FAIL? → Fix first, then continue │
|
||||
└────────────────────────────────────────┘
|
||||
|
|
|
|||
|
|
@ -198,7 +198,7 @@ describe('ComponentName', () => {
|
|||
})
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Async Operations (if component fetches data - useSWR, useQuery, fetch)
|
||||
// Async Operations (if component fetches data - useQuery, fetch)
|
||||
// --------------------------------------------------------------------------
|
||||
// WHY: Async operations have 3 states users experience: loading, success, error
|
||||
describe('Async Operations', () => {
|
||||
|
|
|
|||
|
|
@ -114,15 +114,15 @@ For the current file being tested:
|
|||
|
||||
**Run these checks after EACH test file, not just at the end:**
|
||||
|
||||
- [ ] Run `pnpm test -- path/to/file.spec.tsx` - **MUST PASS before next file**
|
||||
- [ ] Run `pnpm test path/to/file.spec.tsx` - **MUST PASS before next file**
|
||||
- [ ] Fix any failures immediately
|
||||
- [ ] Mark file as complete in todo list
|
||||
- [ ] Only then proceed to next file
|
||||
|
||||
### After All Files Complete
|
||||
|
||||
- [ ] Run full directory test: `pnpm test -- path/to/directory/`
|
||||
- [ ] Check coverage report: `pnpm test -- --coverage`
|
||||
- [ ] Run full directory test: `pnpm test path/to/directory/`
|
||||
- [ ] Check coverage report: `pnpm test:coverage`
|
||||
- [ ] Run `pnpm lint:fix` on all test files
|
||||
- [ ] Run `pnpm type-check:tsgo`
|
||||
|
||||
|
|
@ -186,16 +186,16 @@ Always test these scenarios:
|
|||
|
||||
```bash
|
||||
# Run specific test
|
||||
pnpm test -- path/to/file.spec.tsx
|
||||
pnpm test path/to/file.spec.tsx
|
||||
|
||||
# Run with coverage
|
||||
pnpm test -- --coverage path/to/file.spec.tsx
|
||||
pnpm test:coverage path/to/file.spec.tsx
|
||||
|
||||
# Watch mode
|
||||
pnpm test:watch -- path/to/file.spec.tsx
|
||||
pnpm test:watch path/to/file.spec.tsx
|
||||
|
||||
# Update snapshots (use sparingly)
|
||||
pnpm test -- -u path/to/file.spec.tsx
|
||||
pnpm test -u path/to/file.spec.tsx
|
||||
|
||||
# Analyze component
|
||||
pnpm analyze-component path/to/component.tsx
|
||||
|
|
|
|||
|
|
@ -242,32 +242,9 @@ describe('Component with Context', () => {
|
|||
})
|
||||
```
|
||||
|
||||
### 7. SWR / React Query
|
||||
### 7. React Query
|
||||
|
||||
```typescript
|
||||
// SWR
|
||||
vi.mock('swr', () => ({
|
||||
__esModule: true,
|
||||
default: vi.fn(),
|
||||
}))
|
||||
|
||||
import useSWR from 'swr'
|
||||
const mockedUseSWR = vi.mocked(useSWR)
|
||||
|
||||
describe('Component with SWR', () => {
|
||||
it('should show loading state', () => {
|
||||
mockedUseSWR.mockReturnValue({
|
||||
data: undefined,
|
||||
error: undefined,
|
||||
isLoading: true,
|
||||
})
|
||||
|
||||
render(<Component />)
|
||||
expect(screen.getByText(/loading/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// React Query
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
|
||||
const createTestQueryClient = () => new QueryClient({
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ When testing a **single component, hook, or utility**:
|
|||
2. Run `pnpm analyze-component <path>` (if available)
|
||||
3. Check complexity score and features detected
|
||||
4. Write the test file
|
||||
5. Run test: `pnpm test -- <file>.spec.tsx`
|
||||
5. Run test: `pnpm test <file>.spec.tsx`
|
||||
6. Fix any failures
|
||||
7. Verify coverage meets goals (100% function, >95% branch)
|
||||
```
|
||||
|
|
@ -80,7 +80,7 @@ Process files in this recommended order:
|
|||
```
|
||||
┌─────────────────────────────────────────────┐
|
||||
│ 1. Write test file │
|
||||
│ 2. Run: pnpm test -- <file>.spec.tsx │
|
||||
│ 2. Run: pnpm test <file>.spec.tsx │
|
||||
│ 3. If FAIL → Fix immediately, re-run │
|
||||
│ 4. If PASS → Mark complete in todo list │
|
||||
│ 5. ONLY THEN proceed to next file │
|
||||
|
|
@ -95,10 +95,10 @@ After all individual tests pass:
|
|||
|
||||
```bash
|
||||
# Run all tests in the directory together
|
||||
pnpm test -- path/to/directory/
|
||||
pnpm test path/to/directory/
|
||||
|
||||
# Check coverage
|
||||
pnpm test -- --coverage path/to/directory/
|
||||
pnpm test:coverage path/to/directory/
|
||||
```
|
||||
|
||||
## Component Complexity Guidelines
|
||||
|
|
@ -201,9 +201,9 @@ Run pnpm test ← Multiple failures, hard to debug
|
|||
```
|
||||
# GOOD: Incremental with verification
|
||||
Write component-a.spec.tsx
|
||||
Run pnpm test -- component-a.spec.tsx ✅
|
||||
Run pnpm test component-a.spec.tsx ✅
|
||||
Write component-b.spec.tsx
|
||||
Run pnpm test -- component-b.spec.tsx ✅
|
||||
Run pnpm test component-b.spec.tsx ✅
|
||||
...continue...
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -68,25 +68,4 @@ jobs:
|
|||
run: |
|
||||
uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md"
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Web dependencies
|
||||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: oxlint
|
||||
working-directory: ./web
|
||||
run: pnpm exec oxlint --config .oxlintrc.json --fix .
|
||||
|
||||
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
name: Check i18n Files and Create PR
|
||||
name: Translate i18n Files Based on English
|
||||
|
||||
on:
|
||||
push:
|
||||
|
|
@ -67,25 +67,19 @@ jobs:
|
|||
working-directory: ./web
|
||||
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
|
||||
|
||||
- name: Generate i18n type definitions
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run gen:i18n-types
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
commit-message: 'chore(i18n): update translations based on en-US changes'
|
||||
title: 'chore(i18n): translate i18n files and update type definitions'
|
||||
title: 'chore(i18n): translate i18n files based on en-US changes'
|
||||
body: |
|
||||
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
|
||||
This PR was automatically created to update i18n translation files based on changes in en-US locale.
|
||||
|
||||
**Triggered by:** ${{ github.sha }}
|
||||
|
||||
**Changes included:**
|
||||
- Updated translation files for all locales
|
||||
- Regenerated TypeScript type definitions for type safety
|
||||
branch: chore/automated-i18n-updates-${{ github.sha }}
|
||||
delete-branch: true
|
||||
|
|
|
|||
|
|
@ -38,11 +38,8 @@ jobs:
|
|||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Check i18n types synchronization
|
||||
run: pnpm run check:i18n-types
|
||||
|
||||
- name: Run tests
|
||||
run: pnpm test --coverage
|
||||
run: pnpm test:coverage
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
|
|
|
|||
|
|
@ -139,7 +139,6 @@ pyrightconfig.json
|
|||
.idea/'
|
||||
|
||||
.DS_Store
|
||||
web/.vscode/settings.json
|
||||
|
||||
# Intellij IDEA Files
|
||||
.idea/*
|
||||
|
|
@ -196,6 +195,7 @@ docker/nginx/ssl/*
|
|||
!docker/nginx/ssl/.gitkeep
|
||||
docker/middleware.env
|
||||
docker/docker-compose.override.yaml
|
||||
docker/env-backup/*
|
||||
|
||||
sdks/python-client/build
|
||||
sdks/python-client/dist
|
||||
|
|
@ -205,7 +205,6 @@ sdks/python-client/dify_client.egg-info
|
|||
!.vscode/launch.json.template
|
||||
!.vscode/README.md
|
||||
api/.vscode
|
||||
web/.vscode
|
||||
# vscode Code History Extension
|
||||
.history
|
||||
|
||||
|
|
@ -220,15 +219,6 @@ plugins.jsonl
|
|||
# mise
|
||||
mise.toml
|
||||
|
||||
# Next.js build output
|
||||
.next/
|
||||
|
||||
# PWA generated files
|
||||
web/public/sw.js
|
||||
web/public/sw.js.map
|
||||
web/public/workbox-*.js
|
||||
web/public/workbox-*.js.map
|
||||
web/public/fallback-*.js
|
||||
|
||||
# AI Assistant
|
||||
.roo/
|
||||
|
|
|
|||
|
|
@ -8,6 +8,11 @@ class HostedCreditConfig(BaseSettings):
|
|||
default="",
|
||||
)
|
||||
|
||||
HOSTED_POOL_CREDITS: int = Field(
|
||||
description="Pool credits for hosted service",
|
||||
default=200,
|
||||
)
|
||||
|
||||
def get_model_credits(self, model_name: str) -> int:
|
||||
"""
|
||||
Get credit value for a specific model name.
|
||||
|
|
@ -60,19 +65,46 @@ class HostedOpenAiConfig(BaseSettings):
|
|||
|
||||
HOSTED_OPENAI_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for trial access",
|
||||
default="gpt-3.5-turbo,"
|
||||
"gpt-3.5-turbo-1106,"
|
||||
"gpt-3.5-turbo-instruct,"
|
||||
default="gpt-4,"
|
||||
"gpt-4-turbo-preview,"
|
||||
"gpt-4-turbo-2024-04-09,"
|
||||
"gpt-4-1106-preview,"
|
||||
"gpt-4-0125-preview,"
|
||||
"gpt-4-turbo,"
|
||||
"gpt-4.1,"
|
||||
"gpt-4.1-2025-04-14,"
|
||||
"gpt-4.1-mini,"
|
||||
"gpt-4.1-mini-2025-04-14,"
|
||||
"gpt-4.1-nano,"
|
||||
"gpt-4.1-nano-2025-04-14,"
|
||||
"gpt-3.5-turbo,"
|
||||
"gpt-3.5-turbo-16k,"
|
||||
"gpt-3.5-turbo-16k-0613,"
|
||||
"gpt-3.5-turbo-1106,"
|
||||
"gpt-3.5-turbo-0613,"
|
||||
"gpt-3.5-turbo-0125,"
|
||||
"text-davinci-003",
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description="Quota limit for hosted OpenAI service usage",
|
||||
default=200,
|
||||
"gpt-3.5-turbo-instruct,"
|
||||
"text-davinci-003,"
|
||||
"chatgpt-4o-latest,"
|
||||
"gpt-4o,"
|
||||
"gpt-4o-2024-05-13,"
|
||||
"gpt-4o-2024-08-06,"
|
||||
"gpt-4o-2024-11-20,"
|
||||
"gpt-4o-audio-preview,"
|
||||
"gpt-4o-audio-preview-2025-06-03,"
|
||||
"gpt-4o-mini,"
|
||||
"gpt-4o-mini-2024-07-18,"
|
||||
"o3-mini,"
|
||||
"o3-mini-2025-01-31,"
|
||||
"gpt-5-mini-2025-08-07,"
|
||||
"gpt-5-mini,"
|
||||
"o4-mini,"
|
||||
"o4-mini-2025-04-16,"
|
||||
"gpt-5-chat-latest,"
|
||||
"gpt-5,"
|
||||
"gpt-5-2025-08-07,"
|
||||
"gpt-5-nano,"
|
||||
"gpt-5-nano-2025-08-07",
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
|
||||
|
|
@ -87,6 +119,13 @@ class HostedOpenAiConfig(BaseSettings):
|
|||
"gpt-4-turbo-2024-04-09,"
|
||||
"gpt-4-1106-preview,"
|
||||
"gpt-4-0125-preview,"
|
||||
"gpt-4-turbo,"
|
||||
"gpt-4.1,"
|
||||
"gpt-4.1-2025-04-14,"
|
||||
"gpt-4.1-mini,"
|
||||
"gpt-4.1-mini-2025-04-14,"
|
||||
"gpt-4.1-nano,"
|
||||
"gpt-4.1-nano-2025-04-14,"
|
||||
"gpt-3.5-turbo,"
|
||||
"gpt-3.5-turbo-16k,"
|
||||
"gpt-3.5-turbo-16k-0613,"
|
||||
|
|
@ -94,7 +133,150 @@ class HostedOpenAiConfig(BaseSettings):
|
|||
"gpt-3.5-turbo-0613,"
|
||||
"gpt-3.5-turbo-0125,"
|
||||
"gpt-3.5-turbo-instruct,"
|
||||
"text-davinci-003",
|
||||
"text-davinci-003,"
|
||||
"chatgpt-4o-latest,"
|
||||
"gpt-4o,"
|
||||
"gpt-4o-2024-05-13,"
|
||||
"gpt-4o-2024-08-06,"
|
||||
"gpt-4o-2024-11-20,"
|
||||
"gpt-4o-audio-preview,"
|
||||
"gpt-4o-audio-preview-2025-06-03,"
|
||||
"gpt-4o-mini,"
|
||||
"gpt-4o-mini-2024-07-18,"
|
||||
"o3-mini,"
|
||||
"o3-mini-2025-01-31,"
|
||||
"gpt-5-mini-2025-08-07,"
|
||||
"gpt-5-mini,"
|
||||
"o4-mini,"
|
||||
"o4-mini-2025-04-16,"
|
||||
"gpt-5-chat-latest,"
|
||||
"gpt-5,"
|
||||
"gpt-5-2025-08-07,"
|
||||
"gpt-5-nano,"
|
||||
"gpt-5-nano-2025-08-07",
|
||||
)
|
||||
|
||||
|
||||
class HostedGeminiConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for fetching Gemini service
|
||||
"""
|
||||
|
||||
HOSTED_GEMINI_API_KEY: str | None = Field(
|
||||
description="API key for hosted Gemini service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_API_BASE: str | None = Field(
|
||||
description="Base URL for hosted Gemini API",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_API_ORGANIZATION: str | None = Field(
|
||||
description="Organization ID for hosted Gemini service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_TRIAL_ENABLED: bool = Field(
|
||||
description="Enable trial access to hosted Gemini service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for trial access",
|
||||
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted gemini service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
|
||||
)
|
||||
|
||||
|
||||
class HostedXAIConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for fetching XAI service
|
||||
"""
|
||||
|
||||
HOSTED_XAI_API_KEY: str | None = Field(
|
||||
description="API key for hosted XAI service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_XAI_API_BASE: str | None = Field(
|
||||
description="Base URL for hosted XAI API",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_XAI_API_ORGANIZATION: str | None = Field(
|
||||
description="Organization ID for hosted XAI service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_XAI_TRIAL_ENABLED: bool = Field(
|
||||
description="Enable trial access to hosted XAI service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_XAI_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for trial access",
|
||||
default="grok-3,grok-3-mini,grok-3-mini-fast",
|
||||
)
|
||||
|
||||
HOSTED_XAI_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted XAI service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_XAI_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="grok-3,grok-3-mini,grok-3-mini-fast",
|
||||
)
|
||||
|
||||
|
||||
class HostedDeepseekConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for fetching Deepseek service
|
||||
"""
|
||||
|
||||
HOSTED_DEEPSEEK_API_KEY: str | None = Field(
|
||||
description="API key for hosted Deepseek service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_API_BASE: str | None = Field(
|
||||
description="Base URL for hosted Deepseek API",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_API_ORGANIZATION: str | None = Field(
|
||||
description="Organization ID for hosted Deepseek service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_TRIAL_ENABLED: bool = Field(
|
||||
description="Enable trial access to hosted Deepseek service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for trial access",
|
||||
default="deepseek-chat,deepseek-reasoner",
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted XAI service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="deepseek-chat,deepseek-reasoner",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -144,16 +326,66 @@ class HostedAnthropicConfig(BaseSettings):
|
|||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
|
||||
description="Quota limit for hosted Anthropic service usage",
|
||||
default=600000,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted Anthropic service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_ANTHROPIC_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="claude-opus-4-20250514,"
|
||||
"claude-sonnet-4-20250514,"
|
||||
"claude-3-5-haiku-20241022,"
|
||||
"claude-3-opus-20240229,"
|
||||
"claude-3-7-sonnet-20250219,"
|
||||
"claude-3-haiku-20240307",
|
||||
)
|
||||
HOSTED_ANTHROPIC_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="claude-opus-4-20250514,"
|
||||
"claude-sonnet-4-20250514,"
|
||||
"claude-3-5-haiku-20241022,"
|
||||
"claude-3-opus-20240229,"
|
||||
"claude-3-7-sonnet-20250219,"
|
||||
"claude-3-haiku-20240307",
|
||||
)
|
||||
|
||||
|
||||
class HostedTongyiConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for hosted Tongyi service
|
||||
"""
|
||||
|
||||
HOSTED_TONGYI_API_KEY: str | None = Field(
|
||||
description="API key for hosted Tongyi service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HOSTED_TONGYI_USE_INTERNATIONAL_ENDPOINT: bool = Field(
|
||||
description="Use international endpoint for hosted Tongyi service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_TONGYI_TRIAL_ENABLED: bool = Field(
|
||||
description="Enable trial access to hosted Anthropic service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_TONGYI_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted Anthropic service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_TONGYI_TRIAL_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for trial access",
|
||||
default="",
|
||||
)
|
||||
|
||||
HOSTED_TONGYI_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
class HostedMinmaxConfig(BaseSettings):
|
||||
"""
|
||||
|
|
@ -246,9 +478,13 @@ class HostedServiceConfig(
|
|||
HostedOpenAiConfig,
|
||||
HostedSparkConfig,
|
||||
HostedZhipuAIConfig,
|
||||
HostedTongyiConfig,
|
||||
# moderation
|
||||
HostedModerationConfig,
|
||||
# credit config
|
||||
HostedCreditConfig,
|
||||
HostedGeminiConfig,
|
||||
HostedXAIConfig,
|
||||
HostedDeepseekConfig,
|
||||
):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -572,7 +572,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"credential_id": data_source_info.get("credential_id"),
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
|
|
|
|||
|
|
@ -1,14 +1,32 @@
|
|||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.api_based_extension_fields import api_based_extension_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
from services.code_based_extension_service import CodeBasedExtensionService
|
||||
|
||||
from ..common.schema import register_schema_models
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, setup_required
|
||||
|
||||
|
||||
class CodeBasedExtensionQuery(BaseModel):
|
||||
module: str
|
||||
|
||||
|
||||
class APIBasedExtensionPayload(BaseModel):
|
||||
name: str = Field(description="Extension name")
|
||||
api_endpoint: str = Field(description="API endpoint URL")
|
||||
api_key: str = Field(description="API key for authentication")
|
||||
|
||||
|
||||
register_schema_models(console_ns, APIBasedExtensionPayload)
|
||||
|
||||
|
||||
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
|
||||
|
||||
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
|
||||
|
|
@ -18,11 +36,7 @@ api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_m
|
|||
class CodeBasedExtensionAPI(Resource):
|
||||
@console_ns.doc("get_code_based_extension")
|
||||
@console_ns.doc(description="Get code-based extension data by module name")
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument(
|
||||
"module", type=str, required=True, location="args", help="Extension module name"
|
||||
)
|
||||
)
|
||||
@console_ns.doc(params={"module": "Extension module name"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
|
|
@ -35,10 +49,9 @@ class CodeBasedExtensionAPI(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser().add_argument("module", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
|
||||
return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)}
|
||||
|
||||
|
||||
@console_ns.route("/api-based-extension")
|
||||
|
|
@ -56,30 +69,21 @@ class APIBasedExtensionAPI(Resource):
|
|||
|
||||
@console_ns.doc("create_api_based_extension")
|
||||
@console_ns.doc(description="Create a new API-based extension")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CreateAPIBasedExtensionRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="Extension name"),
|
||||
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
|
||||
"api_key": fields.String(required=True, description="API key for authentication"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
|
||||
@console_ns.response(201, "Extension created successfully", api_based_extension_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_model)
|
||||
def post(self):
|
||||
args = console_ns.payload
|
||||
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=current_tenant_id,
|
||||
name=args["name"],
|
||||
api_endpoint=args["api_endpoint"],
|
||||
api_key=args["api_key"],
|
||||
name=payload.name,
|
||||
api_endpoint=payload.api_endpoint,
|
||||
api_key=payload.api_key,
|
||||
)
|
||||
|
||||
return APIBasedExtensionService.save(extension_data)
|
||||
|
|
@ -104,16 +108,7 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
@console_ns.doc("update_api_based_extension")
|
||||
@console_ns.doc(description="Update API-based extension")
|
||||
@console_ns.doc(params={"id": "Extension ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateAPIBasedExtensionRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="Extension name"),
|
||||
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
|
||||
"api_key": fields.String(required=True, description="API key for authentication"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
|
||||
@console_ns.response(200, "Extension updated successfully", api_based_extension_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -125,13 +120,13 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
|
||||
args = console_ns.payload
|
||||
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
extension_data_from_db.name = args["name"]
|
||||
extension_data_from_db.api_endpoint = args["api_endpoint"]
|
||||
extension_data_from_db.name = payload.name
|
||||
extension_data_from_db.api_endpoint = payload.api_endpoint
|
||||
|
||||
if args["api_key"] != HIDDEN_VALUE:
|
||||
extension_data_from_db.api_key = args["api_key"]
|
||||
if payload.api_key != HIDDEN_VALUE:
|
||||
extension_data_from_db.api_key = payload.api_key
|
||||
|
||||
return APIBasedExtensionService.save(extension_data_from_db)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import io
|
||||
from typing import Literal
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request, send_file
|
||||
from flask_restx import Resource
|
||||
|
|
@ -141,6 +142,15 @@ class ParserDynamicOptions(BaseModel):
|
|||
provider_type: Literal["tool", "trigger"]
|
||||
|
||||
|
||||
class ParserDynamicOptionsWithCredentials(BaseModel):
|
||||
plugin_id: str
|
||||
provider: str
|
||||
action: str
|
||||
parameter: str
|
||||
credential_id: str
|
||||
credentials: Mapping[str, Any]
|
||||
|
||||
|
||||
class PluginPermissionSettingsPayload(BaseModel):
|
||||
install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
|
||||
debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
|
||||
|
|
@ -183,6 +193,7 @@ reg(ParserGithubUpgrade)
|
|||
reg(ParserUninstall)
|
||||
reg(ParserPermissionChange)
|
||||
reg(ParserDynamicOptions)
|
||||
reg(ParserDynamicOptionsWithCredentials)
|
||||
reg(ParserPreferencesChange)
|
||||
reg(ParserExcludePlugin)
|
||||
reg(ParserReadme)
|
||||
|
|
@ -657,6 +668,37 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
|||
return jsonable_encoder({"options": options})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options-with-credentials")
|
||||
class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserDynamicOptionsWithCredentials.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
"""Fetch dynamic options using credentials directly (for edit mode)."""
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
user_id = current_user.id
|
||||
|
||||
args = ParserDynamicOptionsWithCredentials.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
options = PluginParameterService.get_dynamic_select_options_with_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=args.plugin_id,
|
||||
provider=args.provider,
|
||||
action=args.action,
|
||||
parameter=args.parameter,
|
||||
credential_id=args.credential_id,
|
||||
credentials=args.credentials,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder({"options": options})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/change")
|
||||
class PluginChangePreferencesApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserPreferencesChange.__name__])
|
||||
|
|
|
|||
|
|
@ -1,11 +1,15 @@
|
|||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import make_response, redirect, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
|
|
@ -32,6 +36,32 @@ from ..wraps import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerSubscriptionUpdateRequest(BaseModel):
|
||||
"""Request payload for updating a trigger subscription"""
|
||||
|
||||
name: str | None = Field(default=None, description="The name for the subscription")
|
||||
credentials: Mapping[str, Any] | None = Field(default=None, description="The credentials for the subscription")
|
||||
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription")
|
||||
properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription")
|
||||
|
||||
|
||||
class TriggerSubscriptionVerifyRequest(BaseModel):
|
||||
"""Request payload for verifying subscription credentials."""
|
||||
|
||||
credentials: Mapping[str, Any] = Field(description="The credentials to verify")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
TriggerSubscriptionUpdateRequest.__name__,
|
||||
TriggerSubscriptionUpdateRequest.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
TriggerSubscriptionVerifyRequest.__name__,
|
||||
TriggerSubscriptionVerifyRequest.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/icon")
|
||||
class TriggerProviderIconApi(Resource):
|
||||
@setup_required
|
||||
|
|
@ -155,16 +185,16 @@ parser_api = (
|
|||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify-and-update/<path:subscription_builder_id>",
|
||||
)
|
||||
class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
|
||||
@console_ns.expect(parser_api)
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Verify a subscription instance for a trigger provider"""
|
||||
"""Verify and update a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
|
|
@ -289,6 +319,83 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
|||
raise ValueError(str(e)) from e
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/update",
|
||||
)
|
||||
class TriggerSubscriptionUpdateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TriggerSubscriptionUpdateRequest.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, subscription_id: str):
|
||||
"""Update a subscription instance"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
args = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
|
||||
|
||||
subscription = TriggerProviderService.get_subscription_by_id(
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
if not subscription:
|
||||
raise NotFoundError(f"Subscription {subscription_id} not found")
|
||||
|
||||
provider_id = TriggerProviderID(subscription.provider_id)
|
||||
|
||||
try:
|
||||
# rename only
|
||||
if (
|
||||
args.name is not None
|
||||
and args.credentials is None
|
||||
and args.parameters is None
|
||||
and args.properties is None
|
||||
):
|
||||
TriggerProviderService.update_trigger_subscription(
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
name=args.name,
|
||||
)
|
||||
return 200
|
||||
|
||||
# rebuild for create automatically by the provider
|
||||
match subscription.credential_type:
|
||||
case CredentialType.UNAUTHORIZED:
|
||||
TriggerProviderService.update_trigger_subscription(
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
name=args.name,
|
||||
properties=args.properties,
|
||||
)
|
||||
return 200
|
||||
case CredentialType.API_KEY | CredentialType.OAUTH2:
|
||||
if args.credentials:
|
||||
new_credentials: dict[str, Any] = {
|
||||
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
|
||||
for key, value in args.credentials.items()
|
||||
}
|
||||
else:
|
||||
new_credentials = subscription.credentials
|
||||
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=user.current_tenant_id,
|
||||
name=args.name,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription_id,
|
||||
credentials=new_credentials,
|
||||
parameters=args.parameters or subscription.parameters,
|
||||
)
|
||||
return 200
|
||||
case _:
|
||||
raise BadRequest("Invalid credential type")
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error updating subscription", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
|
||||
)
|
||||
|
|
@ -576,3 +683,38 @@ class TriggerOAuthClientManageApi(Resource):
|
|||
except Exception as e:
|
||||
logger.exception("Error removing OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/verify/<path:subscription_id>",
|
||||
)
|
||||
class TriggerSubscriptionVerifyApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TriggerSubscriptionVerifyRequest.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_id):
|
||||
"""Verify credentials for an existing subscription (edit mode only)"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
verify_request: TriggerSubscriptionVerifyRequest = TriggerSubscriptionVerifyRequest.model_validate(
|
||||
console_ns.payload
|
||||
)
|
||||
|
||||
try:
|
||||
result = TriggerProviderService.verify_subscription_credentials(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_id=subscription_id,
|
||||
credentials=verify_request.credentials,
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
logger.warning("Credential verification failed", exc_info=e)
|
||||
raise BadRequest(str(e)) from e
|
||||
except Exception as e:
|
||||
logger.exception("Error verifying subscription credentials", exc_info=e)
|
||||
raise BadRequest(str(e)) from e
|
||||
|
|
|
|||
|
|
@ -80,6 +80,9 @@ tenant_fields = {
|
|||
"in_trial": fields.Boolean,
|
||||
"trial_end_reason": fields.String,
|
||||
"custom_config": fields.Raw(attribute="custom_config"),
|
||||
"trial_credits": fields.Integer,
|
||||
"trial_credits_used": fields.Integer,
|
||||
"next_credit_reset_date": fields.Integer,
|
||||
}
|
||||
|
||||
tenants_fields = {
|
||||
|
|
|
|||
|
|
@ -2,10 +2,12 @@ import base64
|
|||
import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
EmailCodeError,
|
||||
|
|
@ -18,14 +20,40 @@ from controllers.console.error import EmailSendIpLimitError
|
|||
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
|
||||
from controllers.web import web_ns
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models import Account
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
class ForgotPasswordSendPayload(BaseModel):
|
||||
email: EmailStr
|
||||
language: str | None = None
|
||||
|
||||
|
||||
class ForgotPasswordCheckPayload(BaseModel):
|
||||
email: EmailStr
|
||||
code: str
|
||||
token: str = Field(min_length=1)
|
||||
|
||||
|
||||
class ForgotPasswordResetPayload(BaseModel):
|
||||
token: str = Field(min_length=1)
|
||||
new_password: str
|
||||
password_confirm: str
|
||||
|
||||
@field_validator("new_password", "password_confirm")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload)
|
||||
|
||||
|
||||
@web_ns.route("/forgot-password")
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
@web_ns.expect(web_ns.models[ForgotPasswordSendPayload.__name__])
|
||||
@only_edition_enterprise
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
|
|
@ -40,35 +68,31 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
if payload.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none()
|
||||
token = None
|
||||
if account is None:
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
|
||||
token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@web_ns.route("/forgot-password/validity")
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
@web_ns.expect(web_ns.models[ForgotPasswordCheckPayload.__name__])
|
||||
@only_edition_enterprise
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
|
|
@ -78,45 +102,40 @@ class ForgotPasswordCheckApi(Resource):
|
|||
responses={200: "Token is valid", 400: "Bad request - invalid token format", 401: "Invalid or expired token"}
|
||||
)
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
user_email = args["email"]
|
||||
user_email = payload.email
|
||||
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email)
|
||||
if is_forgot_password_error_rate_limit:
|
||||
raise EmailPasswordResetLimitError()
|
||||
|
||||
token_data = AccountService.get_reset_password_data(args["token"])
|
||||
token_data = AccountService.get_reset_password_data(payload.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args["code"] != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(args["email"])
|
||||
if payload.code != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(payload.email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
AccountService.revoke_reset_password_token(payload.token)
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_reset_password_token(
|
||||
user_email, code=args["code"], additional_data={"phase": "reset"}
|
||||
user_email, code=payload.code, additional_data={"phase": "reset"}
|
||||
)
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(args["email"])
|
||||
AccountService.reset_forgot_password_error_rate_limit(payload.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
@web_ns.route("/forgot-password/resets")
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
@web_ns.expect(web_ns.models[ForgotPasswordResetPayload.__name__])
|
||||
@only_edition_enterprise
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
|
|
@ -131,20 +150,14 @@ class ForgotPasswordResetApi(Resource):
|
|||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = ForgotPasswordResetPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
# Validate passwords match
|
||||
if args["new_password"] != args["password_confirm"]:
|
||||
if payload.new_password != payload.password_confirm:
|
||||
raise PasswordMismatchError()
|
||||
|
||||
# Validate token and get reset data
|
||||
reset_data = AccountService.get_reset_password_data(args["token"])
|
||||
reset_data = AccountService.get_reset_password_data(payload.token)
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
|
|
@ -152,11 +165,11 @@ class ForgotPasswordResetApi(Resource):
|
|||
raise InvalidTokenError()
|
||||
|
||||
# Revoke token to prevent reuse
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
AccountService.revoke_reset_password_token(payload.token)
|
||||
|
||||
# Generate secure salt and hash password
|
||||
salt = secrets.token_bytes(16)
|
||||
password_hashed = hash_password(args["new_password"], salt)
|
||||
password_hashed = hash_password(payload.new_password, salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
|
|
@ -170,7 +183,7 @@ class ForgotPasswordResetApi(Resource):
|
|||
|
||||
return {"result": "success"}
|
||||
|
||||
def _update_existing_account(self, account, password_hashed, salt, session):
|
||||
def _update_existing_account(self, account: Account, password_hashed, salt, session):
|
||||
# Update existing account credentials
|
||||
account.password = base64.b64encode(password_hashed).decode()
|
||||
account.password_salt = base64.b64encode(salt).decode()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from flask_restx import marshal_with, reqparse
|
||||
from flask_restx import marshal_with
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
import services
|
||||
from controllers.common import helpers
|
||||
|
|
@ -10,14 +11,23 @@ from controllers.common.errors import (
|
|||
RemoteFileUploadError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
|
||||
from services.file_service import FileService
|
||||
|
||||
from ..common.schema import register_schema_models
|
||||
from . import web_ns
|
||||
from .wraps import WebApiResource
|
||||
|
||||
|
||||
class RemoteFileUploadPayload(BaseModel):
|
||||
url: HttpUrl = Field(description="Remote file URL")
|
||||
|
||||
|
||||
register_schema_models(web_ns, RemoteFileUploadPayload)
|
||||
|
||||
|
||||
@web_ns.route("/remote-files/<path:url>")
|
||||
class RemoteFileInfoApi(WebApiResource):
|
||||
|
|
@ -97,10 +107,8 @@ class RemoteFileUploadApi(WebApiResource):
|
|||
FileTooLargeError: File exceeds size limit
|
||||
UnsupportedFileTypeError: File type not supported
|
||||
"""
|
||||
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
url = args["url"]
|
||||
payload = RemoteFileUploadPayload.model_validate(web_ns.payload or {})
|
||||
url = str(payload.url)
|
||||
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
|
|
|
|||
|
|
@ -72,6 +72,22 @@ def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client:
|
|||
)
|
||||
|
||||
|
||||
def _get_user_provided_host_header(headers: dict | None) -> str | None:
|
||||
"""
|
||||
Extract the user-provided Host header from the headers dict.
|
||||
|
||||
This is needed because when using a forward proxy, httpx may override the Host header.
|
||||
We preserve the user's explicit Host header to support virtual hosting and other use cases.
|
||||
"""
|
||||
if not headers:
|
||||
return None
|
||||
# Case-insensitive lookup for Host header
|
||||
for key, value in headers.items():
|
||||
if key.lower() == "host":
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
if "allow_redirects" in kwargs:
|
||||
allow_redirects = kwargs.pop("allow_redirects")
|
||||
|
|
@ -90,10 +106,24 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
|||
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
|
||||
client = _get_ssrf_client(verify_option)
|
||||
|
||||
# Preserve user-provided Host header
|
||||
# When using a forward proxy, httpx may override the Host header based on the URL.
|
||||
# We extract and preserve any explicitly set Host header to support virtual hosting.
|
||||
headers = kwargs.get("headers", {})
|
||||
user_provided_host = _get_user_provided_host_header(headers)
|
||||
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
# Build the request manually to preserve the Host header
|
||||
# httpx may override the Host header when using a proxy, so we use
|
||||
# the request API to explicitly set headers before sending
|
||||
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
|
||||
if user_provided_host is not None:
|
||||
headers["host"] = user_provided_host
|
||||
kwargs["headers"] = headers
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
|
||||
# Check for SSRF protection by Squid proxy
|
||||
if response.status_code in (401, 403):
|
||||
# Check if this is a Squid SSRF rejection
|
||||
|
|
|
|||
|
|
@ -56,6 +56,10 @@ class HostingConfiguration:
|
|||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/gemini/google"] = self.init_gemini()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/x/x"] = self.init_xai()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/deepseek/deepseek"] = self.init_deepseek()
|
||||
self.provider_map[f"{DEFAULT_PLUGIN_ID}/tongyi/tongyi"] = self.init_tongyi()
|
||||
|
||||
self.moderation_config = self.init_moderation_config()
|
||||
|
||||
|
|
@ -128,7 +132,7 @@ class HostingConfiguration:
|
|||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
|
||||
hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
|
||||
hosted_quota_limit = 0
|
||||
trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
|
||||
quotas.append(trial_quota)
|
||||
|
|
@ -156,18 +160,49 @@ class HostingConfiguration:
|
|||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_anthropic() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
def init_gemini(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_GEMINI_TRIAL_ENABLED:
|
||||
hosted_quota_limit = 0
|
||||
trial_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_GEMINI_PAID_ENABLED:
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"google_api_key": dify_config.HOSTED_GEMINI_API_KEY,
|
||||
}
|
||||
|
||||
if dify_config.HOSTED_GEMINI_API_BASE:
|
||||
credentials["google_base_url"] = dify_config.HOSTED_GEMINI_API_BASE
|
||||
|
||||
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_anthropic(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
|
||||
hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit)
|
||||
hosted_quota_limit = 0
|
||||
trail_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
|
||||
paid_quota = PaidHostingQuota()
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
|
|
@ -185,6 +220,94 @@ class HostingConfiguration:
|
|||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_tongyi(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_TONGYI_TRIAL_ENABLED:
|
||||
hosted_quota_limit = 0
|
||||
trail_models = self.parse_restrict_models_from_env("HOSTED_TONGYI_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"dashscope_api_key": dify_config.HOSTED_TONGYI_API_KEY,
|
||||
"use_international_endpoint": dify_config.HOSTED_TONGYI_USE_INTERNATIONAL_ENDPOINT,
|
||||
}
|
||||
|
||||
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_xai(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_XAI_TRIAL_ENABLED:
|
||||
hosted_quota_limit = 0
|
||||
trail_models = self.parse_restrict_models_from_env("HOSTED_XAI_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_XAI_PAID_ENABLED:
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_XAI_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"api_key": dify_config.HOSTED_XAI_API_KEY,
|
||||
}
|
||||
|
||||
if dify_config.HOSTED_XAI_API_BASE:
|
||||
credentials["endpoint_url"] = dify_config.HOSTED_XAI_API_BASE
|
||||
|
||||
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
def init_deepseek(self) -> HostingProvider:
|
||||
quota_unit = QuotaUnit.CREDITS
|
||||
quotas: list[HostingQuota] = []
|
||||
|
||||
if dify_config.HOSTED_DEEPSEEK_TRIAL_ENABLED:
|
||||
hosted_quota_limit = 0
|
||||
trail_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_DEEPSEEK_PAID_ENABLED:
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"api_key": dify_config.HOSTED_DEEPSEEK_API_KEY,
|
||||
}
|
||||
|
||||
if dify_config.HOSTED_DEEPSEEK_API_BASE:
|
||||
credentials["endpoint_url"] = dify_config.HOSTED_DEEPSEEK_API_BASE
|
||||
|
||||
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
|
||||
|
||||
return HostingProvider(
|
||||
enabled=False,
|
||||
quota_unit=quota_unit,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_minimax() -> HostingProvider:
|
||||
quota_unit = QuotaUnit.TOKENS
|
||||
|
|
|
|||
|
|
@ -396,7 +396,7 @@ class IndexingRunner:
|
|||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"credential_id": data_source_info.get("credential_id"),
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ class SSETransport:
|
|||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self.endpoint_url: str | None = None
|
||||
self.event_source: EventSource | None = None
|
||||
|
||||
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
|
||||
"""Validate that the endpoint URL matches the connection origin.
|
||||
|
|
@ -237,6 +238,9 @@ class SSETransport:
|
|||
write_queue: WriteQueue = queue.Queue()
|
||||
status_queue: StatusQueue = queue.Queue()
|
||||
|
||||
# Store event_source for graceful shutdown
|
||||
self.event_source = event_source
|
||||
|
||||
# Start SSE reader thread
|
||||
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
|
||||
|
||||
|
|
@ -296,6 +300,13 @@ def sse_client(
|
|||
logger.exception("Error connecting to SSE endpoint")
|
||||
raise
|
||||
finally:
|
||||
# Close the SSE connection to unblock the reader thread
|
||||
if transport.event_source is not None:
|
||||
try:
|
||||
transport.event_source.response.close()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# Clean up queues
|
||||
if read_queue:
|
||||
read_queue.put(None)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ and session management.
|
|||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Callable, Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
|
|
@ -103,6 +104,9 @@ class StreamableHTTPTransport:
|
|||
CONTENT_TYPE: JSON,
|
||||
**self.headers,
|
||||
}
|
||||
self.stop_event = threading.Event()
|
||||
self._active_responses: list[httpx.Response] = []
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
|
||||
"""Update headers with session ID if available."""
|
||||
|
|
@ -111,6 +115,30 @@ class StreamableHTTPTransport:
|
|||
headers[MCP_SESSION_ID] = self.session_id
|
||||
return headers
|
||||
|
||||
def _register_response(self, response: httpx.Response):
|
||||
"""Register a response for cleanup on shutdown."""
|
||||
with self._lock:
|
||||
self._active_responses.append(response)
|
||||
|
||||
def _unregister_response(self, response: httpx.Response):
|
||||
"""Unregister a response after it's closed."""
|
||||
with self._lock:
|
||||
try:
|
||||
self._active_responses.remove(response)
|
||||
except ValueError as e:
|
||||
logger.debug("Ignoring error during response unregister: %s", e)
|
||||
|
||||
def close_active_responses(self):
|
||||
"""Close all active SSE connections to unblock threads."""
|
||||
with self._lock:
|
||||
responses_to_close = list(self._active_responses)
|
||||
self._active_responses.clear()
|
||||
for response in responses_to_close:
|
||||
try:
|
||||
response.close()
|
||||
except RuntimeError as e:
|
||||
logger.debug("Ignoring error during active response close: %s", e)
|
||||
|
||||
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
||||
"""Check if the message is an initialization request."""
|
||||
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
|
||||
|
|
@ -195,11 +223,21 @@ class StreamableHTTPTransport:
|
|||
event_source.response.raise_for_status()
|
||||
logger.debug("GET SSE connection established")
|
||||
|
||||
for sse in event_source.iter_sse():
|
||||
self._handle_sse_event(sse, server_to_client_queue)
|
||||
# Register response for cleanup
|
||||
self._register_response(event_source.response)
|
||||
|
||||
try:
|
||||
for sse in event_source.iter_sse():
|
||||
if self.stop_event.is_set():
|
||||
logger.debug("GET stream received stop signal")
|
||||
break
|
||||
self._handle_sse_event(sse, server_to_client_queue)
|
||||
finally:
|
||||
self._unregister_response(event_source.response)
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug("GET stream error (non-fatal): %s", exc)
|
||||
if not self.stop_event.is_set():
|
||||
logger.debug("GET stream error (non-fatal): %s", exc)
|
||||
|
||||
def _handle_resumption_request(self, ctx: RequestContext):
|
||||
"""Handle a resumption request using GET with SSE."""
|
||||
|
|
@ -224,15 +262,24 @@ class StreamableHTTPTransport:
|
|||
event_source.response.raise_for_status()
|
||||
logger.debug("Resumption GET SSE connection established")
|
||||
|
||||
for sse in event_source.iter_sse():
|
||||
is_complete = self._handle_sse_event(
|
||||
sse,
|
||||
ctx.server_to_client_queue,
|
||||
original_request_id,
|
||||
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
|
||||
)
|
||||
if is_complete:
|
||||
break
|
||||
# Register response for cleanup
|
||||
self._register_response(event_source.response)
|
||||
|
||||
try:
|
||||
for sse in event_source.iter_sse():
|
||||
if self.stop_event.is_set():
|
||||
logger.debug("Resumption stream received stop signal")
|
||||
break
|
||||
is_complete = self._handle_sse_event(
|
||||
sse,
|
||||
ctx.server_to_client_queue,
|
||||
original_request_id,
|
||||
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
|
||||
)
|
||||
if is_complete:
|
||||
break
|
||||
finally:
|
||||
self._unregister_response(event_source.response)
|
||||
|
||||
def _handle_post_request(self, ctx: RequestContext):
|
||||
"""Handle a POST request with response processing."""
|
||||
|
|
@ -295,17 +342,27 @@ class StreamableHTTPTransport:
|
|||
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext):
|
||||
"""Handle SSE response from the server."""
|
||||
try:
|
||||
# Register response for cleanup
|
||||
self._register_response(response)
|
||||
|
||||
event_source = EventSource(response)
|
||||
for sse in event_source.iter_sse():
|
||||
is_complete = self._handle_sse_event(
|
||||
sse,
|
||||
ctx.server_to_client_queue,
|
||||
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
||||
)
|
||||
if is_complete:
|
||||
break
|
||||
try:
|
||||
for sse in event_source.iter_sse():
|
||||
if self.stop_event.is_set():
|
||||
logger.debug("SSE response stream received stop signal")
|
||||
break
|
||||
is_complete = self._handle_sse_event(
|
||||
sse,
|
||||
ctx.server_to_client_queue,
|
||||
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
||||
)
|
||||
if is_complete:
|
||||
break
|
||||
finally:
|
||||
self._unregister_response(response)
|
||||
except Exception as e:
|
||||
ctx.server_to_client_queue.put(e)
|
||||
if not self.stop_event.is_set():
|
||||
ctx.server_to_client_queue.put(e)
|
||||
|
||||
def _handle_unexpected_content_type(
|
||||
self,
|
||||
|
|
@ -345,6 +402,11 @@ class StreamableHTTPTransport:
|
|||
"""
|
||||
while True:
|
||||
try:
|
||||
# Check if we should stop
|
||||
if self.stop_event.is_set():
|
||||
logger.debug("Post writer received stop signal")
|
||||
break
|
||||
|
||||
# Read message from client queue with timeout to check stop_event periodically
|
||||
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||
if session_message is None:
|
||||
|
|
@ -381,7 +443,8 @@ class StreamableHTTPTransport:
|
|||
except queue.Empty:
|
||||
continue
|
||||
except Exception as exc:
|
||||
server_to_client_queue.put(exc)
|
||||
if not self.stop_event.is_set():
|
||||
server_to_client_queue.put(exc)
|
||||
|
||||
def terminate_session(self, client: httpx.Client):
|
||||
"""Terminate the session by sending a DELETE request."""
|
||||
|
|
@ -465,6 +528,12 @@ def streamablehttp_client(
|
|||
transport.get_session_id,
|
||||
)
|
||||
finally:
|
||||
# Set stop event to signal all threads to stop
|
||||
transport.stop_event.set()
|
||||
|
||||
# Close all active SSE connections to unblock threads
|
||||
transport.close_active_responses()
|
||||
|
||||
if transport.session_id and terminate_on_close:
|
||||
transport.terminate_session(client)
|
||||
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ def generate_dotted_order(run_id: str, start_time: Union[str, datetime], parent_
|
|||
generate dotted_order for langsmith
|
||||
"""
|
||||
start_time = datetime.fromisoformat(start_time) if isinstance(start_time, str) else start_time
|
||||
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f")[:-3] + "Z"
|
||||
timestamp = start_time.strftime("%Y%m%dT%H%M%S%f") + "Z"
|
||||
current_segment = f"{timestamp}{run_id}"
|
||||
|
||||
if parent_dotted_order is None:
|
||||
|
|
|
|||
|
|
@ -619,18 +619,18 @@ class ProviderManager:
|
|||
)
|
||||
|
||||
for quota in configuration.quotas:
|
||||
if quota.quota_type == ProviderQuotaType.TRIAL:
|
||||
if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID):
|
||||
# Init trial provider records if not exists
|
||||
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
|
||||
if quota.quota_type not in provider_quota_to_provider_record_dict:
|
||||
try:
|
||||
# FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic
|
||||
new_provider_record = Provider(
|
||||
tenant_id=tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
provider_name=ModelProviderID(provider_name).provider_name,
|
||||
provider_type=ProviderType.SYSTEM,
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_limit=quota.quota_limit, # type: ignore
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=quota.quota_type,
|
||||
quota_limit=0, # type: ignore
|
||||
quota_used=0,
|
||||
is_valid=True,
|
||||
)
|
||||
|
|
@ -642,8 +642,8 @@ class ProviderManager:
|
|||
stmt = select(Provider).where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == ModelProviderID(provider_name).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == quota.quota_type,
|
||||
)
|
||||
existed_provider_record = db.session.scalar(stmt)
|
||||
if not existed_provider_record:
|
||||
|
|
@ -913,6 +913,22 @@ class ProviderManager:
|
|||
provider_record
|
||||
)
|
||||
quota_configurations = []
|
||||
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
trail_pool = CreditPoolService.get_pool(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=ProviderQuotaType.TRIAL.value,
|
||||
)
|
||||
paid_pool = CreditPoolService.get_pool(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=ProviderQuotaType.PAID.value,
|
||||
)
|
||||
else:
|
||||
trail_pool = None
|
||||
paid_pool = None
|
||||
|
||||
for provider_quota in provider_hosting_configuration.quotas:
|
||||
if provider_quota.quota_type not in quota_type_to_provider_records_dict:
|
||||
if provider_quota.quota_type == ProviderQuotaType.FREE:
|
||||
|
|
@ -933,16 +949,36 @@ class ProviderManager:
|
|||
raise ValueError("quota_used is None")
|
||||
if provider_record.quota_limit is None:
|
||||
raise ValueError("quota_limit is None")
|
||||
if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=trail_pool.quota_used,
|
||||
quota_limit=trail_pool.quota_limit,
|
||||
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=provider_record.quota_used,
|
||||
quota_limit=provider_record.quota_limit,
|
||||
is_valid=provider_record.quota_limit > provider_record.quota_used
|
||||
or provider_record.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=paid_pool.quota_used,
|
||||
quota_limit=paid_pool.quota_limit,
|
||||
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
|
||||
else:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=provider_record.quota_used,
|
||||
quota_limit=provider_record.quota_limit,
|
||||
is_valid=provider_record.quota_limit > provider_record.quota_used
|
||||
or provider_record.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
|
||||
quota_configurations.append(quota_configuration)
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class FirecrawlApp:
|
|||
}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = self._post_request(f"{self.base_url}/v2/scrape", json_data, headers)
|
||||
response = self._post_request(self._build_url("v2/scrape"), json_data, headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
data = response_data["data"]
|
||||
|
|
@ -42,7 +42,7 @@ class FirecrawlApp:
|
|||
json_data = {"url": url}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = self._post_request(f"{self.base_url}/v2/crawl", json_data, headers)
|
||||
response = self._post_request(self._build_url("v2/crawl"), json_data, headers)
|
||||
if response.status_code == 200:
|
||||
# There's also another two fields in the response: "success" (bool) and "url" (str)
|
||||
job_id = response.json().get("id")
|
||||
|
|
@ -58,7 +58,7 @@ class FirecrawlApp:
|
|||
if params:
|
||||
# Pass through provided params, including optional "sitemap": "only" | "include" | "skip"
|
||||
json_data.update(params)
|
||||
response = self._post_request(f"{self.base_url}/v2/map", json_data, headers)
|
||||
response = self._post_request(self._build_url("v2/map"), json_data, headers)
|
||||
if response.status_code == 200:
|
||||
return cast(dict[str, Any], response.json())
|
||||
elif response.status_code in {402, 409, 500, 429, 408}:
|
||||
|
|
@ -69,7 +69,7 @@ class FirecrawlApp:
|
|||
|
||||
def check_crawl_status(self, job_id) -> dict[str, Any]:
|
||||
headers = self._prepare_headers()
|
||||
response = self._get_request(f"{self.base_url}/v2/crawl/{job_id}", headers)
|
||||
response = self._get_request(self._build_url(f"v2/crawl/{job_id}"), headers)
|
||||
if response.status_code == 200:
|
||||
crawl_status_response = response.json()
|
||||
if crawl_status_response.get("status") == "completed":
|
||||
|
|
@ -120,6 +120,10 @@ class FirecrawlApp:
|
|||
def _prepare_headers(self) -> dict[str, Any]:
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _build_url(self, path: str) -> str:
|
||||
# ensure exactly one slash between base and path, regardless of user-provided base_url
|
||||
return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
|
||||
|
||||
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
|
||||
for attempt in range(retries):
|
||||
response = httpx.post(url, headers=headers, json=data)
|
||||
|
|
@ -139,7 +143,11 @@ class FirecrawlApp:
|
|||
return response
|
||||
|
||||
def _handle_error(self, response, action):
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
try:
|
||||
payload = response.json()
|
||||
error_message = payload.get("error") or payload.get("message") or response.text or "Unknown error occurred"
|
||||
except json.JSONDecodeError:
|
||||
error_message = response.text or "Unknown error occurred"
|
||||
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return]
|
||||
|
||||
def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
|
|
@ -160,7 +168,7 @@ class FirecrawlApp:
|
|||
}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = self._post_request(f"{self.base_url}/v2/search", json_data, headers)
|
||||
response = self._post_request(self._build_url("v2/search"), json_data, headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
if not response_data.get("success"):
|
||||
|
|
|
|||
|
|
@ -48,13 +48,21 @@ class NotionExtractor(BaseExtractor):
|
|||
if notion_access_token:
|
||||
self._notion_access_token = notion_access_token
|
||||
else:
|
||||
self._notion_access_token = self._get_access_token(tenant_id, self._credential_id)
|
||||
if not self._notion_access_token:
|
||||
try:
|
||||
self._notion_access_token = self._get_access_token(tenant_id, self._credential_id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
(
|
||||
"Failed to get Notion access token from datasource credentials: %s, "
|
||||
"falling back to environment variable NOTION_INTEGRATION_TOKEN"
|
||||
),
|
||||
e,
|
||||
)
|
||||
integration_token = dify_config.NOTION_INTEGRATION_TOKEN
|
||||
if integration_token is None:
|
||||
raise ValueError(
|
||||
"Must specify `integration_token` or set environment variable `NOTION_INTEGRATION_TOKEN`."
|
||||
)
|
||||
) from e
|
||||
|
||||
self._notion_access_token = integration_token
|
||||
|
||||
|
|
|
|||
|
|
@ -153,11 +153,11 @@ class ToolInvokeMessage(BaseModel):
|
|||
@classmethod
|
||||
def transform_variable_value(cls, values):
|
||||
"""
|
||||
Only basic types and lists are allowed.
|
||||
Only basic types, lists, and None are allowed.
|
||||
"""
|
||||
value = values.get("variable_value")
|
||||
if not isinstance(value, dict | list | str | int | float | bool):
|
||||
raise ValueError("Only basic types and lists are allowed.")
|
||||
if value is not None and not isinstance(value, dict | list | str | int | float | bool):
|
||||
raise ValueError("Only basic types, lists, and None are allowed.")
|
||||
|
||||
# if stream is true, the value must be a string
|
||||
if values.get("stream"):
|
||||
|
|
|
|||
|
|
@ -67,12 +67,16 @@ def create_trigger_provider_encrypter_for_subscription(
|
|||
|
||||
|
||||
def delete_cache_for_subscription(tenant_id: str, provider_id: str, subscription_id: str):
|
||||
cache = TriggerProviderCredentialsCache(
|
||||
TriggerProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
credential_id=subscription_id,
|
||||
)
|
||||
cache.delete()
|
||||
).delete()
|
||||
TriggerProviderPropertiesCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription_id,
|
||||
).delete()
|
||||
|
||||
|
||||
def create_trigger_provider_encrypter_for_properties(
|
||||
|
|
|
|||
|
|
@ -247,6 +247,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
|||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
COMPLETED_REASON = "completed_reason" # completed reason for loop node
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.file.models import File
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
|
|
@ -136,21 +136,37 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
|
|||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from enum import StrEnum
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
||||
|
|
@ -96,3 +97,8 @@ class LoopState(BaseLoopState):
|
|||
Get current output.
|
||||
"""
|
||||
return self.current_output
|
||||
|
||||
|
||||
class LoopCompletedReason(StrEnum):
|
||||
LOOP_BREAK = "loop_break"
|
||||
LOOP_COMPLETED = "loop_completed"
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from core.workflow.node_events import (
|
|||
)
|
||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
|
||||
from core.workflow.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
|
@ -96,6 +96,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||
loop_duration_map: dict[str, float] = {}
|
||||
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
||||
loop_usage = LLMUsage.empty_usage()
|
||||
loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
|
||||
|
||||
# Start Loop event
|
||||
yield LoopStartedEvent(
|
||||
|
|
@ -118,6 +119,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||
loop_count = 0
|
||||
|
||||
for i in range(loop_count):
|
||||
# Clear stale variables from previous loop iterations to avoid streaming old values
|
||||
self._clear_loop_subgraph_variables(loop_node_ids)
|
||||
graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id)
|
||||
|
||||
loop_start_time = naive_utc_now()
|
||||
|
|
@ -177,7 +180,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency,
|
||||
"completed_reason": "loop_break" if reach_break_condition else "loop_completed",
|
||||
WorkflowNodeExecutionMetadataKey.COMPLETED_REASON: (
|
||||
LoopCompletedReason.LOOP_BREAK
|
||||
if reach_break_condition
|
||||
else LoopCompletedReason.LOOP_COMPLETED.value
|
||||
),
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
|
|
@ -274,6 +281,17 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||
if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata:
|
||||
event.node_run_result.metadata = {**current_metadata, **loop_metadata}
|
||||
|
||||
def _clear_loop_subgraph_variables(self, loop_node_ids: set[str]) -> None:
|
||||
"""
|
||||
Remove variables produced by loop sub-graph nodes from previous iterations.
|
||||
|
||||
Keeping stale variables causes a freshly created response coordinator in the
|
||||
next iteration to fall back to outdated values when no stream chunks exist.
|
||||
"""
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
for node_id in loop_node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
|
|
|||
|
|
@ -281,7 +281,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
|||
|
||||
# handle invoke result
|
||||
|
||||
text = invoke_result.message.content or ""
|
||||
text = invoke_result.message.get_text_content()
|
||||
if not isinstance(text, str):
|
||||
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
||||
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, SystemConfiguration
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
|
|
@ -134,22 +134,38 @@ def handle(sender: Message, **kwargs):
|
|||
system_configuration=system_configuration,
|
||||
model_name=model_config.model,
|
||||
)
|
||||
|
||||
if used_quota is not None:
|
||||
quota_update = _ProviderUpdateOperation(
|
||||
filters=_ProviderUpdateFilters(
|
||||
if provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=ModelProviderID(model_config.provider).provider_name,
|
||||
provider_type=ProviderType.SYSTEM,
|
||||
quota_type=provider_configuration.system_configuration.current_quota_type.value,
|
||||
),
|
||||
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
|
||||
additional_filters=_ProviderUpdateAdditionalFilters(
|
||||
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
|
||||
),
|
||||
description="quota_deduction_update",
|
||||
)
|
||||
updates_to_perform.append(quota_update)
|
||||
credits_required=used_quota,
|
||||
pool_type="trial",
|
||||
)
|
||||
elif provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
else:
|
||||
quota_update = _ProviderUpdateOperation(
|
||||
filters=_ProviderUpdateFilters(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=ModelProviderID(model_config.provider).provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=provider_configuration.system_configuration.current_quota_type.value,
|
||||
),
|
||||
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
|
||||
additional_filters=_ProviderUpdateAdditionalFilters(
|
||||
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
|
||||
),
|
||||
description="quota_deduction_update",
|
||||
)
|
||||
updates_to_perform.append(quota_update)
|
||||
|
||||
# Execute all updates
|
||||
start_time = time_module.perf_counter()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
"""add credit pool
|
||||
|
||||
Revision ID: 7df29de0f6be
|
||||
Revises: 03ea244985ce
|
||||
Create Date: 2025-12-25 10:39:15.139304
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '7df29de0f6be'
|
||||
down_revision = '03ea244985ce'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tenant_credit_pools',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False),
|
||||
sa.Column('quota_limit', sa.BigInteger(), nullable=False),
|
||||
sa.Column('quota_used', sa.BigInteger(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey')
|
||||
)
|
||||
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
|
||||
batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False)
|
||||
batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
|
||||
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
|
||||
batch_op.drop_index('tenant_credit_pool_tenant_id_idx')
|
||||
batch_op.drop_index('tenant_credit_pool_pool_type_idx')
|
||||
|
||||
op.drop_table('tenant_credit_pools')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -60,6 +60,7 @@ from .model import (
|
|||
Site,
|
||||
Tag,
|
||||
TagBinding,
|
||||
TenantCreditPool,
|
||||
TraceAppConfig,
|
||||
UploadFile,
|
||||
)
|
||||
|
|
@ -177,6 +178,7 @@ __all__ = [
|
|||
"Tenant",
|
||||
"TenantAccountJoin",
|
||||
"TenantAccountRole",
|
||||
"TenantCreditPool",
|
||||
"TenantDefaultModel",
|
||||
"TenantPreferredModelProvider",
|
||||
"TenantStatus",
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ from uuid import uuid4
|
|||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
||||
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -2065,3 +2065,29 @@ class TraceAppConfig(TypeBase):
|
|||
"created_at": str(self.created_at) if self.created_at else None,
|
||||
"updated_at": str(self.updated_at) if self.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
class TenantCreditPool(Base):
|
||||
__tablename__ = "tenant_credit_pools"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="tenant_credit_pool_pkey"),
|
||||
sa.Index("tenant_credit_pool_tenant_id_idx", "tenant_id"),
|
||||
sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
pool_type = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
|
||||
quota_limit = mapped_column(BigInteger, nullable=False, default=0)
|
||||
quota_used = mapped_column(BigInteger, nullable=False, default=0)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"))
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def remaining_credits(self) -> int:
|
||||
return max(0, self.quota_limit - self.quota_used)
|
||||
|
||||
def has_sufficient_credits(self, required_credits: int) -> bool:
|
||||
return self.remaining_credits >= required_credits
|
||||
|
|
|
|||
|
|
@ -999,6 +999,11 @@ class TenantService:
|
|||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
db.session.commit()
|
||||
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.create_default_pool(tenant.id)
|
||||
|
||||
return tenant
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -155,6 +155,7 @@ class AppDslService:
|
|||
parsed_url.scheme == "https"
|
||||
and parsed_url.netloc == "github.com"
|
||||
and parsed_url.path.endswith((".yml", ".yaml"))
|
||||
and "/blob/" in parsed_url.path
|
||||
):
|
||||
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
|
||||
yaml_url = yaml_url.replace("/blob/", "/")
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class FirecrawlAuth(ApiKeyAuthBase):
|
|||
"limit": 1,
|
||||
"scrapeOptions": {"onlyMainContent": True},
|
||||
}
|
||||
response = self._post_request(f"{self.base_url}/v1/crawl", options, headers)
|
||||
response = self._post_request(self._build_url("v1/crawl"), options, headers)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
|
|
@ -35,15 +35,17 @@ class FirecrawlAuth(ApiKeyAuthBase):
|
|||
def _prepare_headers(self):
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _build_url(self, path: str) -> str:
|
||||
# ensure exactly one slash between base and path, regardless of user-provided base_url
|
||||
return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
|
||||
|
||||
def _post_request(self, url, data, headers):
|
||||
return httpx.post(url, headers=headers, json=data)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in {402, 409, 500}:
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||
else:
|
||||
if response.text:
|
||||
error_message = json.loads(response.text).get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")
|
||||
try:
|
||||
payload = response.json()
|
||||
except json.JSONDecodeError:
|
||||
payload = {}
|
||||
error_message = payload.get("error") or payload.get("message") or (response.text or "Unknown error occurred")
|
||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,67 @@
|
|||
import logging
|
||||
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.errors.error import QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from models import TenantCreditPool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreditPoolService:
|
||||
@classmethod
|
||||
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
|
||||
"""create default credit pool for new tenant"""
|
||||
credit_pool = TenantCreditPool(
|
||||
tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
|
||||
)
|
||||
db.session.add(credit_pool)
|
||||
db.session.commit()
|
||||
return credit_pool
|
||||
|
||||
@classmethod
|
||||
def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None:
|
||||
"""get tenant credit pool"""
|
||||
return (
|
||||
db.session.query(TenantCreditPool)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=pool_type,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def check_and_deduct_credits(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
credits_required: int,
|
||||
pool_type: str = "trial",
|
||||
):
|
||||
"""check and deduct credits"""
|
||||
|
||||
pool = cls.get_pool(tenant_id, pool_type)
|
||||
if not pool:
|
||||
raise QuotaExceededError("Credit pool not found")
|
||||
|
||||
if pool.remaining_credits < credits_required:
|
||||
raise QuotaExceededError(
|
||||
f"Insufficient credits. Required: {credits_required}, Available: {pool.remaining_credits}"
|
||||
)
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
update_values = {"quota_used": pool.quota_used + credits_required}
|
||||
|
||||
where_conditions = [
|
||||
TenantCreditPool.pool_type == pool_type,
|
||||
TenantCreditPool.tenant_id == tenant_id,
|
||||
TenantCreditPool.quota_used + credits_required <= TenantCreditPool.quota_limit,
|
||||
]
|
||||
stmt = update(TenantCreditPool).where(*where_conditions).values(**update_values)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
except Exception:
|
||||
raise QuotaExceededError("Failed to deduct credits")
|
||||
|
|
@ -140,6 +140,7 @@ class FeatureModel(BaseModel):
|
|||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
knowledge_pipeline: KnowledgePipeline = KnowledgePipeline()
|
||||
next_credit_reset_date: int = 0
|
||||
|
||||
|
||||
class KnowledgeRateLimitModel(BaseModel):
|
||||
|
|
@ -301,6 +302,9 @@ class FeatureService:
|
|||
if "knowledge_pipeline_publish_enabled" in billing_info:
|
||||
features.knowledge_pipeline.publish_enabled = billing_info["knowledge_pipeline_publish_enabled"]
|
||||
|
||||
if "next_credit_reset_date" in billing_info:
|
||||
features.next_credit_reset_date = billing_info["next_credit_reset_date"]
|
||||
|
||||
@classmethod
|
||||
def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel):
|
||||
enterprise_info = EnterpriseService.get_info()
|
||||
|
|
|
|||
|
|
@ -105,3 +105,49 @@ class PluginParameterService:
|
|||
)
|
||||
.options
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_dynamic_select_options_with_credentials(
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
action: str,
|
||||
parameter: str,
|
||||
credential_id: str,
|
||||
credentials: Mapping[str, Any],
|
||||
) -> Sequence[PluginParameterOption]:
|
||||
"""
|
||||
Get dynamic select options using provided credentials directly.
|
||||
Used for edit mode when credentials have been modified but not yet saved.
|
||||
|
||||
Security: credential_id is validated against tenant_id to ensure
|
||||
users can only access their own credentials.
|
||||
"""
|
||||
from constants import HIDDEN_VALUE
|
||||
|
||||
# Get original subscription to replace hidden values (with tenant_id check for security)
|
||||
original_subscription = TriggerProviderService.get_subscription_by_id(tenant_id, credential_id)
|
||||
if not original_subscription:
|
||||
raise ValueError(f"Subscription {credential_id} not found")
|
||||
|
||||
# Replace [__HIDDEN__] with original values
|
||||
resolved_credentials: dict[str, Any] = {
|
||||
key: (original_subscription.credentials.get(key) if value == HIDDEN_VALUE else value)
|
||||
for key, value in credentials.items()
|
||||
}
|
||||
|
||||
return (
|
||||
DynamicSelectClient()
|
||||
.fetch_dynamic_select_options(
|
||||
tenant_id,
|
||||
user_id,
|
||||
plugin_id,
|
||||
provider,
|
||||
action,
|
||||
resolved_credentials,
|
||||
CredentialType.API_KEY.value,
|
||||
parameter,
|
||||
)
|
||||
.options
|
||||
)
|
||||
|
|
|
|||
|
|
@ -94,16 +94,23 @@ class TriggerProviderService:
|
|||
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
for subscription in subscriptions:
|
||||
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
subscription.credentials = dict(
|
||||
encrypter.mask_credentials(dict(encrypter.decrypt(subscription.credentials)))
|
||||
credential_encrypter.mask_credentials(dict(credential_encrypter.decrypt(subscription.credentials)))
|
||||
)
|
||||
subscription.properties = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.properties))))
|
||||
subscription.parameters = dict(encrypter.mask_credentials(dict(encrypter.decrypt(subscription.parameters))))
|
||||
properties_encrypter, _ = create_trigger_provider_encrypter_for_properties(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
subscription.properties = dict(
|
||||
properties_encrypter.mask_credentials(dict(properties_encrypter.decrypt(subscription.properties)))
|
||||
)
|
||||
subscription.parameters = dict(subscription.parameters)
|
||||
count = workflows_in_use_map.get(subscription.id)
|
||||
subscription.workflows_in_use = count if count is not None else 0
|
||||
|
||||
|
|
@ -209,6 +216,101 @@ class TriggerProviderService:
|
|||
logger.exception("Failed to add trigger provider")
|
||||
raise ValueError(str(e))
|
||||
|
||||
@classmethod
|
||||
def update_trigger_subscription(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
subscription_id: str,
|
||||
name: str | None = None,
|
||||
properties: Mapping[str, Any] | None = None,
|
||||
parameters: Mapping[str, Any] | None = None,
|
||||
credentials: Mapping[str, Any] | None = None,
|
||||
credential_expires_at: int | None = None,
|
||||
expires_at: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update an existing trigger subscription.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param subscription_id: Subscription instance ID
|
||||
:param name: Optional new name for this subscription
|
||||
:param properties: Optional new properties
|
||||
:param parameters: Optional new parameters
|
||||
:param credentials: Optional new credentials
|
||||
:param credential_expires_at: Optional new credential expiration timestamp
|
||||
:param expires_at: Optional new expiration timestamp
|
||||
:return: Success response with updated subscription info
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Use distributed lock to prevent race conditions on the same subscription
|
||||
lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}"
|
||||
with redis_client.lock(lock_key, timeout=20):
|
||||
subscription: TriggerSubscription | None = (
|
||||
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError(f"Trigger subscription {subscription_id} not found")
|
||||
|
||||
provider_id = TriggerProviderID(subscription.provider_id)
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
|
||||
# Check for name uniqueness if name is being updated
|
||||
if name is not None and name != subscription.name:
|
||||
existing = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Subscription name '{name}' already exists for this provider")
|
||||
subscription.name = name
|
||||
|
||||
# Update properties if provided
|
||||
if properties is not None:
|
||||
properties_encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_properties_schema(),
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
# Handle hidden values - preserve original encrypted values
|
||||
original_properties = properties_encrypter.decrypt(subscription.properties)
|
||||
new_properties: dict[str, Any] = {
|
||||
key: value if value != HIDDEN_VALUE else original_properties.get(key, UNKNOWN_VALUE)
|
||||
for key, value in properties.items()
|
||||
}
|
||||
subscription.properties = dict(properties_encrypter.encrypt(new_properties))
|
||||
|
||||
# Update parameters if provided
|
||||
if parameters is not None:
|
||||
subscription.parameters = dict(parameters)
|
||||
|
||||
# Update credentials if provided
|
||||
if credentials is not None:
|
||||
credential_type = CredentialType.of(subscription.credential_type)
|
||||
credential_encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credential_schema_config(credential_type),
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
subscription.credentials = dict(credential_encrypter.encrypt(dict(credentials)))
|
||||
|
||||
# Update credential expiration timestamp if provided
|
||||
if credential_expires_at is not None:
|
||||
subscription.credential_expires_at = credential_expires_at
|
||||
|
||||
# Update expiration timestamp if provided
|
||||
if expires_at is not None:
|
||||
subscription.expires_at = expires_at
|
||||
|
||||
session.commit()
|
||||
|
||||
# Clear subscription cache
|
||||
delete_cache_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=subscription.provider_id,
|
||||
subscription_id=subscription.id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_subscription_by_id(cls, tenant_id: str, subscription_id: str | None = None) -> TriggerSubscription | None:
|
||||
"""
|
||||
|
|
@ -257,17 +359,18 @@ class TriggerProviderService:
|
|||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
|
||||
credential_type: CredentialType = CredentialType.of(subscription.credential_type)
|
||||
provider_id = TriggerProviderID(subscription.provider_id)
|
||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
|
||||
is_auto_created: bool = credential_type in [CredentialType.OAUTH2, CredentialType.API_KEY]
|
||||
if is_auto_created:
|
||||
provider_id = TriggerProviderID(subscription.provider_id)
|
||||
provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=tenant_id, provider_id=provider_id
|
||||
)
|
||||
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
try:
|
||||
TriggerManager.unsubscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -280,8 +383,8 @@ class TriggerProviderService:
|
|||
except Exception as e:
|
||||
logger.exception("Error unsubscribing trigger", exc_info=e)
|
||||
|
||||
# Clear cache
|
||||
session.delete(subscription)
|
||||
# Clear cache
|
||||
delete_cache_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=subscription.provider_id,
|
||||
|
|
@ -688,3 +791,125 @@ class TriggerProviderService:
|
|||
)
|
||||
subscription.properties = dict(properties_encrypter.decrypt(subscription.properties))
|
||||
return subscription
|
||||
|
||||
@classmethod
|
||||
def verify_subscription_credentials(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_id: str,
|
||||
credentials: Mapping[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Verify credentials for an existing subscription without updating it.
|
||||
|
||||
This is used in edit mode to validate new credentials before rebuild.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param user_id: User ID
|
||||
:param provider_id: Provider identifier
|
||||
:param subscription_id: Subscription ID
|
||||
:param credentials: New credentials to verify
|
||||
:return: dict with 'verified' boolean
|
||||
"""
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
subscription = cls.get_subscription_by_id(
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError(f"Subscription {subscription_id} not found")
|
||||
|
||||
credential_type = CredentialType.of(subscription.credential_type)
|
||||
|
||||
# For API Key, validate the new credentials
|
||||
if credential_type == CredentialType.API_KEY:
|
||||
new_credentials: dict[str, Any] = {
|
||||
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
|
||||
for key, value in credentials.items()
|
||||
}
|
||||
try:
|
||||
provider_controller.validate_credentials(user_id, credentials=new_credentials)
|
||||
return {"verified": True}
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid credentials: {e}") from e
|
||||
|
||||
return {"verified": True}
|
||||
|
||||
@classmethod
|
||||
def rebuild_trigger_subscription(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_id: str,
|
||||
credentials: Mapping[str, Any],
|
||||
parameters: Mapping[str, Any],
|
||||
name: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Create a subscription builder for rebuilding an existing subscription.
|
||||
|
||||
This method creates a builder pre-filled with data from the rebuild request,
|
||||
keeping the same subscription_id and endpoint_id so the webhook URL remains unchanged.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param name: Name for the subscription
|
||||
:param subscription_id: Subscription ID
|
||||
:param provider_id: Provider identifier
|
||||
:param credentials: Credentials for the subscription
|
||||
:param parameters: Parameters for the subscription
|
||||
:return: SubscriptionBuilderApiEntity
|
||||
"""
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
subscription = TriggerProviderService.get_subscription_by_id(
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError(f"Subscription {subscription_id} not found")
|
||||
|
||||
credential_type = CredentialType.of(subscription.credential_type)
|
||||
if credential_type not in [CredentialType.OAUTH2, CredentialType.API_KEY]:
|
||||
raise ValueError("Credential type not supported for rebuild")
|
||||
|
||||
# TODO: Trying to invoke update api of the plugin trigger provider
|
||||
|
||||
# FALLBACK: If the update api is not implemented, delete the previous subscription and create a new one
|
||||
|
||||
# Delete the previous subscription
|
||||
user_id = subscription.user_id
|
||||
TriggerManager.unsubscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
subscription=subscription.to_entity(),
|
||||
credentials=subscription.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
|
||||
# Create a new subscription with the same subscription_id and endpoint_id
|
||||
new_subscription: TriggerSubscriptionEntity = TriggerManager.subscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
endpoint=generate_plugin_trigger_endpoint_url(subscription.endpoint_id),
|
||||
parameters=parameters,
|
||||
credentials=credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
TriggerProviderService.update_trigger_subscription(
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription.id,
|
||||
name=name,
|
||||
parameters=parameters,
|
||||
credentials=credentials,
|
||||
properties=new_subscription.properties,
|
||||
expires_at=new_subscription.expires_at,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -453,11 +453,12 @@ class TriggerSubscriptionBuilderService:
|
|||
if not subscription_builder:
|
||||
return None
|
||||
|
||||
# response to validation endpoint
|
||||
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=subscription_builder.tenant_id, provider_id=TriggerProviderID(subscription_builder.provider_id)
|
||||
)
|
||||
try:
|
||||
# response to validation endpoint
|
||||
controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider(
|
||||
tenant_id=subscription_builder.tenant_id,
|
||||
provider_id=TriggerProviderID(subscription_builder.provider_id),
|
||||
)
|
||||
dispatch_response: TriggerDispatchResponse = controller.dispatch(
|
||||
request=request,
|
||||
subscription=subscription_builder.to_subscription(),
|
||||
|
|
|
|||
|
|
@ -31,7 +31,8 @@ class WorkspaceService:
|
|||
assert tenant_account_join is not None, "TenantAccountJoin not found"
|
||||
tenant_info["role"] = tenant_account_join.role
|
||||
|
||||
can_replace_logo = FeatureService.get_features(tenant.id).can_replace_logo
|
||||
feature = FeatureService.get_features(tenant.id)
|
||||
can_replace_logo = feature.can_replace_logo
|
||||
|
||||
if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]):
|
||||
base_url = dify_config.FILES_URL
|
||||
|
|
@ -46,5 +47,19 @@ class WorkspaceService:
|
|||
"remove_webapp_brand": remove_webapp_brand,
|
||||
"replace_webapp_logo": replace_webapp_logo,
|
||||
}
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
tenant_info["next_credit_reset_date"] = feature.next_credit_reset_date
|
||||
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
paid_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="paid")
|
||||
if paid_pool:
|
||||
tenant_info["trial_credits"] = paid_pool.quota_limit
|
||||
tenant_info["trial_credits_used"] = paid_pool.quota_used
|
||||
else:
|
||||
trial_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="trial")
|
||||
if trial_pool:
|
||||
tenant_info["trial_credits"] = trial_pool.quota_limit
|
||||
tenant_info["trial_credits_used"] = trial_pool.quota_used
|
||||
|
||||
return tenant_info
|
||||
|
|
|
|||
|
|
@ -0,0 +1,236 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView as FlaskMethodView
|
||||
|
||||
_NEEDS_METHOD_VIEW_CLEANUP = False
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = FlaskMethodView
|
||||
_NEEDS_METHOD_VIEW_CLEANUP = True
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from controllers.console.extension import (
|
||||
APIBasedExtensionAPI,
|
||||
APIBasedExtensionDetailAPI,
|
||||
CodeBasedExtensionAPI,
|
||||
)
|
||||
|
||||
if _NEEDS_METHOD_VIEW_CLEANUP:
|
||||
delattr(builtins, "MethodView")
|
||||
from models.account import AccountStatus
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
|
||||
|
||||
def _make_extension(
|
||||
*,
|
||||
name: str = "Sample Extension",
|
||||
api_endpoint: str = "https://example.com/api",
|
||||
api_key: str = "super-secret-key",
|
||||
) -> APIBasedExtension:
|
||||
extension = APIBasedExtension(
|
||||
tenant_id="tenant-123",
|
||||
name=name,
|
||||
api_endpoint=api_endpoint,
|
||||
api_key=api_key,
|
||||
)
|
||||
extension.id = f"{uuid.uuid4()}"
|
||||
extension.created_at = datetime.now(tz=UTC)
|
||||
return extension
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_console_guards(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
|
||||
"""Bypass console decorators so handlers can run in isolation."""
|
||||
|
||||
import controllers.console.extension as extension_module
|
||||
from controllers.console import wraps as wraps_module
|
||||
|
||||
account = MagicMock()
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.current_tenant_id = "tenant-123"
|
||||
account.id = "account-123"
|
||||
account.is_authenticated = True
|
||||
|
||||
monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD")
|
||||
monkeypatch.setattr("libs.login.dify_config.LOGIN_DISABLED", True)
|
||||
monkeypatch.delenv("INIT_PASSWORD", raising=False)
|
||||
monkeypatch.setattr(extension_module, "current_account_with_tenant", lambda: (account, "tenant-123"))
|
||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (account, "tenant-123"))
|
||||
|
||||
# The login_required decorator consults the shared LocalProxy in libs.login.
|
||||
monkeypatch.setattr("libs.login.current_user", account)
|
||||
monkeypatch.setattr("libs.login.check_csrf_token", lambda *_, **__: None)
|
||||
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _restx_mask_defaults(app: Flask):
|
||||
app.config.setdefault("RESTX_MASK_HEADER", "X-Fields")
|
||||
app.config.setdefault("RESTX_MASK_SWAGGER", False)
|
||||
|
||||
|
||||
def test_code_based_extension_get_returns_service_data(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
service_result = {"entrypoint": "main:agent"}
|
||||
service_mock = MagicMock(return_value=service_result)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.extension.CodeBasedExtensionService.get_code_based_extension",
|
||||
service_mock,
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/code-based-extension",
|
||||
method="GET",
|
||||
query_string={"module": "workflow.tools"},
|
||||
):
|
||||
response = CodeBasedExtensionAPI().get()
|
||||
|
||||
assert response == {"module": "workflow.tools", "data": service_result}
|
||||
service_mock.assert_called_once_with("workflow.tools")
|
||||
|
||||
|
||||
def test_api_based_extension_get_returns_tenant_extensions(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
extension = _make_extension(name="Weather API", api_key="abcdefghi123")
|
||||
service_mock = MagicMock(return_value=[extension])
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.extension.APIBasedExtensionService.get_all_by_tenant_id",
|
||||
service_mock,
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/api-based-extension", method="GET"):
|
||||
response = APIBasedExtensionAPI().get()
|
||||
|
||||
assert response[0]["id"] == extension.id
|
||||
assert response[0]["name"] == "Weather API"
|
||||
assert response[0]["api_endpoint"] == extension.api_endpoint
|
||||
assert response[0]["api_key"].startswith(extension.api_key[:3])
|
||||
service_mock.assert_called_once_with("tenant-123")
|
||||
|
||||
|
||||
def test_api_based_extension_post_creates_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
saved_extension = _make_extension(name="Docs API", api_key="saved-secret")
|
||||
save_mock = MagicMock(return_value=saved_extension)
|
||||
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock)
|
||||
|
||||
payload = {
|
||||
"name": "Docs API",
|
||||
"api_endpoint": "https://docs.example.com/hook",
|
||||
"api_key": "plain-secret",
|
||||
}
|
||||
|
||||
with app.test_request_context("/console/api/api-based-extension", method="POST", json=payload):
|
||||
response = APIBasedExtensionAPI().post()
|
||||
|
||||
args, _ = save_mock.call_args
|
||||
created_extension: APIBasedExtension = args[0]
|
||||
assert created_extension.tenant_id == "tenant-123"
|
||||
assert created_extension.name == payload["name"]
|
||||
assert created_extension.api_endpoint == payload["api_endpoint"]
|
||||
assert created_extension.api_key == payload["api_key"]
|
||||
assert response["name"] == saved_extension.name
|
||||
save_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_api_based_extension_detail_get_fetches_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
extension = _make_extension(name="Docs API", api_key="abcdefg12345")
|
||||
service_mock = MagicMock(return_value=extension)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
|
||||
service_mock,
|
||||
)
|
||||
|
||||
extension_id = uuid.uuid4()
|
||||
with app.test_request_context(f"/console/api/api-based-extension/{extension_id}", method="GET"):
|
||||
response = APIBasedExtensionDetailAPI().get(extension_id)
|
||||
|
||||
assert response["id"] == extension.id
|
||||
assert response["name"] == extension.name
|
||||
service_mock.assert_called_once_with("tenant-123", str(extension_id))
|
||||
|
||||
|
||||
def test_api_based_extension_detail_post_keeps_hidden_api_key(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
existing_extension = _make_extension(name="Docs API", api_key="keep-me")
|
||||
get_mock = MagicMock(return_value=existing_extension)
|
||||
save_mock = MagicMock(return_value=existing_extension)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
|
||||
get_mock,
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock)
|
||||
|
||||
payload = {
|
||||
"name": "Docs API Updated",
|
||||
"api_endpoint": "https://docs.example.com/v2",
|
||||
"api_key": HIDDEN_VALUE,
|
||||
}
|
||||
|
||||
extension_id = uuid.uuid4()
|
||||
with app.test_request_context(
|
||||
f"/console/api/api-based-extension/{extension_id}",
|
||||
method="POST",
|
||||
json=payload,
|
||||
):
|
||||
response = APIBasedExtensionDetailAPI().post(extension_id)
|
||||
|
||||
assert existing_extension.name == payload["name"]
|
||||
assert existing_extension.api_endpoint == payload["api_endpoint"]
|
||||
assert existing_extension.api_key == "keep-me"
|
||||
save_mock.assert_called_once_with(existing_extension)
|
||||
assert response["name"] == payload["name"]
|
||||
|
||||
|
||||
def test_api_based_extension_detail_post_updates_api_key_when_provided(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
existing_extension = _make_extension(name="Docs API", api_key="old-secret")
|
||||
get_mock = MagicMock(return_value=existing_extension)
|
||||
save_mock = MagicMock(return_value=existing_extension)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
|
||||
get_mock,
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock)
|
||||
|
||||
payload = {
|
||||
"name": "Docs API Updated",
|
||||
"api_endpoint": "https://docs.example.com/v2",
|
||||
"api_key": "new-secret",
|
||||
}
|
||||
|
||||
extension_id = uuid.uuid4()
|
||||
with app.test_request_context(
|
||||
f"/console/api/api-based-extension/{extension_id}",
|
||||
method="POST",
|
||||
json=payload,
|
||||
):
|
||||
response = APIBasedExtensionDetailAPI().post(extension_id)
|
||||
|
||||
assert existing_extension.api_key == "new-secret"
|
||||
save_mock.assert_called_once_with(existing_extension)
|
||||
assert response["name"] == payload["name"]
|
||||
|
||||
|
||||
def test_api_based_extension_detail_delete_removes_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
existing_extension = _make_extension()
|
||||
get_mock = MagicMock(return_value=existing_extension)
|
||||
delete_mock = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
|
||||
get_mock,
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.delete", delete_mock)
|
||||
|
||||
extension_id = uuid.uuid4()
|
||||
with app.test_request_context(
|
||||
f"/console/api/api-based-extension/{extension_id}",
|
||||
method="DELETE",
|
||||
):
|
||||
response, status = APIBasedExtensionDetailAPI().delete(extension_id)
|
||||
|
||||
delete_mock.assert_called_once_with(existing_extension)
|
||||
assert response == {"result": "success"}
|
||||
assert status == 204
|
||||
|
|
@ -0,0 +1,195 @@
|
|||
"""Unit tests for controllers.web.forgot_password endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import builtins
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
# Ensure flask_restx.api finds MethodView during import.
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _load_controller_module():
|
||||
"""Import controllers.web.forgot_password using a stub package."""
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
import sys
|
||||
from types import ModuleType
|
||||
|
||||
parent_module_name = "controllers.web"
|
||||
module_name = f"{parent_module_name}.forgot_password"
|
||||
|
||||
if parent_module_name not in sys.modules:
|
||||
from flask_restx import Namespace
|
||||
|
||||
stub = ModuleType(parent_module_name)
|
||||
stub.__file__ = "controllers/web/__init__.py"
|
||||
stub.__path__ = ["controllers/web"]
|
||||
stub.__package__ = "controllers"
|
||||
stub.__spec__ = importlib.util.spec_from_loader(parent_module_name, loader=None, is_package=True)
|
||||
stub.web_ns = Namespace("web", description="Web API", path="/")
|
||||
sys.modules[parent_module_name] = stub
|
||||
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
|
||||
forgot_password_module = _load_controller_module()
|
||||
ForgotPasswordCheckApi = forgot_password_module.ForgotPasswordCheckApi
|
||||
ForgotPasswordResetApi = forgot_password_module.ForgotPasswordResetApi
|
||||
ForgotPasswordSendEmailApi = forgot_password_module.ForgotPasswordSendEmailApi
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
"""Configure a minimal Flask app for request contexts."""
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_web_endpoint_guards():
|
||||
"""Stub enterprise and feature toggles used by route decorators."""
|
||||
|
||||
features = SimpleNamespace(enable_email_password_login=True)
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True),
|
||||
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=features),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_controller_db():
|
||||
"""Replace controller-level db reference with a simple stub."""
|
||||
|
||||
fake_db = SimpleNamespace(engine=MagicMock(name="engine"))
|
||||
fake_wraps_db = SimpleNamespace(
|
||||
session=MagicMock(query=MagicMock(return_value=MagicMock(first=MagicMock(return_value=True))))
|
||||
)
|
||||
with (
|
||||
patch("controllers.web.forgot_password.db", fake_db),
|
||||
patch("controllers.console.wraps.db", fake_wraps_db),
|
||||
):
|
||||
yield fake_db
|
||||
|
||||
|
||||
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email", return_value="reset-token")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="203.0.113.10")
|
||||
def test_send_reset_email_success(
|
||||
mock_extract_ip: MagicMock,
|
||||
mock_is_ip_limit: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
mock_send_email: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
"""POST /forgot-password returns token when email exists and limits allow."""
|
||||
|
||||
mock_account = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
mock_session.return_value.__enter__.return_value = session_ctx
|
||||
session_ctx.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
|
||||
with app.test_request_context(
|
||||
"/forgot-password",
|
||||
method="POST",
|
||||
json={"email": "user@example.com"},
|
||||
):
|
||||
response = ForgotPasswordSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "reset-token"}
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("203.0.113.10")
|
||||
mock_send_email.assert_called_once_with(account=mock_account, email="user@example.com", language="en-US")
|
||||
|
||||
|
||||
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token", return_value=({}, "new-token"))
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit", return_value=False)
|
||||
def test_check_token_success(
|
||||
mock_is_rate_limited: MagicMock,
|
||||
mock_get_data: MagicMock,
|
||||
mock_revoke: MagicMock,
|
||||
mock_generate: MagicMock,
|
||||
mock_reset_limit: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
"""POST /forgot-password/validity validates the code and refreshes token."""
|
||||
|
||||
mock_get_data.return_value = {"email": "user@example.com", "code": "123456"}
|
||||
|
||||
with app.test_request_context(
|
||||
"/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "code": "123456", "token": "old-token"},
|
||||
):
|
||||
response = ForgotPasswordCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
|
||||
mock_is_rate_limited.assert_called_once_with("user@example.com")
|
||||
mock_get_data.assert_called_once_with("old-token")
|
||||
mock_revoke.assert_called_once_with("old-token")
|
||||
mock_generate.assert_called_once_with(
|
||||
"user@example.com",
|
||||
code="123456",
|
||||
additional_data={"phase": "reset"},
|
||||
)
|
||||
mock_reset_limit.assert_called_once_with("user@example.com")
|
||||
|
||||
|
||||
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
|
||||
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_success(
|
||||
mock_get_data: MagicMock,
|
||||
mock_revoke_token: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
mock_token_bytes: MagicMock,
|
||||
mock_hash_password: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
"""POST /forgot-password/resets updates the stored password when token is valid."""
|
||||
|
||||
mock_get_data.return_value = {"email": "user@example.com", "phase": "reset"}
|
||||
account = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
mock_session.return_value.__enter__.return_value = session_ctx
|
||||
session_ctx.execute.return_value.scalar_one_or_none.return_value = account
|
||||
|
||||
with app.test_request_context(
|
||||
"/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "reset-token",
|
||||
"new_password": "StrongPass123!",
|
||||
"password_confirm": "StrongPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_data.assert_called_once_with("reset-token")
|
||||
mock_revoke_token.assert_called_once_with("reset-token")
|
||||
mock_token_bytes.assert_called_once_with(16)
|
||||
mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef")
|
||||
expected_password = base64.b64encode(b"hashed-value").decode()
|
||||
assert account.password == expected_password
|
||||
expected_salt = base64.b64encode(b"0123456789abcdef").decode()
|
||||
assert account.password_salt == expected_salt
|
||||
session_ctx.commit.assert_called_once()
|
||||
|
|
@ -96,7 +96,7 @@ class TestNotionExtractorAuthentication:
|
|||
def test_init_with_integration_token_fallback(self, mock_get_token, mock_config, mock_document_model):
|
||||
"""Test NotionExtractor falls back to integration token when credential not found."""
|
||||
# Arrange
|
||||
mock_get_token.return_value = None
|
||||
mock_get_token.side_effect = Exception("No credential id found")
|
||||
mock_config.NOTION_INTEGRATION_TOKEN = "integration-token-fallback"
|
||||
|
||||
# Act
|
||||
|
|
@ -105,7 +105,7 @@ class TestNotionExtractorAuthentication:
|
|||
notion_obj_id="page-456",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant-789",
|
||||
credential_id="cred-123",
|
||||
credential_id=None,
|
||||
document_model=mock_document_model,
|
||||
)
|
||||
|
||||
|
|
@ -117,7 +117,7 @@ class TestNotionExtractorAuthentication:
|
|||
def test_init_missing_credentials_raises_error(self, mock_get_token, mock_config, mock_document_model):
|
||||
"""Test NotionExtractor raises error when no credentials available."""
|
||||
# Arrange
|
||||
mock_get_token.return_value = None
|
||||
mock_get_token.side_effect = Exception("No credential id found")
|
||||
mock_config.NOTION_INTEGRATION_TOKEN = None
|
||||
|
||||
# Act & Assert
|
||||
|
|
@ -127,7 +127,7 @@ class TestNotionExtractorAuthentication:
|
|||
notion_obj_id="page-456",
|
||||
notion_page_type="page",
|
||||
tenant_id="tenant-789",
|
||||
credential_id="cred-123",
|
||||
credential_id=None,
|
||||
document_model=mock_document_model,
|
||||
)
|
||||
assert "Must specify `integration_token`" in str(exc_info.value)
|
||||
|
|
|
|||
|
|
@ -1,52 +1,109 @@
|
|||
import secrets
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
|
||||
from core.helper.ssrf_proxy import (
|
||||
SSRF_DEFAULT_MAX_RETRIES,
|
||||
_get_user_provided_host_header,
|
||||
make_request,
|
||||
)
|
||||
|
||||
|
||||
@patch("httpx.Client.request")
|
||||
def test_successful_request(mock_request):
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_successful_request(mock_get_client):
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_request.return_value = mock_response
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
response = make_request("GET", "http://example.com")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@patch("httpx.Client.request")
|
||||
def test_retry_exceed_max_retries(mock_request):
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_retry_exceed_max_retries(mock_get_client):
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
|
||||
side_effects = [mock_response] * SSRF_DEFAULT_MAX_RETRIES
|
||||
mock_request.side_effect = side_effects
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1)
|
||||
assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
|
||||
|
||||
|
||||
@patch("httpx.Client.request")
|
||||
def test_retry_logic_success(mock_request):
|
||||
side_effects = []
|
||||
class TestGetUserProvidedHostHeader:
|
||||
"""Tests for _get_user_provided_host_header function."""
|
||||
|
||||
for _ in range(SSRF_DEFAULT_MAX_RETRIES):
|
||||
status_code = secrets.choice(STATUS_FORCELIST)
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
side_effects.append(mock_response)
|
||||
def test_returns_none_when_headers_is_none(self):
|
||||
assert _get_user_provided_host_header(None) is None
|
||||
|
||||
mock_response_200 = MagicMock()
|
||||
mock_response_200.status_code = 200
|
||||
side_effects.append(mock_response_200)
|
||||
def test_returns_none_when_headers_is_empty(self):
|
||||
assert _get_user_provided_host_header({}) is None
|
||||
|
||||
mock_request.side_effect = side_effects
|
||||
def test_returns_none_when_host_header_not_present(self):
|
||||
headers = {"Content-Type": "application/json", "Authorization": "Bearer token"}
|
||||
assert _get_user_provided_host_header(headers) is None
|
||||
|
||||
response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES)
|
||||
def test_returns_host_header_lowercase(self):
|
||||
headers = {"host": "example.com"}
|
||||
assert _get_user_provided_host_header(headers) == "example.com"
|
||||
|
||||
def test_returns_host_header_uppercase(self):
|
||||
headers = {"HOST": "example.com"}
|
||||
assert _get_user_provided_host_header(headers) == "example.com"
|
||||
|
||||
def test_returns_host_header_mixed_case(self):
|
||||
headers = {"HoSt": "example.com"}
|
||||
assert _get_user_provided_host_header(headers) == "example.com"
|
||||
|
||||
def test_returns_host_header_from_multiple_headers(self):
|
||||
headers = {"Content-Type": "application/json", "Host": "api.example.com", "Authorization": "Bearer token"}
|
||||
assert _get_user_provided_host_header(headers) == "api.example.com"
|
||||
|
||||
def test_returns_first_host_header_when_duplicates(self):
|
||||
headers = {"host": "first.com", "Host": "second.com"}
|
||||
# Should return the first one encountered (iteration order is preserved in dict)
|
||||
result = _get_user_provided_host_header(headers)
|
||||
assert result in ("first.com", "second.com")
|
||||
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_host_header_preservation_without_user_header(mock_get_client):
|
||||
"""Test that when no Host header is provided, the default behavior is maintained."""
|
||||
mock_client = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
response = make_request("GET", "http://example.com")
|
||||
|
||||
assert response.status_code == 200
|
||||
# Host should not be set if not provided by user
|
||||
assert "Host" not in mock_request.headers or mock_request.headers.get("Host") is None
|
||||
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
def test_host_header_preservation_with_user_header(mock_get_client):
|
||||
"""Test that user-provided Host header is preserved in the request."""
|
||||
mock_client = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_client.send.return_value = mock_response
|
||||
mock_client.request.return_value = mock_response
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
custom_host = "custom.example.com:8080"
|
||||
response = make_request("GET", "http://example.com", headers={"Host": custom_host})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
|
||||
assert mock_request.call_args_list[0][1].get("method") == "GET"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
import re
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
|
||||
from core.ops.utils import generate_dotted_order, validate_project_name, validate_url, validate_url_with_path
|
||||
|
||||
|
||||
class TestValidateUrl:
|
||||
|
|
@ -136,3 +139,51 @@ class TestValidateProjectName:
|
|||
"""Test custom default name"""
|
||||
result = validate_project_name("", "Custom Default")
|
||||
assert result == "Custom Default"
|
||||
|
||||
|
||||
class TestGenerateDottedOrder:
|
||||
"""Test cases for generate_dotted_order function"""
|
||||
|
||||
def test_dotted_order_has_6_digit_microseconds(self):
|
||||
"""Test that timestamp includes full 6-digit microseconds for LangSmith API compatibility.
|
||||
|
||||
LangSmith API expects timestamps in format: YYYYMMDDTHHMMSSffffffZ (6-digit microseconds).
|
||||
Previously, the code truncated to 3 digits which caused API errors:
|
||||
'cannot parse .111 as .000000'
|
||||
"""
|
||||
start_time = datetime(2025, 12, 23, 4, 19, 55, 111000)
|
||||
run_id = "test-run-id"
|
||||
result = generate_dotted_order(run_id, start_time)
|
||||
|
||||
# Extract timestamp portion (before the run_id)
|
||||
timestamp_match = re.match(r"^(\d{8}T\d{6})(\d+)Z", result)
|
||||
assert timestamp_match is not None, "Timestamp format should match YYYYMMDDTHHMMSSffffffZ"
|
||||
|
||||
microseconds = timestamp_match.group(2)
|
||||
assert len(microseconds) == 6, f"Microseconds should be 6 digits, got {len(microseconds)}: {microseconds}"
|
||||
|
||||
def test_dotted_order_format_matches_langsmith_expected(self):
|
||||
"""Test that dotted_order format matches LangSmith API expected format."""
|
||||
start_time = datetime(2025, 1, 15, 10, 30, 45, 123456)
|
||||
run_id = "abc123"
|
||||
result = generate_dotted_order(run_id, start_time)
|
||||
|
||||
# LangSmith expects: YYYYMMDDTHHMMSSffffffZ followed by run_id
|
||||
assert result == "20250115T103045123456Zabc123"
|
||||
|
||||
def test_dotted_order_with_parent(self):
|
||||
"""Test dotted_order generation with parent order uses dot separator."""
|
||||
start_time = datetime(2025, 12, 23, 4, 19, 55, 111000)
|
||||
run_id = "child-run-id"
|
||||
parent_order = "20251223T041955000000Zparent-run-id"
|
||||
result = generate_dotted_order(run_id, start_time, parent_order)
|
||||
|
||||
assert result == "20251223T041955000000Zparent-run-id.20251223T041955111000Zchild-run-id"
|
||||
|
||||
def test_dotted_order_without_parent_has_no_dot(self):
|
||||
"""Test dotted_order generation without parent has no dot separator."""
|
||||
start_time = datetime(2025, 12, 23, 4, 19, 55, 111000)
|
||||
run_id = "test-run-id"
|
||||
result = generate_dotted_order(run_id, start_time, None)
|
||||
|
||||
assert "." not in result
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import os
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||
|
|
@ -25,3 +27,35 @@ def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
|
|||
|
||||
assert job_id is not None
|
||||
assert isinstance(job_id, str)
|
||||
|
||||
|
||||
def test_build_url_normalizes_slashes_for_crawl(mocker: MockerFixture):
|
||||
api_key = "fc-"
|
||||
base_urls = ["https://custom.firecrawl.dev", "https://custom.firecrawl.dev/"]
|
||||
for base in base_urls:
|
||||
app = FirecrawlApp(api_key=api_key, base_url=base)
|
||||
mock_post = mocker.patch("httpx.post")
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"id": "job123"}
|
||||
mock_post.return_value = mock_resp
|
||||
app.crawl_url("https://example.com", params=None)
|
||||
called_url = mock_post.call_args[0][0]
|
||||
assert called_url == "https://custom.firecrawl.dev/v2/crawl"
|
||||
|
||||
|
||||
def test_error_handler_handles_non_json_error_bodies(mocker: MockerFixture):
|
||||
api_key = "fc-"
|
||||
app = FirecrawlApp(api_key=api_key, base_url="https://custom.firecrawl.dev/")
|
||||
mock_post = mocker.patch("httpx.post")
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 404
|
||||
mock_resp.text = "Not Found"
|
||||
mock_resp.json.side_effect = Exception("Not JSON")
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
app.scrape_url("https://example.com")
|
||||
|
||||
# Should not raise a JSONDecodeError; current behavior reports status code only
|
||||
assert str(excinfo.value) == "Failed to scrape URL. Status code: 404"
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
|
|
@ -110,9 +111,11 @@ class TestFirecrawlAuth:
|
|||
@pytest.mark.parametrize(
|
||||
("status_code", "response_text", "has_json_error", "expected_error_contains"),
|
||||
[
|
||||
(403, '{"error": "Forbidden"}', True, "Failed to authorize. Status code: 403. Error: Forbidden"),
|
||||
(404, "", True, "Unexpected error occurred while trying to authorize. Status code: 404"),
|
||||
(401, "Not JSON", True, "Expecting value"), # JSON decode error
|
||||
(403, '{"error": "Forbidden"}', False, "Failed to authorize. Status code: 403. Error: Forbidden"),
|
||||
# empty body falls back to generic message
|
||||
(404, "", True, "Failed to authorize. Status code: 404. Error: Unknown error occurred"),
|
||||
# non-JSON body is surfaced directly
|
||||
(401, "Not JSON", True, "Failed to authorize. Status code: 401. Error: Not JSON"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
|
|
@ -124,12 +127,14 @@ class TestFirecrawlAuth:
|
|||
mock_response.status_code = status_code
|
||||
mock_response.text = response_text
|
||||
if has_json_error:
|
||||
mock_response.json.side_effect = Exception("Not JSON")
|
||||
mock_response.json.side_effect = json.JSONDecodeError("Not JSON", "", 0)
|
||||
else:
|
||||
mock_response.json.return_value = {"error": "Forbidden"}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
auth_instance.validate_credentials()
|
||||
assert expected_error_contains in str(exc_info.value)
|
||||
assert str(exc_info.value) == expected_error_contains
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exception_type", "exception_message"),
|
||||
|
|
@ -164,20 +169,21 @@ class TestFirecrawlAuth:
|
|||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_use_custom_base_url_in_validation(self, mock_post):
|
||||
"""Test that custom base URL is used in validation"""
|
||||
"""Test that custom base URL is used in validation and normalized"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
credentials = {
|
||||
"auth_type": "bearer",
|
||||
"config": {"api_key": "test_api_key_123", "base_url": "https://custom.firecrawl.dev"},
|
||||
}
|
||||
auth = FirecrawlAuth(credentials)
|
||||
result = auth.validate_credentials()
|
||||
for base in ("https://custom.firecrawl.dev", "https://custom.firecrawl.dev/"):
|
||||
credentials = {
|
||||
"auth_type": "bearer",
|
||||
"config": {"api_key": "test_api_key_123", "base_url": base},
|
||||
}
|
||||
auth = FirecrawlAuth(credentials)
|
||||
result = auth.validate_credentials()
|
||||
|
||||
assert result is True
|
||||
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
|
||||
assert result is True
|
||||
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
|
||||
|
|
|
|||
|
|
@ -619,8 +619,13 @@ class TestTenantService:
|
|||
mock_tenant_instance.name = "Test User's Workspace"
|
||||
mock_tenant_class.return_value = mock_tenant_instance
|
||||
|
||||
# Execute test
|
||||
TenantService.create_owner_tenant_if_not_exist(mock_account)
|
||||
# Mock the db import in CreditPoolService to avoid database connection
|
||||
with patch("services.credit_pool_service.db") as mock_credit_pool_db:
|
||||
mock_credit_pool_db.session.add = MagicMock()
|
||||
mock_credit_pool_db.session.commit = MagicMock()
|
||||
|
||||
# Execute test
|
||||
TenantService.create_owner_tenant_if_not_exist(mock_account)
|
||||
|
||||
# Verify tenant was created with correct parameters
|
||||
mock_db_dependencies["db"].session.add.assert_called()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,71 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
|
||||
from models import Account
|
||||
from services import app_dsl_service
|
||||
from services.app_dsl_service import AppDslService, ImportMode, ImportStatus
|
||||
|
||||
|
||||
def _build_response(url: str, status_code: int, content: bytes = b"") -> httpx.Response:
|
||||
request = httpx.Request("GET", url)
|
||||
return httpx.Response(status_code=status_code, request=request, content=content)
|
||||
|
||||
|
||||
def _pending_yaml_content(version: str = "99.0.0") -> bytes:
|
||||
return (f'version: "{version}"\nkind: app\napp:\n name: Loop Test\n mode: workflow\n').encode()
|
||||
|
||||
|
||||
def _account_mock() -> MagicMock:
|
||||
account = MagicMock(spec=Account)
|
||||
account.current_tenant_id = "tenant-1"
|
||||
return account
|
||||
|
||||
|
||||
def test_import_app_yaml_url_user_attachments_keeps_original_url(monkeypatch):
|
||||
yaml_url = "https://github.com/user-attachments/files/24290802/loop-test.yml"
|
||||
raw_url = "https://raw.githubusercontent.com/user-attachments/files/24290802/loop-test.yml"
|
||||
yaml_bytes = _pending_yaml_content()
|
||||
|
||||
def fake_get(url: str, **kwargs):
|
||||
if url == raw_url:
|
||||
return _build_response(url, status_code=404)
|
||||
assert url == yaml_url
|
||||
return _build_response(url, status_code=200, content=yaml_bytes)
|
||||
|
||||
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get)
|
||||
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.import_app(
|
||||
account=_account_mock(),
|
||||
import_mode=ImportMode.YAML_URL,
|
||||
yaml_url=yaml_url,
|
||||
)
|
||||
|
||||
assert result.status == ImportStatus.PENDING
|
||||
assert result.imported_dsl_version == "99.0.0"
|
||||
|
||||
|
||||
def test_import_app_yaml_url_github_blob_rewrites_to_raw(monkeypatch):
|
||||
yaml_url = "https://github.com/acme/repo/blob/main/app.yml"
|
||||
raw_url = "https://raw.githubusercontent.com/acme/repo/main/app.yml"
|
||||
yaml_bytes = _pending_yaml_content()
|
||||
|
||||
requested_urls: list[str] = []
|
||||
|
||||
def fake_get(url: str, **kwargs):
|
||||
requested_urls.append(url)
|
||||
assert url == raw_url
|
||||
return _build_response(url, status_code=200, content=yaml_bytes)
|
||||
|
||||
monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get)
|
||||
|
||||
service = AppDslService(MagicMock())
|
||||
result = service.import_app(
|
||||
account=_account_mock(),
|
||||
import_mode=ImportMode.YAML_URL,
|
||||
yaml_url=yaml_url,
|
||||
)
|
||||
|
||||
assert result.status == ImportStatus.PENDING
|
||||
assert requested_urls == [raw_url]
|
||||
|
|
@ -23,6 +23,10 @@ Welcome to the new `docker` directory for deploying Dify using Docker Compose. T
|
|||
- Navigate to the `docker` directory.
|
||||
- Copy the `.env.example` file to a new file named `.env` by running `cp .env.example .env`.
|
||||
- Customize the `.env` file as needed. Refer to the `.env.example` file for detailed configuration options.
|
||||
- **Optional (Recommended for upgrades)**:
|
||||
You may use the environment synchronization tool to help keep your `.env` file aligned with the latest `.env.example` updates, while preserving your custom settings.
|
||||
This is especially useful when upgrading Dify or managing a large, customized `.env` file.
|
||||
See the [Environment Variables Synchronization](#environment-variables-synchronization) section below.
|
||||
1. **Running the Services**:
|
||||
- Execute `docker compose up` from the `docker` directory to start the services.
|
||||
- To specify a vector database, set the `VECTOR_STORE` variable in your `.env` file to your desired vector database service, such as `milvus`, `weaviate`, or `opensearch`.
|
||||
|
|
@ -111,6 +115,47 @@ The `.env.example` file provided in the Docker setup is extensive and covers a w
|
|||
|
||||
- Each service like `nginx`, `redis`, `db`, and vector databases have specific environment variables that are directly referenced in the `docker-compose.yaml`.
|
||||
|
||||
### Environment Variables Synchronization
|
||||
|
||||
When upgrading Dify or pulling the latest changes, new environment variables may be introduced in `.env.example`.
|
||||
|
||||
To help keep your existing `.env` file up to date **without losing your custom values**, an optional environment variables synchronization tool is provided.
|
||||
|
||||
> This tool performs a **one-way synchronization** from `.env.example` to `.env`.
|
||||
> Existing values in `.env` are never overwritten automatically.
|
||||
|
||||
#### `dify-env-sync.sh` (Optional)
|
||||
|
||||
This script compares your current `.env` file with the latest `.env.example` template and helps safely apply new or updated environment variables.
|
||||
|
||||
**What it does**
|
||||
|
||||
- Creates a backup of the current `.env` file before making any changes
|
||||
- Synchronizes newly added environment variables from `.env.example`
|
||||
- Preserves all existing custom values in `.env`
|
||||
- Displays differences and variables removed from `.env.example` for review
|
||||
|
||||
**Backup behavior**
|
||||
|
||||
Before synchronization, the current `.env` file is saved to the `env-backup/` directory with a timestamped filename
|
||||
(e.g. `env-backup/.env.backup_20231218_143022`).
|
||||
|
||||
**When to use**
|
||||
|
||||
- After upgrading Dify to a newer version
|
||||
- When `.env.example` has been updated with new environment variables
|
||||
- When managing a large or heavily customized `.env` file
|
||||
|
||||
**Usage**
|
||||
|
||||
```bash
|
||||
# Grant execution permission (first time only)
|
||||
chmod +x dify-env-sync.sh
|
||||
|
||||
# Run the synchronization
|
||||
./dify-env-sync.sh
|
||||
```
|
||||
|
||||
### Additional Information
|
||||
|
||||
- **Continuous Improvement Phase**: We are actively seeking feedback from the community to refine and enhance the deployment process. As more users adopt this new method, we will continue to make improvements based on your experiences and suggestions.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,465 @@
|
|||
#!/bin/bash
|
||||
|
||||
# ================================================================
|
||||
# Dify Environment Variables Synchronization Script
|
||||
#
|
||||
# Features:
|
||||
# - Synchronize latest settings from .env.example to .env
|
||||
# - Preserve custom settings in existing .env
|
||||
# - Add new environment variables
|
||||
# - Detect removed environment variables
|
||||
# - Create backup files
|
||||
# ================================================================
|
||||
|
||||
set -eo pipefail # Exit on error and pipe failures (safer for complex variable handling)
|
||||
|
||||
# Error handling function
|
||||
# Arguments:
|
||||
# $1 - Line number where error occurred
|
||||
# $2 - Error code
|
||||
handle_error() {
|
||||
local line_no=$1
|
||||
local error_code=$2
|
||||
echo -e "\033[0;31m[ERROR]\033[0m Script error: line $line_no with error code $error_code" >&2
|
||||
echo -e "\033[0;31m[ERROR]\033[0m Debug info: current working directory $(pwd)" >&2
|
||||
exit $error_code
|
||||
}
|
||||
|
||||
# Set error trap
|
||||
trap 'handle_error ${LINENO} $?' ERR
|
||||
|
||||
# Color settings for output
|
||||
readonly RED='\033[0;31m'
|
||||
readonly GREEN='\033[0;32m'
|
||||
readonly YELLOW='\033[1;33m'
|
||||
readonly BLUE='\033[0;34m'
|
||||
readonly NC='\033[0m' # No Color
|
||||
|
||||
# Logging functions
|
||||
# Print informational message in blue
|
||||
# Arguments: $1 - Message to print
|
||||
log_info() {
|
||||
echo -e "${BLUE}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
# Print success message in green
|
||||
# Arguments: $1 - Message to print
|
||||
log_success() {
|
||||
echo -e "${GREEN}[SUCCESS]${NC} $1"
|
||||
}
|
||||
|
||||
# Print warning message in yellow
|
||||
# Arguments: $1 - Message to print
|
||||
log_warning() {
|
||||
echo -e "${YELLOW}[WARNING]${NC} $1" >&2
|
||||
}
|
||||
|
||||
# Print error message in red to stderr
|
||||
# Arguments: $1 - Message to print
|
||||
log_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1" >&2
|
||||
}
|
||||
|
||||
# Check for required files and create .env if missing
|
||||
# Verifies that .env.example exists and creates .env from template if needed
|
||||
check_files() {
|
||||
log_info "Checking required files..."
|
||||
|
||||
if [[ ! -f ".env.example" ]]; then
|
||||
log_error ".env.example file not found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ ! -f ".env" ]]; then
|
||||
log_warning ".env file does not exist. Creating from .env.example."
|
||||
cp ".env.example" ".env"
|
||||
log_success ".env file created"
|
||||
fi
|
||||
|
||||
log_success "Required files verified"
|
||||
}
|
||||
|
||||
# Create timestamped backup of .env file
|
||||
# Creates env-backup directory if needed and backs up current .env file
|
||||
create_backup() {
|
||||
local timestamp=$(date +"%Y%m%d_%H%M%S")
|
||||
local backup_dir="env-backup"
|
||||
|
||||
# Create backup directory if it doesn't exist
|
||||
if [[ ! -d "$backup_dir" ]]; then
|
||||
mkdir -p "$backup_dir"
|
||||
log_info "Created backup directory: $backup_dir"
|
||||
fi
|
||||
|
||||
if [[ -f ".env" ]]; then
|
||||
local backup_file="${backup_dir}/.env.backup_${timestamp}"
|
||||
cp ".env" "$backup_file"
|
||||
log_success "Backed up existing .env to $backup_file"
|
||||
fi
|
||||
}
|
||||
|
||||
# Detect differences between .env and .env.example (optimized for large files)
|
||||
detect_differences() {
|
||||
log_info "Detecting differences between .env and .env.example..."
|
||||
|
||||
# Create secure temporary directory
|
||||
local temp_dir=$(mktemp -d)
|
||||
local temp_diff="$temp_dir/env_diff"
|
||||
|
||||
# Store diff file path as global variable
|
||||
declare -g DIFF_FILE="$temp_diff"
|
||||
declare -g TEMP_DIR="$temp_dir"
|
||||
|
||||
# Initialize difference file
|
||||
> "$temp_diff"
|
||||
|
||||
# Use awk for efficient comparison (much faster for large files)
|
||||
local diff_count=$(awk -F= '
|
||||
BEGIN { OFS="\x01" }
|
||||
FNR==NR {
|
||||
if (!/^[[:space:]]*#/ && !/^[[:space:]]*$/ && /=/) {
|
||||
gsub(/^[[:space:]]+|[[:space:]]+$/, "", $1)
|
||||
key = $1
|
||||
value = substr($0, index($0,"=")+1)
|
||||
gsub(/^[[:space:]]+|[[:space:]]+$/, "", value)
|
||||
env_values[key] = value
|
||||
}
|
||||
next
|
||||
}
|
||||
{
|
||||
if (!/^[[:space:]]*#/ && !/^[[:space:]]*$/ && /=/) {
|
||||
gsub(/^[[:space:]]+|[[:space:]]+$/, "", $1)
|
||||
key = $1
|
||||
example_value = substr($0, index($0,"=")+1)
|
||||
gsub(/^[[:space:]]+|[[:space:]]+$/, "", example_value)
|
||||
|
||||
if (key in env_values && env_values[key] != example_value) {
|
||||
print key, env_values[key], example_value > "'$temp_diff'"
|
||||
diff_count++
|
||||
}
|
||||
}
|
||||
}
|
||||
END { print diff_count }
|
||||
' .env .env.example)
|
||||
|
||||
if [[ $diff_count -gt 0 ]]; then
|
||||
log_success "Detected differences in $diff_count environment variables"
|
||||
# Show detailed differences
|
||||
show_differences_detail
|
||||
else
|
||||
log_info "No differences detected"
|
||||
fi
|
||||
}
|
||||
|
||||
# Parse environment variable line
|
||||
# Extracts key-value pairs from .env file format lines
|
||||
# Arguments:
|
||||
# $1 - Line to parse
|
||||
# Returns:
|
||||
# 0 - Success, outputs "key|value" format
|
||||
# 1 - Skip (empty line, comment, or invalid format)
|
||||
parse_env_line() {
|
||||
local line="$1"
|
||||
local key=""
|
||||
local value=""
|
||||
|
||||
# Skip empty lines or comment lines
|
||||
[[ -z "$line" || "$line" =~ ^[[:space:]]*# ]] && return 1
|
||||
|
||||
# Split by =
|
||||
if [[ "$line" =~ ^([^=]+)=(.*)$ ]]; then
|
||||
key="${BASH_REMATCH[1]}"
|
||||
value="${BASH_REMATCH[2]}"
|
||||
|
||||
# Remove leading and trailing whitespace
|
||||
key=$(echo "$key" | sed 's/^[[:space:]]*//; s/[[:space:]]*$//')
|
||||
value=$(echo "$value" | sed 's/^[[:space:]]*//; s/[[:space:]]*$//')
|
||||
|
||||
if [[ -n "$key" ]]; then
|
||||
echo "$key|$value"
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
# Show detailed differences
|
||||
show_differences_detail() {
|
||||
log_info ""
|
||||
log_info "=== Environment Variable Differences ==="
|
||||
|
||||
# Read differences from the already created diff file
|
||||
if [[ ! -s "$DIFF_FILE" ]]; then
|
||||
log_info "No differences to display"
|
||||
return
|
||||
fi
|
||||
|
||||
# Display differences
|
||||
local count=1
|
||||
while IFS=$'\x01' read -r key env_value example_value; do
|
||||
echo ""
|
||||
echo -e "${YELLOW}[$count] $key${NC}"
|
||||
echo -e " ${GREEN}.env (current)${NC} : ${env_value}"
|
||||
echo -e " ${BLUE}.env.example (recommended)${NC}: ${example_value}"
|
||||
|
||||
# Analyze value changes
|
||||
analyze_value_change "$env_value" "$example_value"
|
||||
((count++))
|
||||
done < "$DIFF_FILE"
|
||||
|
||||
echo ""
|
||||
log_info "=== Difference Analysis Complete ==="
|
||||
log_info "Note: Consider changing to the recommended values above."
|
||||
log_info "Current implementation preserves .env values."
|
||||
echo ""
|
||||
}
|
||||
|
||||
# Analyze value changes
|
||||
analyze_value_change() {
|
||||
local current_value="$1"
|
||||
local recommended_value="$2"
|
||||
|
||||
# Analyze value characteristics
|
||||
local analysis=""
|
||||
|
||||
# Empty value check
|
||||
if [[ -z "$current_value" && -n "$recommended_value" ]]; then
|
||||
analysis=" ${RED}→ Setting from empty to recommended value${NC}"
|
||||
elif [[ -n "$current_value" && -z "$recommended_value" ]]; then
|
||||
analysis=" ${RED}→ Recommended value changed to empty${NC}"
|
||||
# Numeric check - using arithmetic evaluation for robust comparison
|
||||
elif [[ "$current_value" =~ ^[0-9]+$ && "$recommended_value" =~ ^[0-9]+$ ]]; then
|
||||
# Use arithmetic evaluation to handle leading zeros correctly
|
||||
if (( 10#$current_value < 10#$recommended_value )); then
|
||||
analysis=" ${BLUE}→ Numeric increase (${current_value} < ${recommended_value})${NC}"
|
||||
elif (( 10#$current_value > 10#$recommended_value )); then
|
||||
analysis=" ${YELLOW}→ Numeric decrease (${current_value} > ${recommended_value})${NC}"
|
||||
fi
|
||||
# Boolean check
|
||||
elif [[ "$current_value" =~ ^(true|false)$ && "$recommended_value" =~ ^(true|false)$ ]]; then
|
||||
if [[ "$current_value" != "$recommended_value" ]]; then
|
||||
analysis=" ${BLUE}→ Boolean value change (${current_value} → ${recommended_value})${NC}"
|
||||
fi
|
||||
# URL/endpoint check
|
||||
elif [[ "$current_value" =~ ^https?:// || "$recommended_value" =~ ^https?:// ]]; then
|
||||
analysis=" ${BLUE}→ URL/endpoint change${NC}"
|
||||
# File path check
|
||||
elif [[ "$current_value" =~ ^/ || "$recommended_value" =~ ^/ ]]; then
|
||||
analysis=" ${BLUE}→ File path change${NC}"
|
||||
else
|
||||
# Length comparison
|
||||
local current_len=${#current_value}
|
||||
local recommended_len=${#recommended_value}
|
||||
if [[ $current_len -ne $recommended_len ]]; then
|
||||
analysis=" ${YELLOW}→ String length change (${current_len} → ${recommended_len} characters)${NC}"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ -n "$analysis" ]]; then
|
||||
echo -e "$analysis"
|
||||
fi
|
||||
}
|
||||
|
||||
# Synchronize .env file with .env.example while preserving custom values
|
||||
# Creates a new .env file based on .env.example structure, preserving existing custom values
|
||||
# Global variables used: DIFF_FILE, TEMP_DIR
|
||||
sync_env_file() {
|
||||
log_info "Starting partial synchronization of .env file..."
|
||||
|
||||
local new_env_file=".env.new"
|
||||
local preserved_count=0
|
||||
local updated_count=0
|
||||
|
||||
# Pre-process diff file for efficient lookup
|
||||
local lookup_file=""
|
||||
if [[ -f "$DIFF_FILE" && -s "$DIFF_FILE" ]]; then
|
||||
lookup_file="${DIFF_FILE}.lookup"
|
||||
# Create sorted lookup file for fast search
|
||||
sort "$DIFF_FILE" > "$lookup_file"
|
||||
log_info "Created lookup file for $(wc -l < "$DIFF_FILE") preserved values"
|
||||
fi
|
||||
|
||||
# Use AWK for efficient processing (much faster than bash loop for large files)
|
||||
log_info "Processing $(wc -l < .env.example) lines with AWK..."
|
||||
|
||||
local preserved_keys_file="${TEMP_DIR}/preserved_keys"
|
||||
local awk_preserved_count_file="${TEMP_DIR}/awk_preserved_count"
|
||||
local awk_updated_count_file="${TEMP_DIR}/awk_updated_count"
|
||||
|
||||
awk -F'=' -v lookup_file="$lookup_file" -v preserved_file="$preserved_keys_file" \
|
||||
-v preserved_count_file="$awk_preserved_count_file" -v updated_count_file="$awk_updated_count_file" '
|
||||
BEGIN {
|
||||
preserved_count = 0
|
||||
updated_count = 0
|
||||
|
||||
# Load preserved values if lookup file exists
|
||||
if (lookup_file != "") {
|
||||
while ((getline line < lookup_file) > 0) {
|
||||
split(line, parts, "\x01")
|
||||
key = parts[1]
|
||||
value = parts[2]
|
||||
preserved_values[key] = value
|
||||
}
|
||||
close(lookup_file)
|
||||
}
|
||||
}
|
||||
|
||||
# Process each line
|
||||
{
|
||||
# Check if this is an environment variable line
|
||||
if (/^[[:space:]]*[A-Za-z_][A-Za-z0-9_]*[[:space:]]*=/) {
|
||||
# Extract key
|
||||
key = $1
|
||||
gsub(/^[[:space:]]+|[[:space:]]+$/, "", key)
|
||||
|
||||
# Check if key should be preserved
|
||||
if (key in preserved_values) {
|
||||
print key "=" preserved_values[key]
|
||||
print key > preserved_file
|
||||
preserved_count++
|
||||
} else {
|
||||
print $0
|
||||
updated_count++
|
||||
}
|
||||
} else {
|
||||
# Not an env var line, preserve as-is
|
||||
print $0
|
||||
}
|
||||
}
|
||||
|
||||
END {
|
||||
print preserved_count > preserved_count_file
|
||||
print updated_count > updated_count_file
|
||||
}
|
||||
' .env.example > "$new_env_file"
|
||||
|
||||
# Read counters and preserved keys
|
||||
if [[ -f "$awk_preserved_count_file" ]]; then
|
||||
preserved_count=$(cat "$awk_preserved_count_file")
|
||||
fi
|
||||
if [[ -f "$awk_updated_count_file" ]]; then
|
||||
updated_count=$(cat "$awk_updated_count_file")
|
||||
fi
|
||||
|
||||
# Show what was preserved
|
||||
if [[ -f "$preserved_keys_file" ]]; then
|
||||
while read -r key; do
|
||||
[[ -n "$key" ]] && log_info " Preserved: $key (.env value)"
|
||||
done < "$preserved_keys_file"
|
||||
fi
|
||||
|
||||
# Clean up lookup file
|
||||
[[ -n "$lookup_file" ]] && rm -f "$lookup_file"
|
||||
|
||||
# Replace the original .env file
|
||||
if mv "$new_env_file" ".env"; then
|
||||
log_success "Successfully created new .env file"
|
||||
else
|
||||
log_error "Failed to replace .env file"
|
||||
rm -f "$new_env_file"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Clean up difference file and temporary directory
|
||||
if [[ -n "${TEMP_DIR:-}" ]]; then
|
||||
rm -rf "${TEMP_DIR}"
|
||||
unset TEMP_DIR
|
||||
fi
|
||||
if [[ -n "${DIFF_FILE:-}" ]]; then
|
||||
unset DIFF_FILE
|
||||
fi
|
||||
|
||||
log_success "Partial synchronization of .env file completed"
|
||||
log_info " Preserved .env values: $preserved_count"
|
||||
log_info " Updated to .env.example values: $updated_count"
|
||||
}
|
||||
|
||||
# Detect removed environment variables
|
||||
detect_removed_variables() {
|
||||
log_info "Detecting removed environment variables..."
|
||||
|
||||
if [[ ! -f ".env" ]]; then
|
||||
return
|
||||
fi
|
||||
|
||||
# Use temporary files for efficient lookup
|
||||
local temp_dir="${TEMP_DIR:-$(mktemp -d)}"
|
||||
local temp_example_keys="$temp_dir/example_keys"
|
||||
local temp_current_keys="$temp_dir/current_keys"
|
||||
local cleanup_temp_dir=""
|
||||
|
||||
# Set flag if we created a new temp directory
|
||||
if [[ -z "${TEMP_DIR:-}" ]]; then
|
||||
cleanup_temp_dir="$temp_dir"
|
||||
fi
|
||||
|
||||
# Get keys from .env.example and .env, sorted for comm
|
||||
awk -F= '!/^[[:space:]]*#/ && /=/ {gsub(/^[[:space:]]+|[[:space:]]+$/, "", $1); print $1}' .env.example | sort > "$temp_example_keys"
|
||||
awk -F= '!/^[[:space:]]*#/ && /=/ {gsub(/^[[:space:]]+|[[:space:]]+$/, "", $1); print $1}' .env | sort > "$temp_current_keys"
|
||||
|
||||
# Get keys from existing .env and check for removals
|
||||
local removed_vars=()
|
||||
while IFS= read -r var; do
|
||||
removed_vars+=("$var")
|
||||
done < <(comm -13 "$temp_example_keys" "$temp_current_keys")
|
||||
|
||||
# Clean up temporary files if we created a new temp directory
|
||||
if [[ -n "$cleanup_temp_dir" ]]; then
|
||||
rm -rf "$cleanup_temp_dir"
|
||||
fi
|
||||
|
||||
if [[ ${#removed_vars[@]} -gt 0 ]]; then
|
||||
log_warning "The following environment variables have been removed from .env.example:"
|
||||
for var in "${removed_vars[@]}"; do
|
||||
log_warning " - $var"
|
||||
done
|
||||
log_warning "Consider manually removing these variables from .env"
|
||||
else
|
||||
log_success "No removed environment variables found"
|
||||
fi
|
||||
}
|
||||
|
||||
# Show statistics
|
||||
show_statistics() {
|
||||
log_info "Synchronization statistics:"
|
||||
|
||||
local total_example=$(grep -c "^[^#]*=" .env.example 2>/dev/null || echo "0")
|
||||
local total_env=$(grep -c "^[^#]*=" .env 2>/dev/null || echo "0")
|
||||
|
||||
log_info " .env.example environment variables: $total_example"
|
||||
log_info " .env environment variables: $total_env"
|
||||
}
|
||||
|
||||
# Main execution function
|
||||
# Orchestrates the complete synchronization process in the correct order
|
||||
main() {
|
||||
log_info "=== Dify Environment Variables Synchronization Script ==="
|
||||
log_info "Execution started: $(date)"
|
||||
|
||||
# Check prerequisites
|
||||
check_files
|
||||
|
||||
# Create backup
|
||||
create_backup
|
||||
|
||||
# Detect differences
|
||||
detect_differences
|
||||
|
||||
# Detect removed variables (before sync)
|
||||
detect_removed_variables
|
||||
|
||||
# Synchronize environment file
|
||||
sync_env_file
|
||||
|
||||
# Show statistics
|
||||
show_statistics
|
||||
|
||||
log_success "=== Synchronization process completed successfully ==="
|
||||
log_info "Execution finished: $(date)"
|
||||
}
|
||||
|
||||
# Execute main function only when script is run directly
|
||||
if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then
|
||||
main "$@"
|
||||
fi
|
||||
|
|
@ -1,48 +1,40 @@
|
|||
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
|
||||
# Dependencies
|
||||
node_modules/
|
||||
|
||||
# dependencies
|
||||
/node_modules
|
||||
/.pnp
|
||||
.pnp.js
|
||||
# Build output
|
||||
dist/
|
||||
|
||||
# testing
|
||||
/coverage
|
||||
# Testing
|
||||
coverage/
|
||||
|
||||
# next.js
|
||||
/.next/
|
||||
/out/
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# production
|
||||
/build
|
||||
|
||||
# misc
|
||||
# OS
|
||||
.DS_Store
|
||||
*.pem
|
||||
Thumbs.db
|
||||
|
||||
# debug
|
||||
# Debug logs
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
.pnpm-debug.log*
|
||||
pnpm-debug.log*
|
||||
|
||||
# local env files
|
||||
.env*.local
|
||||
# Environment
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
|
||||
# vercel
|
||||
.vercel
|
||||
|
||||
# typescript
|
||||
# TypeScript
|
||||
*.tsbuildinfo
|
||||
next-env.d.ts
|
||||
|
||||
# npm
|
||||
# Lock files (use pnpm-lock.yaml in CI if needed)
|
||||
package-lock.json
|
||||
yarn.lock
|
||||
|
||||
# yarn
|
||||
.pnp.cjs
|
||||
.pnp.loader.mjs
|
||||
.yarn/
|
||||
.yarnrc.yml
|
||||
|
||||
# pmpm
|
||||
pnpm-lock.yaml
|
||||
# Misc
|
||||
*.pem
|
||||
*.tgz
|
||||
|
|
|
|||
|
|
@ -0,0 +1,22 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangGenius
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
|
@ -13,54 +13,92 @@ npm install dify-client
|
|||
After installing the SDK, you can use it in your project like this:
|
||||
|
||||
```js
|
||||
import { DifyClient, ChatClient, CompletionClient } from 'dify-client'
|
||||
import {
|
||||
DifyClient,
|
||||
ChatClient,
|
||||
CompletionClient,
|
||||
WorkflowClient,
|
||||
KnowledgeBaseClient,
|
||||
WorkspaceClient
|
||||
} from 'dify-client'
|
||||
|
||||
const API_KEY = 'your-api-key-here'
|
||||
const user = `random-user-id`
|
||||
const API_KEY = 'your-app-api-key'
|
||||
const DATASET_API_KEY = 'your-dataset-api-key'
|
||||
const user = 'random-user-id'
|
||||
const query = 'Please tell me a short story in 10 words or less.'
|
||||
const remote_url_files = [{
|
||||
type: 'image',
|
||||
transfer_method: 'remote_url',
|
||||
url: 'your_url_address'
|
||||
}]
|
||||
|
||||
// Create a completion client
|
||||
const completionClient = new CompletionClient(API_KEY)
|
||||
// Create a completion message
|
||||
completionClient.createCompletionMessage({'query': query}, user)
|
||||
// Create a completion message with vision model
|
||||
completionClient.createCompletionMessage({'query': 'Describe the picture.'}, user, false, remote_url_files)
|
||||
|
||||
// Create a chat client
|
||||
const chatClient = new ChatClient(API_KEY)
|
||||
// Create a chat message in stream mode
|
||||
const response = await chatClient.createChatMessage({}, query, user, true, null)
|
||||
const stream = response.data;
|
||||
stream.on('data', data => {
|
||||
console.log(data);
|
||||
});
|
||||
stream.on('end', () => {
|
||||
console.log('stream done');
|
||||
});
|
||||
// Create a chat message with vision model
|
||||
chatClient.createChatMessage({}, 'Describe the picture.', user, false, null, remote_url_files)
|
||||
// Fetch conversations
|
||||
chatClient.getConversations(user)
|
||||
// Fetch conversation messages
|
||||
chatClient.getConversationMessages(conversationId, user)
|
||||
// Rename conversation
|
||||
chatClient.renameConversation(conversationId, name, user)
|
||||
|
||||
|
||||
const completionClient = new CompletionClient(API_KEY)
|
||||
const workflowClient = new WorkflowClient(API_KEY)
|
||||
const kbClient = new KnowledgeBaseClient(DATASET_API_KEY)
|
||||
const workspaceClient = new WorkspaceClient(DATASET_API_KEY)
|
||||
const client = new DifyClient(API_KEY)
|
||||
// Fetch application parameters
|
||||
client.getApplicationParameters(user)
|
||||
// Provide feedback for a message
|
||||
client.messageFeedback(messageId, rating, user)
|
||||
|
||||
// App core
|
||||
await client.getApplicationParameters(user)
|
||||
await client.messageFeedback('message-id', 'like', user)
|
||||
|
||||
// Completion (blocking)
|
||||
await completionClient.createCompletionMessage({
|
||||
inputs: { query },
|
||||
user,
|
||||
response_mode: 'blocking'
|
||||
})
|
||||
|
||||
// Chat (streaming)
|
||||
const stream = await chatClient.createChatMessage({
|
||||
inputs: {},
|
||||
query,
|
||||
user,
|
||||
response_mode: 'streaming'
|
||||
})
|
||||
for await (const event of stream) {
|
||||
console.log(event.event, event.data)
|
||||
}
|
||||
|
||||
// Chatflow (advanced chat via workflow_id)
|
||||
await chatClient.createChatMessage({
|
||||
inputs: {},
|
||||
query,
|
||||
user,
|
||||
workflow_id: 'workflow-id',
|
||||
response_mode: 'blocking'
|
||||
})
|
||||
|
||||
// Workflow run (blocking or streaming)
|
||||
await workflowClient.run({
|
||||
inputs: { query },
|
||||
user,
|
||||
response_mode: 'blocking'
|
||||
})
|
||||
|
||||
// Knowledge base (dataset token required)
|
||||
await kbClient.listDatasets({ page: 1, limit: 20 })
|
||||
await kbClient.createDataset({ name: 'KB', indexing_technique: 'economy' })
|
||||
|
||||
// RAG pipeline (may require service API route registration)
|
||||
const pipelineStream = await kbClient.runPipeline('dataset-id', {
|
||||
inputs: {},
|
||||
datasource_type: 'online_document',
|
||||
datasource_info_list: [],
|
||||
start_node_id: 'start-node-id',
|
||||
is_published: true,
|
||||
response_mode: 'streaming'
|
||||
})
|
||||
for await (const event of pipelineStream) {
|
||||
console.log(event.data)
|
||||
}
|
||||
|
||||
// Workspace models (dataset token required)
|
||||
await workspaceClient.getModelsByType('text-embedding')
|
||||
|
||||
```
|
||||
|
||||
Replace 'your-api-key-here' with your actual Dify API key.Replace 'your-app-id-here' with your actual Dify APP ID.
|
||||
Notes:
|
||||
|
||||
- App endpoints use an app API token; knowledge base and workspace endpoints use a dataset API token.
|
||||
- Chat/completion require a stable `user` identifier in the request payload.
|
||||
- For streaming responses, iterate the returned AsyncIterable. Use `stream.toText()` to collect text.
|
||||
|
||||
## License
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +0,0 @@
|
|||
module.exports = {
|
||||
presets: [
|
||||
[
|
||||
"@babel/preset-env",
|
||||
{
|
||||
targets: {
|
||||
node: "current",
|
||||
},
|
||||
},
|
||||
],
|
||||
],
|
||||
};
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
import js from "@eslint/js";
|
||||
import tsParser from "@typescript-eslint/parser";
|
||||
import tsPlugin from "@typescript-eslint/eslint-plugin";
|
||||
import { fileURLToPath } from "node:url";
|
||||
import path from "node:path";
|
||||
|
||||
const tsconfigRootDir = path.dirname(fileURLToPath(import.meta.url));
|
||||
const typeCheckedRules =
|
||||
tsPlugin.configs["recommended-type-checked"]?.rules ??
|
||||
tsPlugin.configs.recommendedTypeChecked?.rules ??
|
||||
{};
|
||||
|
||||
export default [
|
||||
{
|
||||
ignores: ["dist", "node_modules", "scripts", "tests", "**/*.test.*", "**/*.spec.*"],
|
||||
},
|
||||
js.configs.recommended,
|
||||
{
|
||||
files: ["src/**/*.ts"],
|
||||
languageOptions: {
|
||||
parser: tsParser,
|
||||
ecmaVersion: "latest",
|
||||
parserOptions: {
|
||||
project: "./tsconfig.json",
|
||||
tsconfigRootDir,
|
||||
sourceType: "module",
|
||||
},
|
||||
},
|
||||
plugins: {
|
||||
"@typescript-eslint": tsPlugin,
|
||||
},
|
||||
rules: {
|
||||
...tsPlugin.configs.recommended.rules,
|
||||
...typeCheckedRules,
|
||||
"no-undef": "off",
|
||||
"no-unused-vars": "off",
|
||||
"@typescript-eslint/no-unsafe-call": "error",
|
||||
"@typescript-eslint/no-unsafe-return": "error",
|
||||
"@typescript-eslint/consistent-type-imports": [
|
||||
"error",
|
||||
{ prefer: "type-imports", fixStyle: "separate-type-imports" },
|
||||
],
|
||||
},
|
||||
},
|
||||
];
|
||||
|
|
@ -1,107 +0,0 @@
|
|||
// Types.d.ts
|
||||
export const BASE_URL: string;
|
||||
|
||||
export type RequestMethods = 'GET' | 'POST' | 'PATCH' | 'DELETE';
|
||||
|
||||
interface Params {
|
||||
[key: string]: any;
|
||||
}
|
||||
|
||||
interface HeaderParams {
|
||||
[key: string]: string;
|
||||
}
|
||||
|
||||
interface User {
|
||||
}
|
||||
|
||||
interface DifyFileBase {
|
||||
type: "image"
|
||||
}
|
||||
|
||||
export interface DifyRemoteFile extends DifyFileBase {
|
||||
transfer_method: "remote_url"
|
||||
url: string
|
||||
}
|
||||
|
||||
export interface DifyLocalFile extends DifyFileBase {
|
||||
transfer_method: "local_file"
|
||||
upload_file_id: string
|
||||
}
|
||||
|
||||
export type DifyFile = DifyRemoteFile | DifyLocalFile;
|
||||
|
||||
export declare class DifyClient {
|
||||
constructor(apiKey: string, baseUrl?: string);
|
||||
|
||||
updateApiKey(apiKey: string): void;
|
||||
|
||||
sendRequest(
|
||||
method: RequestMethods,
|
||||
endpoint: string,
|
||||
data?: any,
|
||||
params?: Params,
|
||||
stream?: boolean,
|
||||
headerParams?: HeaderParams
|
||||
): Promise<any>;
|
||||
|
||||
messageFeedback(message_id: string, rating: number, user: User): Promise<any>;
|
||||
|
||||
getApplicationParameters(user: User): Promise<any>;
|
||||
|
||||
fileUpload(data: FormData): Promise<any>;
|
||||
|
||||
textToAudio(text: string ,user: string, streaming?: boolean): Promise<any>;
|
||||
|
||||
getMeta(user: User): Promise<any>;
|
||||
}
|
||||
|
||||
export declare class CompletionClient extends DifyClient {
|
||||
createCompletionMessage(
|
||||
inputs: any,
|
||||
user: User,
|
||||
stream?: boolean,
|
||||
files?: DifyFile[] | null
|
||||
): Promise<any>;
|
||||
}
|
||||
|
||||
export declare class ChatClient extends DifyClient {
|
||||
createChatMessage(
|
||||
inputs: any,
|
||||
query: string,
|
||||
user: User,
|
||||
stream?: boolean,
|
||||
conversation_id?: string | null,
|
||||
files?: DifyFile[] | null
|
||||
): Promise<any>;
|
||||
|
||||
getSuggested(message_id: string, user: User): Promise<any>;
|
||||
|
||||
stopMessage(task_id: string, user: User) : Promise<any>;
|
||||
|
||||
|
||||
getConversations(
|
||||
user: User,
|
||||
first_id?: string | null,
|
||||
limit?: number | null,
|
||||
pinned?: boolean | null
|
||||
): Promise<any>;
|
||||
|
||||
getConversationMessages(
|
||||
user: User,
|
||||
conversation_id?: string,
|
||||
first_id?: string | null,
|
||||
limit?: number | null
|
||||
): Promise<any>;
|
||||
|
||||
renameConversation(conversation_id: string, name: string, user: User,auto_generate:boolean): Promise<any>;
|
||||
|
||||
deleteConversation(conversation_id: string, user: User): Promise<any>;
|
||||
|
||||
audioToText(data: FormData): Promise<any>;
|
||||
}
|
||||
|
||||
export declare class WorkflowClient extends DifyClient {
|
||||
run(inputs: any, user: User, stream?: boolean,): Promise<any>;
|
||||
|
||||
stop(task_id: string, user: User): Promise<any>;
|
||||
}
|
||||
|
|
@ -1,351 +0,0 @@
|
|||
import axios from "axios";
|
||||
export const BASE_URL = "https://api.dify.ai/v1";
|
||||
|
||||
export const routes = {
|
||||
// app's
|
||||
feedback: {
|
||||
method: "POST",
|
||||
url: (message_id) => `/messages/${message_id}/feedbacks`,
|
||||
},
|
||||
application: {
|
||||
method: "GET",
|
||||
url: () => `/parameters`,
|
||||
},
|
||||
fileUpload: {
|
||||
method: "POST",
|
||||
url: () => `/files/upload`,
|
||||
},
|
||||
textToAudio: {
|
||||
method: "POST",
|
||||
url: () => `/text-to-audio`,
|
||||
},
|
||||
getMeta: {
|
||||
method: "GET",
|
||||
url: () => `/meta`,
|
||||
},
|
||||
|
||||
// completion's
|
||||
createCompletionMessage: {
|
||||
method: "POST",
|
||||
url: () => `/completion-messages`,
|
||||
},
|
||||
|
||||
// chat's
|
||||
createChatMessage: {
|
||||
method: "POST",
|
||||
url: () => `/chat-messages`,
|
||||
},
|
||||
getSuggested:{
|
||||
method: "GET",
|
||||
url: (message_id) => `/messages/${message_id}/suggested`,
|
||||
},
|
||||
stopChatMessage: {
|
||||
method: "POST",
|
||||
url: (task_id) => `/chat-messages/${task_id}/stop`,
|
||||
},
|
||||
getConversations: {
|
||||
method: "GET",
|
||||
url: () => `/conversations`,
|
||||
},
|
||||
getConversationMessages: {
|
||||
method: "GET",
|
||||
url: () => `/messages`,
|
||||
},
|
||||
renameConversation: {
|
||||
method: "POST",
|
||||
url: (conversation_id) => `/conversations/${conversation_id}/name`,
|
||||
},
|
||||
deleteConversation: {
|
||||
method: "DELETE",
|
||||
url: (conversation_id) => `/conversations/${conversation_id}`,
|
||||
},
|
||||
audioToText: {
|
||||
method: "POST",
|
||||
url: () => `/audio-to-text`,
|
||||
},
|
||||
|
||||
// workflow‘s
|
||||
runWorkflow: {
|
||||
method: "POST",
|
||||
url: () => `/workflows/run`,
|
||||
},
|
||||
stopWorkflow: {
|
||||
method: "POST",
|
||||
url: (task_id) => `/workflows/tasks/${task_id}/stop`,
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
export class DifyClient {
|
||||
constructor(apiKey, baseUrl = BASE_URL) {
|
||||
this.apiKey = apiKey;
|
||||
this.baseUrl = baseUrl;
|
||||
}
|
||||
|
||||
updateApiKey(apiKey) {
|
||||
this.apiKey = apiKey;
|
||||
}
|
||||
|
||||
async sendRequest(
|
||||
method,
|
||||
endpoint,
|
||||
data = null,
|
||||
params = null,
|
||||
stream = false,
|
||||
headerParams = {}
|
||||
) {
|
||||
const isFormData =
|
||||
(typeof FormData !== "undefined" && data instanceof FormData) ||
|
||||
(data && data.constructor && data.constructor.name === "FormData");
|
||||
const headers = {
|
||||
Authorization: `Bearer ${this.apiKey}`,
|
||||
...(isFormData ? {} : { "Content-Type": "application/json" }),
|
||||
...headerParams,
|
||||
};
|
||||
|
||||
const url = `${this.baseUrl}${endpoint}`;
|
||||
let response;
|
||||
if (stream) {
|
||||
response = await axios({
|
||||
method,
|
||||
url,
|
||||
data,
|
||||
params,
|
||||
headers,
|
||||
responseType: "stream",
|
||||
});
|
||||
} else {
|
||||
response = await axios({
|
||||
method,
|
||||
url,
|
||||
...(method !== "GET" && { data }),
|
||||
params,
|
||||
headers,
|
||||
responseType: "json",
|
||||
});
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
messageFeedback(message_id, rating, user) {
|
||||
const data = {
|
||||
rating,
|
||||
user,
|
||||
};
|
||||
return this.sendRequest(
|
||||
routes.feedback.method,
|
||||
routes.feedback.url(message_id),
|
||||
data
|
||||
);
|
||||
}
|
||||
|
||||
getApplicationParameters(user) {
|
||||
const params = { user };
|
||||
return this.sendRequest(
|
||||
routes.application.method,
|
||||
routes.application.url(),
|
||||
null,
|
||||
params
|
||||
);
|
||||
}
|
||||
|
||||
fileUpload(data) {
|
||||
return this.sendRequest(
|
||||
routes.fileUpload.method,
|
||||
routes.fileUpload.url(),
|
||||
data
|
||||
);
|
||||
}
|
||||
|
||||
textToAudio(text, user, streaming = false) {
|
||||
const data = {
|
||||
text,
|
||||
user,
|
||||
streaming
|
||||
};
|
||||
return this.sendRequest(
|
||||
routes.textToAudio.method,
|
||||
routes.textToAudio.url(),
|
||||
data,
|
||||
null,
|
||||
streaming
|
||||
);
|
||||
}
|
||||
|
||||
getMeta(user) {
|
||||
const params = { user };
|
||||
return this.sendRequest(
|
||||
routes.getMeta.method,
|
||||
routes.getMeta.url(),
|
||||
null,
|
||||
params
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export class CompletionClient extends DifyClient {
|
||||
createCompletionMessage(inputs, user, stream = false, files = null) {
|
||||
const data = {
|
||||
inputs,
|
||||
user,
|
||||
response_mode: stream ? "streaming" : "blocking",
|
||||
files,
|
||||
};
|
||||
return this.sendRequest(
|
||||
routes.createCompletionMessage.method,
|
||||
routes.createCompletionMessage.url(),
|
||||
data,
|
||||
null,
|
||||
stream
|
||||
);
|
||||
}
|
||||
|
||||
runWorkflow(inputs, user, stream = false, files = null) {
|
||||
const data = {
|
||||
inputs,
|
||||
user,
|
||||
response_mode: stream ? "streaming" : "blocking",
|
||||
};
|
||||
return this.sendRequest(
|
||||
routes.runWorkflow.method,
|
||||
routes.runWorkflow.url(),
|
||||
data,
|
||||
null,
|
||||
stream
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export class ChatClient extends DifyClient {
|
||||
createChatMessage(
|
||||
inputs,
|
||||
query,
|
||||
user,
|
||||
stream = false,
|
||||
conversation_id = null,
|
||||
files = null
|
||||
) {
|
||||
const data = {
|
||||
inputs,
|
||||
query,
|
||||
user,
|
||||
response_mode: stream ? "streaming" : "blocking",
|
||||
files,
|
||||
};
|
||||
if (conversation_id) data.conversation_id = conversation_id;
|
||||
|
||||
return this.sendRequest(
|
||||
routes.createChatMessage.method,
|
||||
routes.createChatMessage.url(),
|
||||
data,
|
||||
null,
|
||||
stream
|
||||
);
|
||||
}
|
||||
|
||||
getSuggested(message_id, user) {
|
||||
const data = { user };
|
||||
return this.sendRequest(
|
||||
routes.getSuggested.method,
|
||||
routes.getSuggested.url(message_id),
|
||||
data
|
||||
);
|
||||
}
|
||||
|
||||
stopMessage(task_id, user) {
|
||||
const data = { user };
|
||||
return this.sendRequest(
|
||||
routes.stopChatMessage.method,
|
||||
routes.stopChatMessage.url(task_id),
|
||||
data
|
||||
);
|
||||
}
|
||||
|
||||
getConversations(user, first_id = null, limit = null, pinned = null) {
|
||||
const params = { user, first_id: first_id, limit, pinned };
|
||||
return this.sendRequest(
|
||||
routes.getConversations.method,
|
||||
routes.getConversations.url(),
|
||||
null,
|
||||
params
|
||||
);
|
||||
}
|
||||
|
||||
getConversationMessages(
|
||||
user,
|
||||
conversation_id = "",
|
||||
first_id = null,
|
||||
limit = null
|
||||
) {
|
||||
const params = { user };
|
||||
|
||||
if (conversation_id) params.conversation_id = conversation_id;
|
||||
|
||||
if (first_id) params.first_id = first_id;
|
||||
|
||||
if (limit) params.limit = limit;
|
||||
|
||||
return this.sendRequest(
|
||||
routes.getConversationMessages.method,
|
||||
routes.getConversationMessages.url(),
|
||||
null,
|
||||
params
|
||||
);
|
||||
}
|
||||
|
||||
renameConversation(conversation_id, name, user, auto_generate) {
|
||||
const data = { name, user, auto_generate };
|
||||
return this.sendRequest(
|
||||
routes.renameConversation.method,
|
||||
routes.renameConversation.url(conversation_id),
|
||||
data
|
||||
);
|
||||
}
|
||||
|
||||
deleteConversation(conversation_id, user) {
|
||||
const data = { user };
|
||||
return this.sendRequest(
|
||||
routes.deleteConversation.method,
|
||||
routes.deleteConversation.url(conversation_id),
|
||||
data
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
audioToText(data) {
|
||||
return this.sendRequest(
|
||||
routes.audioToText.method,
|
||||
routes.audioToText.url(),
|
||||
data
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
export class WorkflowClient extends DifyClient {
|
||||
run(inputs,user,stream) {
|
||||
const data = {
|
||||
inputs,
|
||||
response_mode: stream ? "streaming" : "blocking",
|
||||
user
|
||||
};
|
||||
|
||||
return this.sendRequest(
|
||||
routes.runWorkflow.method,
|
||||
routes.runWorkflow.url(),
|
||||
data,
|
||||
null,
|
||||
stream
|
||||
);
|
||||
}
|
||||
|
||||
stop(task_id, user) {
|
||||
const data = { user };
|
||||
return this.sendRequest(
|
||||
routes.stopWorkflow.method,
|
||||
routes.stopWorkflow.url(task_id),
|
||||
data
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,141 +0,0 @@
|
|||
import { DifyClient, WorkflowClient, BASE_URL, routes } from ".";
|
||||
|
||||
import axios from 'axios'
|
||||
|
||||
jest.mock('axios')
|
||||
|
||||
afterEach(() => {
|
||||
jest.resetAllMocks()
|
||||
})
|
||||
|
||||
describe('Client', () => {
|
||||
let difyClient
|
||||
beforeEach(() => {
|
||||
difyClient = new DifyClient('test')
|
||||
})
|
||||
|
||||
test('should create a client', () => {
|
||||
expect(difyClient).toBeDefined();
|
||||
})
|
||||
// test updateApiKey
|
||||
test('should update the api key', () => {
|
||||
difyClient.updateApiKey('test2');
|
||||
expect(difyClient.apiKey).toBe('test2');
|
||||
})
|
||||
});
|
||||
|
||||
describe('Send Requests', () => {
|
||||
let difyClient
|
||||
|
||||
beforeEach(() => {
|
||||
difyClient = new DifyClient('test')
|
||||
})
|
||||
|
||||
it('should make a successful request to the application parameter', async () => {
|
||||
const method = 'GET'
|
||||
const endpoint = routes.application.url()
|
||||
const expectedResponse = { data: 'response' }
|
||||
axios.mockResolvedValue(expectedResponse)
|
||||
|
||||
await difyClient.sendRequest(method, endpoint)
|
||||
|
||||
expect(axios).toHaveBeenCalledWith({
|
||||
method,
|
||||
url: `${BASE_URL}${endpoint}`,
|
||||
params: null,
|
||||
headers: {
|
||||
Authorization: `Bearer ${difyClient.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
responseType: 'json',
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
it('should handle errors from the API', async () => {
|
||||
const method = 'GET'
|
||||
const endpoint = '/test-endpoint'
|
||||
const errorMessage = 'Request failed with status code 404'
|
||||
axios.mockRejectedValue(new Error(errorMessage))
|
||||
|
||||
await expect(difyClient.sendRequest(method, endpoint)).rejects.toThrow(
|
||||
errorMessage
|
||||
)
|
||||
})
|
||||
|
||||
it('uses the getMeta route configuration', async () => {
|
||||
axios.mockResolvedValue({ data: 'ok' })
|
||||
await difyClient.getMeta('end-user')
|
||||
|
||||
expect(axios).toHaveBeenCalledWith({
|
||||
method: routes.getMeta.method,
|
||||
url: `${BASE_URL}${routes.getMeta.url()}`,
|
||||
params: { user: 'end-user' },
|
||||
headers: {
|
||||
Authorization: `Bearer ${difyClient.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
responseType: 'json',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('File uploads', () => {
|
||||
let difyClient
|
||||
const OriginalFormData = global.FormData
|
||||
|
||||
beforeAll(() => {
|
||||
global.FormData = class FormDataMock {}
|
||||
})
|
||||
|
||||
afterAll(() => {
|
||||
global.FormData = OriginalFormData
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
difyClient = new DifyClient('test')
|
||||
})
|
||||
|
||||
it('does not override multipart boundary headers for FormData', async () => {
|
||||
const form = new FormData()
|
||||
axios.mockResolvedValue({ data: 'ok' })
|
||||
|
||||
await difyClient.fileUpload(form)
|
||||
|
||||
expect(axios).toHaveBeenCalledWith({
|
||||
method: routes.fileUpload.method,
|
||||
url: `${BASE_URL}${routes.fileUpload.url()}`,
|
||||
data: form,
|
||||
params: null,
|
||||
headers: {
|
||||
Authorization: `Bearer ${difyClient.apiKey}`,
|
||||
},
|
||||
responseType: 'json',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Workflow client', () => {
|
||||
let workflowClient
|
||||
|
||||
beforeEach(() => {
|
||||
workflowClient = new WorkflowClient('test')
|
||||
})
|
||||
|
||||
it('uses tasks stop path for workflow stop', async () => {
|
||||
axios.mockResolvedValue({ data: 'stopped' })
|
||||
await workflowClient.stop('task-1', 'end-user')
|
||||
|
||||
expect(axios).toHaveBeenCalledWith({
|
||||
method: routes.stopWorkflow.method,
|
||||
url: `${BASE_URL}${routes.stopWorkflow.url('task-1')}`,
|
||||
data: { user: 'end-user' },
|
||||
params: null,
|
||||
headers: {
|
||||
Authorization: `Bearer ${workflowClient.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
responseType: 'json',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
module.exports = {
|
||||
testEnvironment: "node",
|
||||
transform: {
|
||||
"^.+\\.[tj]sx?$": "babel-jest",
|
||||
},
|
||||
};
|
||||
|
|
@ -1,30 +1,70 @@
|
|||
{
|
||||
"name": "dify-client",
|
||||
"version": "2.3.2",
|
||||
"version": "3.0.0",
|
||||
"description": "This is the Node.js SDK for the Dify.AI API, which allows you to easily integrate Dify.AI into your Node.js applications.",
|
||||
"main": "index.js",
|
||||
"type": "module",
|
||||
"types":"index.d.ts",
|
||||
"main": "./dist/index.js",
|
||||
"types": "./dist/index.d.ts",
|
||||
"exports": {
|
||||
".": {
|
||||
"types": "./dist/index.d.ts",
|
||||
"import": "./dist/index.js"
|
||||
}
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
},
|
||||
"files": [
|
||||
"dist",
|
||||
"README.md",
|
||||
"LICENSE"
|
||||
],
|
||||
"keywords": [
|
||||
"Dify",
|
||||
"Dify.AI",
|
||||
"LLM"
|
||||
"LLM",
|
||||
"AI",
|
||||
"SDK",
|
||||
"API"
|
||||
],
|
||||
"author": "Joel",
|
||||
"author": "LangGenius",
|
||||
"contributors": [
|
||||
"<crazywoola> <<427733928@qq.com>> (https://github.com/crazywoola)"
|
||||
"Joel <iamjoel007@gmail.com> (https://github.com/iamjoel)",
|
||||
"lyzno1 <yuanyouhuilyz@gmail.com> (https://github.com/lyzno1)",
|
||||
"crazywoola <427733928@qq.com> (https://github.com/crazywoola)"
|
||||
],
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "https://github.com/langgenius/dify.git",
|
||||
"directory": "sdks/nodejs-client"
|
||||
},
|
||||
"bugs": {
|
||||
"url": "https://github.com/langgenius/dify/issues"
|
||||
},
|
||||
"homepage": "https://dify.ai",
|
||||
"license": "MIT",
|
||||
"scripts": {
|
||||
"test": "jest"
|
||||
"build": "tsup",
|
||||
"lint": "eslint",
|
||||
"lint:fix": "eslint --fix",
|
||||
"type-check": "tsc -p tsconfig.json --noEmit",
|
||||
"test": "vitest run",
|
||||
"test:coverage": "vitest run --coverage",
|
||||
"publish:check": "./scripts/publish.sh --dry-run",
|
||||
"publish:npm": "./scripts/publish.sh"
|
||||
},
|
||||
"dependencies": {
|
||||
"axios": "^1.3.5"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@babel/core": "^7.21.8",
|
||||
"@babel/preset-env": "^7.21.5",
|
||||
"babel-jest": "^29.5.0",
|
||||
"jest": "^29.5.0"
|
||||
"@eslint/js": "^9.2.0",
|
||||
"@types/node": "^20.11.30",
|
||||
"@typescript-eslint/eslint-plugin": "^8.50.1",
|
||||
"@typescript-eslint/parser": "^8.50.1",
|
||||
"@vitest/coverage-v8": "1.6.1",
|
||||
"eslint": "^9.2.0",
|
||||
"tsup": "^8.5.1",
|
||||
"typescript": "^5.4.5",
|
||||
"vitest": "^1.5.0"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,261 @@
|
|||
#!/usr/bin/env bash
|
||||
#
|
||||
# Dify Node.js SDK Publish Script
|
||||
# ================================
|
||||
# A beautiful and reliable script to publish the SDK to npm
|
||||
#
|
||||
# Usage:
|
||||
# ./scripts/publish.sh # Normal publish
|
||||
# ./scripts/publish.sh --dry-run # Test without publishing
|
||||
# ./scripts/publish.sh --skip-tests # Skip tests (not recommended)
|
||||
#
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# ============================================================================
|
||||
# Colors and Formatting
|
||||
# ============================================================================
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
MAGENTA='\033[0;35m'
|
||||
CYAN='\033[0;36m'
|
||||
BOLD='\033[1m'
|
||||
DIM='\033[2m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# ============================================================================
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
print_banner() {
|
||||
echo -e "${CYAN}"
|
||||
echo "╔═══════════════════════════════════════════════════════════════╗"
|
||||
echo "║ ║"
|
||||
echo "║ 🚀 Dify Node.js SDK Publish Script 🚀 ║"
|
||||
echo "║ ║"
|
||||
echo "╚═══════════════════════════════════════════════════════════════╝"
|
||||
echo -e "${NC}"
|
||||
}
|
||||
|
||||
info() {
|
||||
echo -e "${BLUE}ℹ ${NC}$1"
|
||||
}
|
||||
|
||||
success() {
|
||||
echo -e "${GREEN}✔ ${NC}$1"
|
||||
}
|
||||
|
||||
warning() {
|
||||
echo -e "${YELLOW}⚠ ${NC}$1"
|
||||
}
|
||||
|
||||
error() {
|
||||
echo -e "${RED}✖ ${NC}$1"
|
||||
}
|
||||
|
||||
step() {
|
||||
echo -e "\n${MAGENTA}▶ ${BOLD}$1${NC}"
|
||||
}
|
||||
|
||||
divider() {
|
||||
echo -e "${DIM}─────────────────────────────────────────────────────────────────${NC}"
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_DIR="$(cd "$SCRIPT_DIR/.." && pwd)"
|
||||
|
||||
DRY_RUN=false
|
||||
SKIP_TESTS=false
|
||||
|
||||
# Parse arguments
|
||||
for arg in "$@"; do
|
||||
case $arg in
|
||||
--dry-run)
|
||||
DRY_RUN=true
|
||||
;;
|
||||
--skip-tests)
|
||||
SKIP_TESTS=true
|
||||
;;
|
||||
--help|-h)
|
||||
echo "Usage: $0 [options]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " --dry-run Run without actually publishing"
|
||||
echo " --skip-tests Skip running tests (not recommended)"
|
||||
echo " --help, -h Show this help message"
|
||||
exit 0
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# ============================================================================
|
||||
# Main Script
|
||||
# ============================================================================
|
||||
main() {
|
||||
print_banner
|
||||
cd "$PROJECT_DIR"
|
||||
|
||||
# Show mode
|
||||
if [[ "$DRY_RUN" == true ]]; then
|
||||
warning "Running in DRY-RUN mode - no actual publish will occur"
|
||||
divider
|
||||
fi
|
||||
|
||||
# ========================================================================
|
||||
# Step 1: Environment Check
|
||||
# ========================================================================
|
||||
step "Step 1/6: Checking environment..."
|
||||
|
||||
# Check Node.js
|
||||
if ! command -v node &> /dev/null; then
|
||||
error "Node.js is not installed"
|
||||
exit 1
|
||||
fi
|
||||
NODE_VERSION=$(node -v)
|
||||
success "Node.js: $NODE_VERSION"
|
||||
|
||||
# Check npm
|
||||
if ! command -v npm &> /dev/null; then
|
||||
error "npm is not installed"
|
||||
exit 1
|
||||
fi
|
||||
NPM_VERSION=$(npm -v)
|
||||
success "npm: v$NPM_VERSION"
|
||||
|
||||
# Check pnpm (optional, for local dev)
|
||||
if command -v pnpm &> /dev/null; then
|
||||
PNPM_VERSION=$(pnpm -v)
|
||||
success "pnpm: v$PNPM_VERSION"
|
||||
else
|
||||
info "pnpm not found (optional)"
|
||||
fi
|
||||
|
||||
# Check npm login status
|
||||
if ! npm whoami &> /dev/null; then
|
||||
error "Not logged in to npm. Run 'npm login' first."
|
||||
exit 1
|
||||
fi
|
||||
NPM_USER=$(npm whoami)
|
||||
success "Logged in as: ${BOLD}$NPM_USER${NC}"
|
||||
|
||||
# ========================================================================
|
||||
# Step 2: Read Package Info
|
||||
# ========================================================================
|
||||
step "Step 2/6: Reading package info..."
|
||||
|
||||
PACKAGE_NAME=$(node -p "require('./package.json').name")
|
||||
PACKAGE_VERSION=$(node -p "require('./package.json').version")
|
||||
|
||||
success "Package: ${BOLD}$PACKAGE_NAME${NC}"
|
||||
success "Version: ${BOLD}$PACKAGE_VERSION${NC}"
|
||||
|
||||
# Check if version already exists on npm
|
||||
if npm view "$PACKAGE_NAME@$PACKAGE_VERSION" version &> /dev/null; then
|
||||
error "Version $PACKAGE_VERSION already exists on npm!"
|
||||
echo ""
|
||||
info "Current published versions:"
|
||||
npm view "$PACKAGE_NAME" versions --json 2>/dev/null | tail -5
|
||||
echo ""
|
||||
warning "Please update the version in package.json before publishing."
|
||||
exit 1
|
||||
fi
|
||||
success "Version $PACKAGE_VERSION is available"
|
||||
|
||||
# ========================================================================
|
||||
# Step 3: Install Dependencies
|
||||
# ========================================================================
|
||||
step "Step 3/6: Installing dependencies..."
|
||||
|
||||
if command -v pnpm &> /dev/null; then
|
||||
pnpm install --frozen-lockfile 2>/dev/null || pnpm install
|
||||
else
|
||||
npm ci 2>/dev/null || npm install
|
||||
fi
|
||||
success "Dependencies installed"
|
||||
|
||||
# ========================================================================
|
||||
# Step 4: Run Tests
|
||||
# ========================================================================
|
||||
step "Step 4/6: Running tests..."
|
||||
|
||||
if [[ "$SKIP_TESTS" == true ]]; then
|
||||
warning "Skipping tests (--skip-tests flag)"
|
||||
else
|
||||
if command -v pnpm &> /dev/null; then
|
||||
pnpm test
|
||||
else
|
||||
npm test
|
||||
fi
|
||||
success "All tests passed"
|
||||
fi
|
||||
|
||||
# ========================================================================
|
||||
# Step 5: Build
|
||||
# ========================================================================
|
||||
step "Step 5/6: Building package..."
|
||||
|
||||
# Clean previous build
|
||||
rm -rf dist
|
||||
|
||||
if command -v pnpm &> /dev/null; then
|
||||
pnpm run build
|
||||
else
|
||||
npm run build
|
||||
fi
|
||||
success "Build completed"
|
||||
|
||||
# Verify build output
|
||||
if [[ ! -f "dist/index.js" ]]; then
|
||||
error "Build failed - dist/index.js not found"
|
||||
exit 1
|
||||
fi
|
||||
if [[ ! -f "dist/index.d.ts" ]]; then
|
||||
error "Build failed - dist/index.d.ts not found"
|
||||
exit 1
|
||||
fi
|
||||
success "Build output verified"
|
||||
|
||||
# ========================================================================
|
||||
# Step 6: Publish
|
||||
# ========================================================================
|
||||
step "Step 6/6: Publishing to npm..."
|
||||
|
||||
divider
|
||||
echo -e "${CYAN}Package contents:${NC}"
|
||||
npm pack --dry-run 2>&1 | head -30
|
||||
divider
|
||||
|
||||
if [[ "$DRY_RUN" == true ]]; then
|
||||
warning "DRY-RUN: Skipping actual publish"
|
||||
echo ""
|
||||
info "To publish for real, run without --dry-run flag"
|
||||
else
|
||||
echo ""
|
||||
echo -e "${YELLOW}About to publish ${BOLD}$PACKAGE_NAME@$PACKAGE_VERSION${NC}${YELLOW} to npm${NC}"
|
||||
echo -e "${DIM}Press Enter to continue, or Ctrl+C to cancel...${NC}"
|
||||
read -r
|
||||
|
||||
npm publish --access public
|
||||
|
||||
echo ""
|
||||
success "🎉 Successfully published ${BOLD}$PACKAGE_NAME@$PACKAGE_VERSION${NC} to npm!"
|
||||
echo ""
|
||||
echo -e "${GREEN}Install with:${NC}"
|
||||
echo -e " ${CYAN}npm install $PACKAGE_NAME${NC}"
|
||||
echo -e " ${CYAN}pnpm add $PACKAGE_NAME${NC}"
|
||||
echo -e " ${CYAN}yarn add $PACKAGE_NAME${NC}"
|
||||
echo ""
|
||||
echo -e "${GREEN}View on npm:${NC}"
|
||||
echo -e " ${CYAN}https://www.npmjs.com/package/$PACKAGE_NAME${NC}"
|
||||
fi
|
||||
|
||||
divider
|
||||
echo -e "${GREEN}${BOLD}✨ All done!${NC}"
|
||||
}
|
||||
|
||||
# Run main function
|
||||
main "$@"
|
||||
|
|
@ -0,0 +1,175 @@
|
|||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { DifyClient } from "./base";
|
||||
import { ValidationError } from "../errors/dify-error";
|
||||
import { createHttpClientWithSpies } from "../../tests/test-utils";
|
||||
|
||||
describe("DifyClient base", () => {
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it("getRoot calls root endpoint", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const dify = new DifyClient(client);
|
||||
|
||||
await dify.getRoot();
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/",
|
||||
});
|
||||
});
|
||||
|
||||
it("getApplicationParameters includes optional user", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const dify = new DifyClient(client);
|
||||
|
||||
await dify.getApplicationParameters();
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/parameters",
|
||||
query: undefined,
|
||||
});
|
||||
|
||||
await dify.getApplicationParameters("user-1");
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/parameters",
|
||||
query: { user: "user-1" },
|
||||
});
|
||||
});
|
||||
|
||||
it("getMeta includes optional user", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const dify = new DifyClient(client);
|
||||
|
||||
await dify.getMeta("user-1");
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/meta",
|
||||
query: { user: "user-1" },
|
||||
});
|
||||
});
|
||||
|
||||
it("getInfo and getSite support optional user", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const dify = new DifyClient(client);
|
||||
|
||||
await dify.getInfo();
|
||||
await dify.getSite("user");
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/info",
|
||||
query: undefined,
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/site",
|
||||
query: { user: "user" },
|
||||
});
|
||||
});
|
||||
|
||||
it("messageFeedback builds payload from request object", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const dify = new DifyClient(client);
|
||||
|
||||
await dify.messageFeedback({
|
||||
messageId: "msg",
|
||||
user: "user",
|
||||
rating: "like",
|
||||
content: "good",
|
||||
});
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/messages/msg/feedbacks",
|
||||
data: { user: "user", rating: "like", content: "good" },
|
||||
});
|
||||
});
|
||||
|
||||
it("fileUpload appends user to form data", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const dify = new DifyClient(client);
|
||||
const form = { append: vi.fn(), getHeaders: () => ({}) };
|
||||
|
||||
await dify.fileUpload(form, "user");
|
||||
|
||||
expect(form.append).toHaveBeenCalledWith("user", "user");
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/files/upload",
|
||||
data: form,
|
||||
});
|
||||
});
|
||||
|
||||
it("filePreview uses arraybuffer response", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const dify = new DifyClient(client);
|
||||
|
||||
await dify.filePreview("file", "user", true);
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/files/file/preview",
|
||||
query: { user: "user", as_attachment: "true" },
|
||||
responseType: "arraybuffer",
|
||||
});
|
||||
});
|
||||
|
||||
it("audioToText appends user and sends form", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const dify = new DifyClient(client);
|
||||
const form = { append: vi.fn(), getHeaders: () => ({}) };
|
||||
|
||||
await dify.audioToText(form, "user");
|
||||
|
||||
expect(form.append).toHaveBeenCalledWith("user", "user");
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/audio-to-text",
|
||||
data: form,
|
||||
});
|
||||
});
|
||||
|
||||
it("textToAudio supports streaming and message id", async () => {
|
||||
const { client, request, requestBinaryStream } = createHttpClientWithSpies();
|
||||
const dify = new DifyClient(client);
|
||||
|
||||
await dify.textToAudio({
|
||||
user: "user",
|
||||
message_id: "msg",
|
||||
streaming: true,
|
||||
});
|
||||
|
||||
expect(requestBinaryStream).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/text-to-audio",
|
||||
data: {
|
||||
user: "user",
|
||||
message_id: "msg",
|
||||
streaming: true,
|
||||
},
|
||||
});
|
||||
|
||||
await dify.textToAudio("hello", "user", false, "voice");
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/text-to-audio",
|
||||
data: {
|
||||
text: "hello",
|
||||
user: "user",
|
||||
streaming: false,
|
||||
voice: "voice",
|
||||
},
|
||||
responseType: "arraybuffer",
|
||||
});
|
||||
});
|
||||
|
||||
it("textToAudio requires text or message id", async () => {
|
||||
const { client } = createHttpClientWithSpies();
|
||||
const dify = new DifyClient(client);
|
||||
|
||||
expect(() => dify.textToAudio({ user: "user" })).toThrow(ValidationError);
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,284 @@
|
|||
import type {
|
||||
BinaryStream,
|
||||
DifyClientConfig,
|
||||
DifyResponse,
|
||||
MessageFeedbackRequest,
|
||||
QueryParams,
|
||||
RequestMethod,
|
||||
TextToAudioRequest,
|
||||
} from "../types/common";
|
||||
import { HttpClient } from "../http/client";
|
||||
import { ensureNonEmptyString, ensureRating } from "./validation";
|
||||
import { FileUploadError, ValidationError } from "../errors/dify-error";
|
||||
import { isFormData } from "../http/form-data";
|
||||
|
||||
const toConfig = (
|
||||
init: string | DifyClientConfig,
|
||||
baseUrl?: string
|
||||
): DifyClientConfig => {
|
||||
if (typeof init === "string") {
|
||||
return {
|
||||
apiKey: init,
|
||||
baseUrl,
|
||||
};
|
||||
}
|
||||
return init;
|
||||
};
|
||||
|
||||
const appendUserToFormData = (form: unknown, user: string): void => {
|
||||
if (!isFormData(form)) {
|
||||
throw new FileUploadError("FormData is required for file uploads");
|
||||
}
|
||||
if (typeof form.append === "function") {
|
||||
form.append("user", user);
|
||||
}
|
||||
};
|
||||
|
||||
export class DifyClient {
|
||||
protected http: HttpClient;
|
||||
|
||||
constructor(config: string | DifyClientConfig | HttpClient, baseUrl?: string) {
|
||||
if (config instanceof HttpClient) {
|
||||
this.http = config;
|
||||
} else {
|
||||
this.http = new HttpClient(toConfig(config, baseUrl));
|
||||
}
|
||||
}
|
||||
|
||||
updateApiKey(apiKey: string): void {
|
||||
ensureNonEmptyString(apiKey, "apiKey");
|
||||
this.http.updateApiKey(apiKey);
|
||||
}
|
||||
|
||||
getHttpClient(): HttpClient {
|
||||
return this.http;
|
||||
}
|
||||
|
||||
sendRequest(
|
||||
method: RequestMethod,
|
||||
endpoint: string,
|
||||
data: unknown = null,
|
||||
params: QueryParams | null = null,
|
||||
stream = false,
|
||||
headerParams: Record<string, string> = {}
|
||||
): ReturnType<HttpClient["requestRaw"]> {
|
||||
return this.http.requestRaw({
|
||||
method,
|
||||
path: endpoint,
|
||||
data,
|
||||
query: params ?? undefined,
|
||||
headers: headerParams,
|
||||
responseType: stream ? "stream" : "json",
|
||||
});
|
||||
}
|
||||
|
||||
getRoot(): Promise<DifyResponse<unknown>> {
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: "/",
|
||||
});
|
||||
}
|
||||
|
||||
getApplicationParameters(user?: string): Promise<DifyResponse<unknown>> {
|
||||
if (user) {
|
||||
ensureNonEmptyString(user, "user");
|
||||
}
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: "/parameters",
|
||||
query: user ? { user } : undefined,
|
||||
});
|
||||
}
|
||||
|
||||
async getParameters(user?: string): Promise<DifyResponse<unknown>> {
|
||||
return this.getApplicationParameters(user);
|
||||
}
|
||||
|
||||
getMeta(user?: string): Promise<DifyResponse<unknown>> {
|
||||
if (user) {
|
||||
ensureNonEmptyString(user, "user");
|
||||
}
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: "/meta",
|
||||
query: user ? { user } : undefined,
|
||||
});
|
||||
}
|
||||
|
||||
messageFeedback(
|
||||
request: MessageFeedbackRequest
|
||||
): Promise<DifyResponse<Record<string, unknown>>>;
|
||||
messageFeedback(
|
||||
messageId: string,
|
||||
rating: "like" | "dislike" | null,
|
||||
user: string,
|
||||
content?: string
|
||||
): Promise<DifyResponse<Record<string, unknown>>>;
|
||||
messageFeedback(
|
||||
messageIdOrRequest: string | MessageFeedbackRequest,
|
||||
rating?: "like" | "dislike" | null,
|
||||
user?: string,
|
||||
content?: string
|
||||
): Promise<DifyResponse<Record<string, unknown>>> {
|
||||
let messageId: string;
|
||||
const payload: Record<string, unknown> = {};
|
||||
|
||||
if (typeof messageIdOrRequest === "string") {
|
||||
messageId = messageIdOrRequest;
|
||||
ensureNonEmptyString(messageId, "messageId");
|
||||
ensureNonEmptyString(user, "user");
|
||||
payload.user = user;
|
||||
if (rating !== undefined && rating !== null) {
|
||||
ensureRating(rating);
|
||||
payload.rating = rating;
|
||||
}
|
||||
if (content !== undefined) {
|
||||
payload.content = content;
|
||||
}
|
||||
} else {
|
||||
const request = messageIdOrRequest;
|
||||
messageId = request.messageId;
|
||||
ensureNonEmptyString(messageId, "messageId");
|
||||
ensureNonEmptyString(request.user, "user");
|
||||
payload.user = request.user;
|
||||
if (request.rating !== undefined && request.rating !== null) {
|
||||
ensureRating(request.rating);
|
||||
payload.rating = request.rating;
|
||||
}
|
||||
if (request.content !== undefined) {
|
||||
payload.content = request.content;
|
||||
}
|
||||
}
|
||||
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: `/messages/${messageId}/feedbacks`,
|
||||
data: payload,
|
||||
});
|
||||
}
|
||||
|
||||
getInfo(user?: string): Promise<DifyResponse<unknown>> {
|
||||
if (user) {
|
||||
ensureNonEmptyString(user, "user");
|
||||
}
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: "/info",
|
||||
query: user ? { user } : undefined,
|
||||
});
|
||||
}
|
||||
|
||||
getSite(user?: string): Promise<DifyResponse<unknown>> {
|
||||
if (user) {
|
||||
ensureNonEmptyString(user, "user");
|
||||
}
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: "/site",
|
||||
query: user ? { user } : undefined,
|
||||
});
|
||||
}
|
||||
|
||||
fileUpload(form: unknown, user: string): Promise<DifyResponse<unknown>> {
|
||||
if (!isFormData(form)) {
|
||||
throw new FileUploadError("FormData is required for file uploads");
|
||||
}
|
||||
ensureNonEmptyString(user, "user");
|
||||
appendUserToFormData(form, user);
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: "/files/upload",
|
||||
data: form,
|
||||
});
|
||||
}
|
||||
|
||||
filePreview(
|
||||
fileId: string,
|
||||
user: string,
|
||||
asAttachment?: boolean
|
||||
): Promise<DifyResponse<Buffer>> {
|
||||
ensureNonEmptyString(fileId, "fileId");
|
||||
ensureNonEmptyString(user, "user");
|
||||
return this.http.request<Buffer>({
|
||||
method: "GET",
|
||||
path: `/files/${fileId}/preview`,
|
||||
query: {
|
||||
user,
|
||||
as_attachment: asAttachment ? "true" : undefined,
|
||||
},
|
||||
responseType: "arraybuffer",
|
||||
});
|
||||
}
|
||||
|
||||
audioToText(form: unknown, user: string): Promise<DifyResponse<unknown>> {
|
||||
if (!isFormData(form)) {
|
||||
throw new FileUploadError("FormData is required for audio uploads");
|
||||
}
|
||||
ensureNonEmptyString(user, "user");
|
||||
appendUserToFormData(form, user);
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: "/audio-to-text",
|
||||
data: form,
|
||||
});
|
||||
}
|
||||
|
||||
textToAudio(
|
||||
request: TextToAudioRequest
|
||||
): Promise<DifyResponse<Buffer> | BinaryStream>;
|
||||
textToAudio(
|
||||
text: string,
|
||||
user: string,
|
||||
streaming?: boolean,
|
||||
voice?: string
|
||||
): Promise<DifyResponse<Buffer> | BinaryStream>;
|
||||
textToAudio(
|
||||
textOrRequest: string | TextToAudioRequest,
|
||||
user?: string,
|
||||
streaming = false,
|
||||
voice?: string
|
||||
): Promise<DifyResponse<Buffer> | BinaryStream> {
|
||||
let payload: TextToAudioRequest;
|
||||
|
||||
if (typeof textOrRequest === "string") {
|
||||
ensureNonEmptyString(textOrRequest, "text");
|
||||
ensureNonEmptyString(user, "user");
|
||||
payload = {
|
||||
text: textOrRequest,
|
||||
user,
|
||||
streaming,
|
||||
};
|
||||
if (voice) {
|
||||
payload.voice = voice;
|
||||
}
|
||||
} else {
|
||||
payload = { ...textOrRequest };
|
||||
ensureNonEmptyString(payload.user, "user");
|
||||
if (payload.text !== undefined && payload.text !== null) {
|
||||
ensureNonEmptyString(payload.text, "text");
|
||||
}
|
||||
if (payload.message_id !== undefined && payload.message_id !== null) {
|
||||
ensureNonEmptyString(payload.message_id, "messageId");
|
||||
}
|
||||
if (!payload.text && !payload.message_id) {
|
||||
throw new ValidationError("text or message_id is required");
|
||||
}
|
||||
payload.streaming = payload.streaming ?? false;
|
||||
}
|
||||
|
||||
if (payload.streaming) {
|
||||
return this.http.requestBinaryStream({
|
||||
method: "POST",
|
||||
path: "/text-to-audio",
|
||||
data: payload,
|
||||
});
|
||||
}
|
||||
|
||||
return this.http.request<Buffer>({
|
||||
method: "POST",
|
||||
path: "/text-to-audio",
|
||||
data: payload,
|
||||
responseType: "arraybuffer",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,239 @@
|
|||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { ChatClient } from "./chat";
|
||||
import { ValidationError } from "../errors/dify-error";
|
||||
import { createHttpClientWithSpies } from "../../tests/test-utils";
|
||||
|
||||
describe("ChatClient", () => {
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it("creates chat messages in blocking mode", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const chat = new ChatClient(client);
|
||||
|
||||
await chat.createChatMessage({ input: "x" }, "hello", "user", false, null);
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/chat-messages",
|
||||
data: {
|
||||
inputs: { input: "x" },
|
||||
query: "hello",
|
||||
user: "user",
|
||||
response_mode: "blocking",
|
||||
files: undefined,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("creates chat messages in streaming mode", async () => {
|
||||
const { client, requestStream } = createHttpClientWithSpies();
|
||||
const chat = new ChatClient(client);
|
||||
|
||||
await chat.createChatMessage({
|
||||
inputs: { input: "x" },
|
||||
query: "hello",
|
||||
user: "user",
|
||||
response_mode: "streaming",
|
||||
});
|
||||
|
||||
expect(requestStream).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/chat-messages",
|
||||
data: {
|
||||
inputs: { input: "x" },
|
||||
query: "hello",
|
||||
user: "user",
|
||||
response_mode: "streaming",
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("stops chat messages", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const chat = new ChatClient(client);
|
||||
|
||||
await chat.stopChatMessage("task", "user");
|
||||
await chat.stopMessage("task", "user");
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/chat-messages/task/stop",
|
||||
data: { user: "user" },
|
||||
});
|
||||
});
|
||||
|
||||
it("gets suggested questions", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const chat = new ChatClient(client);
|
||||
|
||||
await chat.getSuggested("msg", "user");
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/messages/msg/suggested",
|
||||
query: { user: "user" },
|
||||
});
|
||||
});
|
||||
|
||||
it("submits message feedback", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const chat = new ChatClient(client);
|
||||
|
||||
await chat.messageFeedback("msg", "like", "user", "good");
|
||||
await chat.messageFeedback({
|
||||
messageId: "msg",
|
||||
user: "user",
|
||||
rating: "dislike",
|
||||
});
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/messages/msg/feedbacks",
|
||||
data: { user: "user", rating: "like", content: "good" },
|
||||
});
|
||||
});
|
||||
|
||||
it("lists app feedbacks", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const chat = new ChatClient(client);
|
||||
|
||||
await chat.getAppFeedbacks(2, 5);
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/app/feedbacks",
|
||||
query: { page: 2, limit: 5 },
|
||||
});
|
||||
});
|
||||
|
||||
it("lists conversations and messages", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const chat = new ChatClient(client);
|
||||
|
||||
await chat.getConversations("user", "last", 10, "-updated_at");
|
||||
await chat.getConversationMessages("user", "conv", "first", 5);
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/conversations",
|
||||
query: {
|
||||
user: "user",
|
||||
last_id: "last",
|
||||
limit: 10,
|
||||
sort_by: "-updated_at",
|
||||
},
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/messages",
|
||||
query: {
|
||||
user: "user",
|
||||
conversation_id: "conv",
|
||||
first_id: "first",
|
||||
limit: 5,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("renames conversations with optional auto-generate", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const chat = new ChatClient(client);
|
||||
|
||||
await chat.renameConversation("conv", "name", "user", false);
|
||||
await chat.renameConversation("conv", "user", { autoGenerate: true });
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/conversations/conv/name",
|
||||
data: { user: "user", auto_generate: false, name: "name" },
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/conversations/conv/name",
|
||||
data: { user: "user", auto_generate: true },
|
||||
});
|
||||
});
|
||||
|
||||
it("requires name when autoGenerate is false", async () => {
|
||||
const { client } = createHttpClientWithSpies();
|
||||
const chat = new ChatClient(client);
|
||||
|
||||
expect(() =>
|
||||
chat.renameConversation("conv", "", "user", false)
|
||||
).toThrow(ValidationError);
|
||||
});
|
||||
|
||||
it("deletes conversations", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const chat = new ChatClient(client);
|
||||
|
||||
await chat.deleteConversation("conv", "user");
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "DELETE",
|
||||
path: "/conversations/conv",
|
||||
data: { user: "user" },
|
||||
});
|
||||
});
|
||||
|
||||
it("manages conversation variables", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const chat = new ChatClient(client);
|
||||
|
||||
await chat.getConversationVariables("conv", "user", "last", 10, "name");
|
||||
await chat.updateConversationVariable("conv", "var", "user", "value");
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/conversations/conv/variables",
|
||||
query: {
|
||||
user: "user",
|
||||
last_id: "last",
|
||||
limit: 10,
|
||||
variable_name: "name",
|
||||
},
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "PUT",
|
||||
path: "/conversations/conv/variables/var",
|
||||
data: { user: "user", value: "value" },
|
||||
});
|
||||
});
|
||||
|
||||
it("handles annotation APIs", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const chat = new ChatClient(client);
|
||||
|
||||
await chat.annotationReplyAction("enable", {
|
||||
score_threshold: 0.5,
|
||||
embedding_provider_name: "prov",
|
||||
embedding_model_name: "model",
|
||||
});
|
||||
await chat.getAnnotationReplyStatus("enable", "job");
|
||||
await chat.listAnnotations({ page: 1, limit: 10, keyword: "k" });
|
||||
await chat.createAnnotation({ question: "q", answer: "a" });
|
||||
await chat.updateAnnotation("id", { question: "q", answer: "a" });
|
||||
await chat.deleteAnnotation("id");
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/apps/annotation-reply/enable",
|
||||
data: {
|
||||
score_threshold: 0.5,
|
||||
embedding_provider_name: "prov",
|
||||
embedding_model_name: "model",
|
||||
},
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/apps/annotation-reply/enable/status/job",
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/apps/annotations",
|
||||
query: { page: 1, limit: 10, keyword: "k" },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,377 @@
|
|||
import { DifyClient } from "./base";
|
||||
import type { ChatMessageRequest, ChatMessageResponse } from "../types/chat";
|
||||
import type {
|
||||
AnnotationCreateRequest,
|
||||
AnnotationListOptions,
|
||||
AnnotationReplyActionRequest,
|
||||
AnnotationResponse,
|
||||
} from "../types/annotation";
|
||||
import type {
|
||||
DifyResponse,
|
||||
DifyStream,
|
||||
QueryParams,
|
||||
} from "../types/common";
|
||||
import {
|
||||
ensureNonEmptyString,
|
||||
ensureOptionalInt,
|
||||
ensureOptionalString,
|
||||
} from "./validation";
|
||||
|
||||
export class ChatClient extends DifyClient {
|
||||
createChatMessage(
|
||||
request: ChatMessageRequest
|
||||
): Promise<DifyResponse<ChatMessageResponse> | DifyStream<ChatMessageResponse>>;
|
||||
createChatMessage(
|
||||
inputs: Record<string, unknown>,
|
||||
query: string,
|
||||
user: string,
|
||||
stream?: boolean,
|
||||
conversationId?: string | null,
|
||||
files?: Array<Record<string, unknown>> | null
|
||||
): Promise<DifyResponse<ChatMessageResponse> | DifyStream<ChatMessageResponse>>;
|
||||
createChatMessage(
|
||||
inputOrRequest: ChatMessageRequest | Record<string, unknown>,
|
||||
query?: string,
|
||||
user?: string,
|
||||
stream = false,
|
||||
conversationId?: string | null,
|
||||
files?: Array<Record<string, unknown>> | null
|
||||
): Promise<DifyResponse<ChatMessageResponse> | DifyStream<ChatMessageResponse>> {
|
||||
let payload: ChatMessageRequest;
|
||||
let shouldStream = stream;
|
||||
|
||||
if (query === undefined && "user" in (inputOrRequest as ChatMessageRequest)) {
|
||||
payload = inputOrRequest as ChatMessageRequest;
|
||||
shouldStream = payload.response_mode === "streaming";
|
||||
} else {
|
||||
ensureNonEmptyString(query, "query");
|
||||
ensureNonEmptyString(user, "user");
|
||||
payload = {
|
||||
inputs: inputOrRequest as Record<string, unknown>,
|
||||
query,
|
||||
user,
|
||||
response_mode: stream ? "streaming" : "blocking",
|
||||
files,
|
||||
};
|
||||
if (conversationId) {
|
||||
payload.conversation_id = conversationId;
|
||||
}
|
||||
}
|
||||
|
||||
ensureNonEmptyString(payload.user, "user");
|
||||
ensureNonEmptyString(payload.query, "query");
|
||||
|
||||
if (shouldStream) {
|
||||
return this.http.requestStream<ChatMessageResponse>({
|
||||
method: "POST",
|
||||
path: "/chat-messages",
|
||||
data: payload,
|
||||
});
|
||||
}
|
||||
|
||||
return this.http.request<ChatMessageResponse>({
|
||||
method: "POST",
|
||||
path: "/chat-messages",
|
||||
data: payload,
|
||||
});
|
||||
}
|
||||
|
||||
stopChatMessage(
|
||||
taskId: string,
|
||||
user: string
|
||||
): Promise<DifyResponse<ChatMessageResponse>> {
|
||||
ensureNonEmptyString(taskId, "taskId");
|
||||
ensureNonEmptyString(user, "user");
|
||||
return this.http.request<ChatMessageResponse>({
|
||||
method: "POST",
|
||||
path: `/chat-messages/${taskId}/stop`,
|
||||
data: { user },
|
||||
});
|
||||
}
|
||||
|
||||
stopMessage(
|
||||
taskId: string,
|
||||
user: string
|
||||
): Promise<DifyResponse<ChatMessageResponse>> {
|
||||
return this.stopChatMessage(taskId, user);
|
||||
}
|
||||
|
||||
getSuggested(
|
||||
messageId: string,
|
||||
user: string
|
||||
): Promise<DifyResponse<ChatMessageResponse>> {
|
||||
ensureNonEmptyString(messageId, "messageId");
|
||||
ensureNonEmptyString(user, "user");
|
||||
return this.http.request<ChatMessageResponse>({
|
||||
method: "GET",
|
||||
path: `/messages/${messageId}/suggested`,
|
||||
query: { user },
|
||||
});
|
||||
}
|
||||
|
||||
// Note: messageFeedback is inherited from DifyClient
|
||||
|
||||
getAppFeedbacks(
|
||||
page?: number,
|
||||
limit?: number
|
||||
): Promise<DifyResponse<Record<string, unknown>>> {
|
||||
ensureOptionalInt(page, "page");
|
||||
ensureOptionalInt(limit, "limit");
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: "/app/feedbacks",
|
||||
query: {
|
||||
page,
|
||||
limit,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
getConversations(
|
||||
user: string,
|
||||
lastId?: string | null,
|
||||
limit?: number | null,
|
||||
sortByOrPinned?: string | boolean | null
|
||||
): Promise<DifyResponse<Record<string, unknown>>> {
|
||||
ensureNonEmptyString(user, "user");
|
||||
ensureOptionalString(lastId, "lastId");
|
||||
ensureOptionalInt(limit, "limit");
|
||||
|
||||
const params: QueryParams = { user };
|
||||
if (lastId) {
|
||||
params.last_id = lastId;
|
||||
}
|
||||
if (limit) {
|
||||
params.limit = limit;
|
||||
}
|
||||
if (typeof sortByOrPinned === "string") {
|
||||
params.sort_by = sortByOrPinned;
|
||||
} else if (typeof sortByOrPinned === "boolean") {
|
||||
params.pinned = sortByOrPinned;
|
||||
}
|
||||
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: "/conversations",
|
||||
query: params,
|
||||
});
|
||||
}
|
||||
|
||||
getConversationMessages(
|
||||
user: string,
|
||||
conversationId: string,
|
||||
firstId?: string | null,
|
||||
limit?: number | null
|
||||
): Promise<DifyResponse<Record<string, unknown>>> {
|
||||
ensureNonEmptyString(user, "user");
|
||||
ensureNonEmptyString(conversationId, "conversationId");
|
||||
ensureOptionalString(firstId, "firstId");
|
||||
ensureOptionalInt(limit, "limit");
|
||||
|
||||
const params: QueryParams = { user };
|
||||
params.conversation_id = conversationId;
|
||||
if (firstId) {
|
||||
params.first_id = firstId;
|
||||
}
|
||||
if (limit) {
|
||||
params.limit = limit;
|
||||
}
|
||||
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: "/messages",
|
||||
query: params,
|
||||
});
|
||||
}
|
||||
|
||||
renameConversation(
|
||||
conversationId: string,
|
||||
name: string,
|
||||
user: string,
|
||||
autoGenerate?: boolean
|
||||
): Promise<DifyResponse<Record<string, unknown>>>;
|
||||
renameConversation(
|
||||
conversationId: string,
|
||||
user: string,
|
||||
options?: { name?: string | null; autoGenerate?: boolean }
|
||||
): Promise<DifyResponse<Record<string, unknown>>>;
|
||||
renameConversation(
|
||||
conversationId: string,
|
||||
nameOrUser: string,
|
||||
userOrOptions?: string | { name?: string | null; autoGenerate?: boolean },
|
||||
autoGenerate?: boolean
|
||||
): Promise<DifyResponse<Record<string, unknown>>> {
|
||||
ensureNonEmptyString(conversationId, "conversationId");
|
||||
|
||||
let name: string | null | undefined;
|
||||
let user: string;
|
||||
let resolvedAutoGenerate: boolean;
|
||||
|
||||
if (typeof userOrOptions === "string" || userOrOptions === undefined) {
|
||||
name = nameOrUser;
|
||||
user = userOrOptions ?? "";
|
||||
resolvedAutoGenerate = autoGenerate ?? false;
|
||||
} else {
|
||||
user = nameOrUser;
|
||||
name = userOrOptions.name;
|
||||
resolvedAutoGenerate = userOrOptions.autoGenerate ?? false;
|
||||
}
|
||||
|
||||
ensureNonEmptyString(user, "user");
|
||||
if (!resolvedAutoGenerate) {
|
||||
ensureNonEmptyString(name, "name");
|
||||
}
|
||||
|
||||
const payload: Record<string, unknown> = {
|
||||
user,
|
||||
auto_generate: resolvedAutoGenerate,
|
||||
};
|
||||
if (typeof name === "string" && name.trim().length > 0) {
|
||||
payload.name = name;
|
||||
}
|
||||
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: `/conversations/${conversationId}/name`,
|
||||
data: payload,
|
||||
});
|
||||
}
|
||||
|
||||
deleteConversation(
|
||||
conversationId: string,
|
||||
user: string
|
||||
): Promise<DifyResponse<Record<string, unknown>>> {
|
||||
ensureNonEmptyString(conversationId, "conversationId");
|
||||
ensureNonEmptyString(user, "user");
|
||||
return this.http.request({
|
||||
method: "DELETE",
|
||||
path: `/conversations/${conversationId}`,
|
||||
data: { user },
|
||||
});
|
||||
}
|
||||
|
||||
getConversationVariables(
|
||||
conversationId: string,
|
||||
user: string,
|
||||
lastId?: string | null,
|
||||
limit?: number | null,
|
||||
variableName?: string | null
|
||||
): Promise<DifyResponse<Record<string, unknown>>> {
|
||||
ensureNonEmptyString(conversationId, "conversationId");
|
||||
ensureNonEmptyString(user, "user");
|
||||
ensureOptionalString(lastId, "lastId");
|
||||
ensureOptionalInt(limit, "limit");
|
||||
ensureOptionalString(variableName, "variableName");
|
||||
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: `/conversations/${conversationId}/variables`,
|
||||
query: {
|
||||
user,
|
||||
last_id: lastId ?? undefined,
|
||||
limit: limit ?? undefined,
|
||||
variable_name: variableName ?? undefined,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
updateConversationVariable(
|
||||
conversationId: string,
|
||||
variableId: string,
|
||||
user: string,
|
||||
value: unknown
|
||||
): Promise<DifyResponse<Record<string, unknown>>> {
|
||||
ensureNonEmptyString(conversationId, "conversationId");
|
||||
ensureNonEmptyString(variableId, "variableId");
|
||||
ensureNonEmptyString(user, "user");
|
||||
return this.http.request({
|
||||
method: "PUT",
|
||||
path: `/conversations/${conversationId}/variables/${variableId}`,
|
||||
data: {
|
||||
user,
|
||||
value,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
annotationReplyAction(
|
||||
action: "enable" | "disable",
|
||||
request: AnnotationReplyActionRequest
|
||||
): Promise<DifyResponse<AnnotationResponse>> {
|
||||
ensureNonEmptyString(action, "action");
|
||||
ensureNonEmptyString(request.embedding_provider_name, "embedding_provider_name");
|
||||
ensureNonEmptyString(request.embedding_model_name, "embedding_model_name");
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: `/apps/annotation-reply/${action}`,
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
getAnnotationReplyStatus(
|
||||
action: "enable" | "disable",
|
||||
jobId: string
|
||||
): Promise<DifyResponse<AnnotationResponse>> {
|
||||
ensureNonEmptyString(action, "action");
|
||||
ensureNonEmptyString(jobId, "jobId");
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: `/apps/annotation-reply/${action}/status/${jobId}`,
|
||||
});
|
||||
}
|
||||
|
||||
listAnnotations(
|
||||
options?: AnnotationListOptions
|
||||
): Promise<DifyResponse<AnnotationResponse>> {
|
||||
ensureOptionalInt(options?.page, "page");
|
||||
ensureOptionalInt(options?.limit, "limit");
|
||||
ensureOptionalString(options?.keyword, "keyword");
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: "/apps/annotations",
|
||||
query: {
|
||||
page: options?.page,
|
||||
limit: options?.limit,
|
||||
keyword: options?.keyword ?? undefined,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
createAnnotation(
|
||||
request: AnnotationCreateRequest
|
||||
): Promise<DifyResponse<AnnotationResponse>> {
|
||||
ensureNonEmptyString(request.question, "question");
|
||||
ensureNonEmptyString(request.answer, "answer");
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: "/apps/annotations",
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
updateAnnotation(
|
||||
annotationId: string,
|
||||
request: AnnotationCreateRequest
|
||||
): Promise<DifyResponse<AnnotationResponse>> {
|
||||
ensureNonEmptyString(annotationId, "annotationId");
|
||||
ensureNonEmptyString(request.question, "question");
|
||||
ensureNonEmptyString(request.answer, "answer");
|
||||
return this.http.request({
|
||||
method: "PUT",
|
||||
path: `/apps/annotations/${annotationId}`,
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
deleteAnnotation(
|
||||
annotationId: string
|
||||
): Promise<DifyResponse<AnnotationResponse>> {
|
||||
ensureNonEmptyString(annotationId, "annotationId");
|
||||
return this.http.request({
|
||||
method: "DELETE",
|
||||
path: `/apps/annotations/${annotationId}`,
|
||||
});
|
||||
}
|
||||
|
||||
// Note: audioToText is inherited from DifyClient
|
||||
}
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { CompletionClient } from "./completion";
|
||||
import { createHttpClientWithSpies } from "../../tests/test-utils";
|
||||
|
||||
describe("CompletionClient", () => {
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it("creates completion messages in blocking mode", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const completion = new CompletionClient(client);
|
||||
|
||||
await completion.createCompletionMessage({ input: "x" }, "user", false);
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/completion-messages",
|
||||
data: {
|
||||
inputs: { input: "x" },
|
||||
user: "user",
|
||||
files: undefined,
|
||||
response_mode: "blocking",
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("creates completion messages in streaming mode", async () => {
|
||||
const { client, requestStream } = createHttpClientWithSpies();
|
||||
const completion = new CompletionClient(client);
|
||||
|
||||
await completion.createCompletionMessage({
|
||||
inputs: { input: "x" },
|
||||
user: "user",
|
||||
response_mode: "streaming",
|
||||
});
|
||||
|
||||
expect(requestStream).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/completion-messages",
|
||||
data: {
|
||||
inputs: { input: "x" },
|
||||
user: "user",
|
||||
response_mode: "streaming",
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("stops completion messages", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const completion = new CompletionClient(client);
|
||||
|
||||
await completion.stopCompletionMessage("task", "user");
|
||||
await completion.stop("task", "user");
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/completion-messages/task/stop",
|
||||
data: { user: "user" },
|
||||
});
|
||||
});
|
||||
|
||||
it("supports deprecated runWorkflow", async () => {
|
||||
const { client, request, requestStream } = createHttpClientWithSpies();
|
||||
const completion = new CompletionClient(client);
|
||||
const warn = vi.spyOn(console, "warn").mockImplementation(() => {});
|
||||
|
||||
await completion.runWorkflow({ input: "x" }, "user", false);
|
||||
await completion.runWorkflow({ input: "x" }, "user", true);
|
||||
|
||||
expect(warn).toHaveBeenCalled();
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/workflows/run",
|
||||
data: { inputs: { input: "x" }, user: "user", response_mode: "blocking" },
|
||||
});
|
||||
expect(requestStream).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/workflows/run",
|
||||
data: { inputs: { input: "x" }, user: "user", response_mode: "streaming" },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
import { DifyClient } from "./base";
|
||||
import type { CompletionRequest, CompletionResponse } from "../types/completion";
|
||||
import type { DifyResponse, DifyStream } from "../types/common";
|
||||
import { ensureNonEmptyString } from "./validation";
|
||||
|
||||
const warned = new Set<string>();
|
||||
const warnOnce = (message: string): void => {
|
||||
if (warned.has(message)) {
|
||||
return;
|
||||
}
|
||||
warned.add(message);
|
||||
console.warn(message);
|
||||
};
|
||||
|
||||
export class CompletionClient extends DifyClient {
|
||||
createCompletionMessage(
|
||||
request: CompletionRequest
|
||||
): Promise<DifyResponse<CompletionResponse> | DifyStream<CompletionResponse>>;
|
||||
createCompletionMessage(
|
||||
inputs: Record<string, unknown>,
|
||||
user: string,
|
||||
stream?: boolean,
|
||||
files?: Array<Record<string, unknown>> | null
|
||||
): Promise<DifyResponse<CompletionResponse> | DifyStream<CompletionResponse>>;
|
||||
createCompletionMessage(
|
||||
inputOrRequest: CompletionRequest | Record<string, unknown>,
|
||||
user?: string,
|
||||
stream = false,
|
||||
files?: Array<Record<string, unknown>> | null
|
||||
): Promise<DifyResponse<CompletionResponse> | DifyStream<CompletionResponse>> {
|
||||
let payload: CompletionRequest;
|
||||
let shouldStream = stream;
|
||||
|
||||
if (user === undefined && "user" in (inputOrRequest as CompletionRequest)) {
|
||||
payload = inputOrRequest as CompletionRequest;
|
||||
shouldStream = payload.response_mode === "streaming";
|
||||
} else {
|
||||
ensureNonEmptyString(user, "user");
|
||||
payload = {
|
||||
inputs: inputOrRequest as Record<string, unknown>,
|
||||
user,
|
||||
files,
|
||||
response_mode: stream ? "streaming" : "blocking",
|
||||
};
|
||||
}
|
||||
|
||||
ensureNonEmptyString(payload.user, "user");
|
||||
|
||||
if (shouldStream) {
|
||||
return this.http.requestStream<CompletionResponse>({
|
||||
method: "POST",
|
||||
path: "/completion-messages",
|
||||
data: payload,
|
||||
});
|
||||
}
|
||||
|
||||
return this.http.request<CompletionResponse>({
|
||||
method: "POST",
|
||||
path: "/completion-messages",
|
||||
data: payload,
|
||||
});
|
||||
}
|
||||
|
||||
stopCompletionMessage(
|
||||
taskId: string,
|
||||
user: string
|
||||
): Promise<DifyResponse<CompletionResponse>> {
|
||||
ensureNonEmptyString(taskId, "taskId");
|
||||
ensureNonEmptyString(user, "user");
|
||||
return this.http.request<CompletionResponse>({
|
||||
method: "POST",
|
||||
path: `/completion-messages/${taskId}/stop`,
|
||||
data: { user },
|
||||
});
|
||||
}
|
||||
|
||||
stop(
|
||||
taskId: string,
|
||||
user: string
|
||||
): Promise<DifyResponse<CompletionResponse>> {
|
||||
return this.stopCompletionMessage(taskId, user);
|
||||
}
|
||||
|
||||
runWorkflow(
|
||||
inputs: Record<string, unknown>,
|
||||
user: string,
|
||||
stream = false
|
||||
): Promise<DifyResponse<Record<string, unknown>> | DifyStream<Record<string, unknown>>> {
|
||||
warnOnce(
|
||||
"CompletionClient.runWorkflow is deprecated. Use WorkflowClient.run instead."
|
||||
);
|
||||
ensureNonEmptyString(user, "user");
|
||||
const payload = {
|
||||
inputs,
|
||||
user,
|
||||
response_mode: stream ? "streaming" : "blocking",
|
||||
};
|
||||
if (stream) {
|
||||
return this.http.requestStream<Record<string, unknown>>({
|
||||
method: "POST",
|
||||
path: "/workflows/run",
|
||||
data: payload,
|
||||
});
|
||||
}
|
||||
return this.http.request<Record<string, unknown>>({
|
||||
method: "POST",
|
||||
path: "/workflows/run",
|
||||
data: payload,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,249 @@
|
|||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { KnowledgeBaseClient } from "./knowledge-base";
|
||||
import { createHttpClientWithSpies } from "../../tests/test-utils";
|
||||
|
||||
describe("KnowledgeBaseClient", () => {
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it("handles dataset and tag operations", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const kb = new KnowledgeBaseClient(client);
|
||||
|
||||
await kb.listDatasets({
|
||||
page: 1,
|
||||
limit: 2,
|
||||
keyword: "k",
|
||||
includeAll: true,
|
||||
tagIds: ["t1"],
|
||||
});
|
||||
await kb.createDataset({ name: "dataset" });
|
||||
await kb.getDataset("ds");
|
||||
await kb.updateDataset("ds", { name: "new" });
|
||||
await kb.deleteDataset("ds");
|
||||
await kb.updateDocumentStatus("ds", "enable", ["doc1"]);
|
||||
|
||||
await kb.listTags();
|
||||
await kb.createTag({ name: "tag" });
|
||||
await kb.updateTag({ tag_id: "tag", name: "name" });
|
||||
await kb.deleteTag({ tag_id: "tag" });
|
||||
await kb.bindTags({ tag_ids: ["tag"], target_id: "doc" });
|
||||
await kb.unbindTags({ tag_id: "tag", target_id: "doc" });
|
||||
await kb.getDatasetTags("ds");
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/datasets",
|
||||
query: {
|
||||
page: 1,
|
||||
limit: 2,
|
||||
keyword: "k",
|
||||
include_all: true,
|
||||
tag_ids: ["t1"],
|
||||
},
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/datasets",
|
||||
data: { name: "dataset" },
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "PATCH",
|
||||
path: "/datasets/ds",
|
||||
data: { name: "new" },
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "PATCH",
|
||||
path: "/datasets/ds/documents/status/enable",
|
||||
data: { document_ids: ["doc1"] },
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/datasets/tags/binding",
|
||||
data: { tag_ids: ["tag"], target_id: "doc" },
|
||||
});
|
||||
});
|
||||
|
||||
it("handles document operations", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const kb = new KnowledgeBaseClient(client);
|
||||
const form = { append: vi.fn(), getHeaders: () => ({}) };
|
||||
|
||||
await kb.createDocumentByText("ds", { name: "doc", text: "text" });
|
||||
await kb.updateDocumentByText("ds", "doc", { name: "doc2" });
|
||||
await kb.createDocumentByFile("ds", form);
|
||||
await kb.updateDocumentByFile("ds", "doc", form);
|
||||
await kb.listDocuments("ds", { page: 1, limit: 20, keyword: "k" });
|
||||
await kb.getDocument("ds", "doc", { metadata: "all" });
|
||||
await kb.deleteDocument("ds", "doc");
|
||||
await kb.getDocumentIndexingStatus("ds", "batch");
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/datasets/ds/document/create_by_text",
|
||||
data: { name: "doc", text: "text" },
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/datasets/ds/documents/doc/update_by_text",
|
||||
data: { name: "doc2" },
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/datasets/ds/document/create_by_file",
|
||||
data: form,
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/datasets/ds/documents",
|
||||
query: { page: 1, limit: 20, keyword: "k", status: undefined },
|
||||
});
|
||||
});
|
||||
|
||||
it("handles segments and child chunks", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const kb = new KnowledgeBaseClient(client);
|
||||
|
||||
await kb.createSegments("ds", "doc", { segments: [{ content: "x" }] });
|
||||
await kb.listSegments("ds", "doc", { page: 1, limit: 10, keyword: "k" });
|
||||
await kb.getSegment("ds", "doc", "seg");
|
||||
await kb.updateSegment("ds", "doc", "seg", {
|
||||
segment: { content: "y" },
|
||||
});
|
||||
await kb.deleteSegment("ds", "doc", "seg");
|
||||
|
||||
await kb.createChildChunk("ds", "doc", "seg", { content: "c" });
|
||||
await kb.listChildChunks("ds", "doc", "seg", { page: 1, limit: 10 });
|
||||
await kb.updateChildChunk("ds", "doc", "seg", "child", {
|
||||
content: "c2",
|
||||
});
|
||||
await kb.deleteChildChunk("ds", "doc", "seg", "child");
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/datasets/ds/documents/doc/segments",
|
||||
data: { segments: [{ content: "x" }] },
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/datasets/ds/documents/doc/segments/seg",
|
||||
data: { segment: { content: "y" } },
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "PATCH",
|
||||
path: "/datasets/ds/documents/doc/segments/seg/child_chunks/child",
|
||||
data: { content: "c2" },
|
||||
});
|
||||
});
|
||||
|
||||
it("handles metadata and retrieval", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const kb = new KnowledgeBaseClient(client);
|
||||
|
||||
await kb.listMetadata("ds");
|
||||
await kb.createMetadata("ds", { name: "m", type: "string" });
|
||||
await kb.updateMetadata("ds", "mid", { name: "m2" });
|
||||
await kb.deleteMetadata("ds", "mid");
|
||||
await kb.listBuiltInMetadata("ds");
|
||||
await kb.updateBuiltInMetadata("ds", "enable");
|
||||
await kb.updateDocumentsMetadata("ds", {
|
||||
operation_data: [
|
||||
{ document_id: "doc", metadata_list: [{ id: "m", name: "n" }] },
|
||||
],
|
||||
});
|
||||
await kb.hitTesting("ds", { query: "q" });
|
||||
await kb.retrieve("ds", { query: "q" });
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/datasets/ds/metadata",
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/datasets/ds/metadata",
|
||||
data: { name: "m", type: "string" },
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/datasets/ds/hit-testing",
|
||||
data: { query: "q" },
|
||||
});
|
||||
});
|
||||
|
||||
it("handles pipeline operations", async () => {
|
||||
const { client, request, requestStream } = createHttpClientWithSpies();
|
||||
const kb = new KnowledgeBaseClient(client);
|
||||
const warn = vi.spyOn(console, "warn").mockImplementation(() => {});
|
||||
const form = { append: vi.fn(), getHeaders: () => ({}) };
|
||||
|
||||
await kb.listDatasourcePlugins("ds", { isPublished: true });
|
||||
await kb.runDatasourceNode("ds", "node", {
|
||||
inputs: { input: "x" },
|
||||
datasource_type: "custom",
|
||||
is_published: true,
|
||||
});
|
||||
await kb.runPipeline("ds", {
|
||||
inputs: { input: "x" },
|
||||
datasource_type: "custom",
|
||||
datasource_info_list: [],
|
||||
start_node_id: "start",
|
||||
is_published: true,
|
||||
response_mode: "streaming",
|
||||
});
|
||||
await kb.runPipeline("ds", {
|
||||
inputs: { input: "x" },
|
||||
datasource_type: "custom",
|
||||
datasource_info_list: [],
|
||||
start_node_id: "start",
|
||||
is_published: true,
|
||||
response_mode: "blocking",
|
||||
});
|
||||
await kb.uploadPipelineFile(form);
|
||||
|
||||
expect(warn).toHaveBeenCalled();
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/datasets/ds/pipeline/datasource-plugins",
|
||||
query: { is_published: true },
|
||||
});
|
||||
expect(requestStream).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/datasets/ds/pipeline/datasource/nodes/node/run",
|
||||
data: {
|
||||
inputs: { input: "x" },
|
||||
datasource_type: "custom",
|
||||
is_published: true,
|
||||
},
|
||||
});
|
||||
expect(requestStream).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/datasets/ds/pipeline/run",
|
||||
data: {
|
||||
inputs: { input: "x" },
|
||||
datasource_type: "custom",
|
||||
datasource_info_list: [],
|
||||
start_node_id: "start",
|
||||
is_published: true,
|
||||
response_mode: "streaming",
|
||||
},
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/datasets/ds/pipeline/run",
|
||||
data: {
|
||||
inputs: { input: "x" },
|
||||
datasource_type: "custom",
|
||||
datasource_info_list: [],
|
||||
start_node_id: "start",
|
||||
is_published: true,
|
||||
response_mode: "blocking",
|
||||
},
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/datasets/pipeline/file-upload",
|
||||
data: form,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,706 @@
|
|||
import { DifyClient } from "./base";
|
||||
import type {
|
||||
DatasetCreateRequest,
|
||||
DatasetListOptions,
|
||||
DatasetTagBindingRequest,
|
||||
DatasetTagCreateRequest,
|
||||
DatasetTagDeleteRequest,
|
||||
DatasetTagUnbindingRequest,
|
||||
DatasetTagUpdateRequest,
|
||||
DatasetUpdateRequest,
|
||||
DocumentGetOptions,
|
||||
DocumentListOptions,
|
||||
DocumentStatusAction,
|
||||
DocumentTextCreateRequest,
|
||||
DocumentTextUpdateRequest,
|
||||
SegmentCreateRequest,
|
||||
SegmentListOptions,
|
||||
SegmentUpdateRequest,
|
||||
ChildChunkCreateRequest,
|
||||
ChildChunkListOptions,
|
||||
ChildChunkUpdateRequest,
|
||||
MetadataCreateRequest,
|
||||
MetadataOperationRequest,
|
||||
MetadataUpdateRequest,
|
||||
HitTestingRequest,
|
||||
DatasourcePluginListOptions,
|
||||
DatasourceNodeRunRequest,
|
||||
PipelineRunRequest,
|
||||
KnowledgeBaseResponse,
|
||||
PipelineStreamEvent,
|
||||
} from "../types/knowledge-base";
|
||||
import type { DifyResponse, DifyStream, QueryParams } from "../types/common";
|
||||
import {
|
||||
ensureNonEmptyString,
|
||||
ensureOptionalBoolean,
|
||||
ensureOptionalInt,
|
||||
ensureOptionalString,
|
||||
ensureStringArray,
|
||||
} from "./validation";
|
||||
import { FileUploadError, ValidationError } from "../errors/dify-error";
|
||||
import { isFormData } from "../http/form-data";
|
||||
|
||||
const warned = new Set<string>();
|
||||
const warnOnce = (message: string): void => {
|
||||
if (warned.has(message)) {
|
||||
return;
|
||||
}
|
||||
warned.add(message);
|
||||
console.warn(message);
|
||||
};
|
||||
|
||||
const ensureFormData = (form: unknown, context: string): void => {
|
||||
if (!isFormData(form)) {
|
||||
throw new FileUploadError(`${context} requires FormData`);
|
||||
}
|
||||
};
|
||||
|
||||
const ensureNonEmptyArray = (value: unknown, name: string): void => {
|
||||
if (!Array.isArray(value) || value.length === 0) {
|
||||
throw new ValidationError(`${name} must be a non-empty array`);
|
||||
}
|
||||
};
|
||||
|
||||
const warnPipelineRoutes = (): void => {
|
||||
warnOnce(
|
||||
"RAG pipeline endpoints may be unavailable unless the service API registers dataset/rag_pipeline routes."
|
||||
);
|
||||
};
|
||||
|
||||
export class KnowledgeBaseClient extends DifyClient {
|
||||
async listDatasets(
|
||||
options?: DatasetListOptions
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureOptionalInt(options?.page, "page");
|
||||
ensureOptionalInt(options?.limit, "limit");
|
||||
ensureOptionalString(options?.keyword, "keyword");
|
||||
ensureOptionalBoolean(options?.includeAll, "includeAll");
|
||||
|
||||
const query: QueryParams = {
|
||||
page: options?.page,
|
||||
limit: options?.limit,
|
||||
keyword: options?.keyword ?? undefined,
|
||||
include_all: options?.includeAll ?? undefined,
|
||||
};
|
||||
|
||||
if (options?.tagIds && options.tagIds.length > 0) {
|
||||
ensureStringArray(options.tagIds, "tagIds");
|
||||
query.tag_ids = options.tagIds;
|
||||
}
|
||||
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: "/datasets",
|
||||
query,
|
||||
});
|
||||
}
|
||||
|
||||
async createDataset(
|
||||
request: DatasetCreateRequest
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(request.name, "name");
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: "/datasets",
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async getDataset(datasetId: string): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: `/datasets/${datasetId}`,
|
||||
});
|
||||
}
|
||||
|
||||
async updateDataset(
|
||||
datasetId: string,
|
||||
request: DatasetUpdateRequest
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
if (request.name !== undefined && request.name !== null) {
|
||||
ensureNonEmptyString(request.name, "name");
|
||||
}
|
||||
return this.http.request({
|
||||
method: "PATCH",
|
||||
path: `/datasets/${datasetId}`,
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async deleteDataset(datasetId: string): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
return this.http.request({
|
||||
method: "DELETE",
|
||||
path: `/datasets/${datasetId}`,
|
||||
});
|
||||
}
|
||||
|
||||
async updateDocumentStatus(
|
||||
datasetId: string,
|
||||
action: DocumentStatusAction,
|
||||
documentIds: string[]
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(action, "action");
|
||||
ensureStringArray(documentIds, "documentIds");
|
||||
return this.http.request({
|
||||
method: "PATCH",
|
||||
path: `/datasets/${datasetId}/documents/status/${action}`,
|
||||
data: {
|
||||
document_ids: documentIds,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async listTags(): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: "/datasets/tags",
|
||||
});
|
||||
}
|
||||
|
||||
async createTag(
|
||||
request: DatasetTagCreateRequest
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(request.name, "name");
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: "/datasets/tags",
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async updateTag(
|
||||
request: DatasetTagUpdateRequest
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(request.tag_id, "tag_id");
|
||||
ensureNonEmptyString(request.name, "name");
|
||||
return this.http.request({
|
||||
method: "PATCH",
|
||||
path: "/datasets/tags",
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async deleteTag(
|
||||
request: DatasetTagDeleteRequest
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(request.tag_id, "tag_id");
|
||||
return this.http.request({
|
||||
method: "DELETE",
|
||||
path: "/datasets/tags",
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async bindTags(
|
||||
request: DatasetTagBindingRequest
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureStringArray(request.tag_ids, "tag_ids");
|
||||
ensureNonEmptyString(request.target_id, "target_id");
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: "/datasets/tags/binding",
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async unbindTags(
|
||||
request: DatasetTagUnbindingRequest
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(request.tag_id, "tag_id");
|
||||
ensureNonEmptyString(request.target_id, "target_id");
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: "/datasets/tags/unbinding",
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async getDatasetTags(
|
||||
datasetId: string
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: `/datasets/${datasetId}/tags`,
|
||||
});
|
||||
}
|
||||
|
||||
async createDocumentByText(
|
||||
datasetId: string,
|
||||
request: DocumentTextCreateRequest
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(request.name, "name");
|
||||
ensureNonEmptyString(request.text, "text");
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: `/datasets/${datasetId}/document/create_by_text`,
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async updateDocumentByText(
|
||||
datasetId: string,
|
||||
documentId: string,
|
||||
request: DocumentTextUpdateRequest
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(documentId, "documentId");
|
||||
if (request.name !== undefined && request.name !== null) {
|
||||
ensureNonEmptyString(request.name, "name");
|
||||
}
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: `/datasets/${datasetId}/documents/${documentId}/update_by_text`,
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async createDocumentByFile(
|
||||
datasetId: string,
|
||||
form: unknown
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureFormData(form, "createDocumentByFile");
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: `/datasets/${datasetId}/document/create_by_file`,
|
||||
data: form,
|
||||
});
|
||||
}
|
||||
|
||||
async updateDocumentByFile(
|
||||
datasetId: string,
|
||||
documentId: string,
|
||||
form: unknown
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(documentId, "documentId");
|
||||
ensureFormData(form, "updateDocumentByFile");
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: `/datasets/${datasetId}/documents/${documentId}/update_by_file`,
|
||||
data: form,
|
||||
});
|
||||
}
|
||||
|
||||
async listDocuments(
|
||||
datasetId: string,
|
||||
options?: DocumentListOptions
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureOptionalInt(options?.page, "page");
|
||||
ensureOptionalInt(options?.limit, "limit");
|
||||
ensureOptionalString(options?.keyword, "keyword");
|
||||
ensureOptionalString(options?.status, "status");
|
||||
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: `/datasets/${datasetId}/documents`,
|
||||
query: {
|
||||
page: options?.page,
|
||||
limit: options?.limit,
|
||||
keyword: options?.keyword ?? undefined,
|
||||
status: options?.status ?? undefined,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async getDocument(
|
||||
datasetId: string,
|
||||
documentId: string,
|
||||
options?: DocumentGetOptions
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(documentId, "documentId");
|
||||
if (options?.metadata) {
|
||||
const allowed = new Set(["all", "only", "without"]);
|
||||
if (!allowed.has(options.metadata)) {
|
||||
throw new ValidationError("metadata must be one of all, only, without");
|
||||
}
|
||||
}
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: `/datasets/${datasetId}/documents/${documentId}`,
|
||||
query: {
|
||||
metadata: options?.metadata ?? undefined,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async deleteDocument(
|
||||
datasetId: string,
|
||||
documentId: string
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(documentId, "documentId");
|
||||
return this.http.request({
|
||||
method: "DELETE",
|
||||
path: `/datasets/${datasetId}/documents/${documentId}`,
|
||||
});
|
||||
}
|
||||
|
||||
async getDocumentIndexingStatus(
|
||||
datasetId: string,
|
||||
batch: string
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(batch, "batch");
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: `/datasets/${datasetId}/documents/${batch}/indexing-status`,
|
||||
});
|
||||
}
|
||||
|
||||
async createSegments(
|
||||
datasetId: string,
|
||||
documentId: string,
|
||||
request: SegmentCreateRequest
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(documentId, "documentId");
|
||||
ensureNonEmptyArray(request.segments, "segments");
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: `/datasets/${datasetId}/documents/${documentId}/segments`,
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async listSegments(
|
||||
datasetId: string,
|
||||
documentId: string,
|
||||
options?: SegmentListOptions
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(documentId, "documentId");
|
||||
ensureOptionalInt(options?.page, "page");
|
||||
ensureOptionalInt(options?.limit, "limit");
|
||||
ensureOptionalString(options?.keyword, "keyword");
|
||||
if (options?.status && options.status.length > 0) {
|
||||
ensureStringArray(options.status, "status");
|
||||
}
|
||||
|
||||
const query: QueryParams = {
|
||||
page: options?.page,
|
||||
limit: options?.limit,
|
||||
keyword: options?.keyword ?? undefined,
|
||||
};
|
||||
if (options?.status && options.status.length > 0) {
|
||||
query.status = options.status;
|
||||
}
|
||||
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: `/datasets/${datasetId}/documents/${documentId}/segments`,
|
||||
query,
|
||||
});
|
||||
}
|
||||
|
||||
async getSegment(
|
||||
datasetId: string,
|
||||
documentId: string,
|
||||
segmentId: string
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(documentId, "documentId");
|
||||
ensureNonEmptyString(segmentId, "segmentId");
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: `/datasets/${datasetId}/documents/${documentId}/segments/${segmentId}`,
|
||||
});
|
||||
}
|
||||
|
||||
async updateSegment(
|
||||
datasetId: string,
|
||||
documentId: string,
|
||||
segmentId: string,
|
||||
request: SegmentUpdateRequest
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(documentId, "documentId");
|
||||
ensureNonEmptyString(segmentId, "segmentId");
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: `/datasets/${datasetId}/documents/${documentId}/segments/${segmentId}`,
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async deleteSegment(
|
||||
datasetId: string,
|
||||
documentId: string,
|
||||
segmentId: string
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(documentId, "documentId");
|
||||
ensureNonEmptyString(segmentId, "segmentId");
|
||||
return this.http.request({
|
||||
method: "DELETE",
|
||||
path: `/datasets/${datasetId}/documents/${documentId}/segments/${segmentId}`,
|
||||
});
|
||||
}
|
||||
|
||||
async createChildChunk(
|
||||
datasetId: string,
|
||||
documentId: string,
|
||||
segmentId: string,
|
||||
request: ChildChunkCreateRequest
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(documentId, "documentId");
|
||||
ensureNonEmptyString(segmentId, "segmentId");
|
||||
ensureNonEmptyString(request.content, "content");
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: `/datasets/${datasetId}/documents/${documentId}/segments/${segmentId}/child_chunks`,
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async listChildChunks(
|
||||
datasetId: string,
|
||||
documentId: string,
|
||||
segmentId: string,
|
||||
options?: ChildChunkListOptions
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(documentId, "documentId");
|
||||
ensureNonEmptyString(segmentId, "segmentId");
|
||||
ensureOptionalInt(options?.page, "page");
|
||||
ensureOptionalInt(options?.limit, "limit");
|
||||
ensureOptionalString(options?.keyword, "keyword");
|
||||
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: `/datasets/${datasetId}/documents/${documentId}/segments/${segmentId}/child_chunks`,
|
||||
query: {
|
||||
page: options?.page,
|
||||
limit: options?.limit,
|
||||
keyword: options?.keyword ?? undefined,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async updateChildChunk(
|
||||
datasetId: string,
|
||||
documentId: string,
|
||||
segmentId: string,
|
||||
childChunkId: string,
|
||||
request: ChildChunkUpdateRequest
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(documentId, "documentId");
|
||||
ensureNonEmptyString(segmentId, "segmentId");
|
||||
ensureNonEmptyString(childChunkId, "childChunkId");
|
||||
ensureNonEmptyString(request.content, "content");
|
||||
return this.http.request({
|
||||
method: "PATCH",
|
||||
path: `/datasets/${datasetId}/documents/${documentId}/segments/${segmentId}/child_chunks/${childChunkId}`,
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async deleteChildChunk(
|
||||
datasetId: string,
|
||||
documentId: string,
|
||||
segmentId: string,
|
||||
childChunkId: string
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(documentId, "documentId");
|
||||
ensureNonEmptyString(segmentId, "segmentId");
|
||||
ensureNonEmptyString(childChunkId, "childChunkId");
|
||||
return this.http.request({
|
||||
method: "DELETE",
|
||||
path: `/datasets/${datasetId}/documents/${documentId}/segments/${segmentId}/child_chunks/${childChunkId}`,
|
||||
});
|
||||
}
|
||||
|
||||
async listMetadata(
|
||||
datasetId: string
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: `/datasets/${datasetId}/metadata`,
|
||||
});
|
||||
}
|
||||
|
||||
async createMetadata(
|
||||
datasetId: string,
|
||||
request: MetadataCreateRequest
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(request.name, "name");
|
||||
ensureNonEmptyString(request.type, "type");
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: `/datasets/${datasetId}/metadata`,
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async updateMetadata(
|
||||
datasetId: string,
|
||||
metadataId: string,
|
||||
request: MetadataUpdateRequest
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(metadataId, "metadataId");
|
||||
ensureNonEmptyString(request.name, "name");
|
||||
return this.http.request({
|
||||
method: "PATCH",
|
||||
path: `/datasets/${datasetId}/metadata/${metadataId}`,
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async deleteMetadata(
|
||||
datasetId: string,
|
||||
metadataId: string
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(metadataId, "metadataId");
|
||||
return this.http.request({
|
||||
method: "DELETE",
|
||||
path: `/datasets/${datasetId}/metadata/${metadataId}`,
|
||||
});
|
||||
}
|
||||
|
||||
async listBuiltInMetadata(
|
||||
datasetId: string
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: `/datasets/${datasetId}/metadata/built-in`,
|
||||
});
|
||||
}
|
||||
|
||||
async updateBuiltInMetadata(
|
||||
datasetId: string,
|
||||
action: "enable" | "disable"
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(action, "action");
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: `/datasets/${datasetId}/metadata/built-in/${action}`,
|
||||
});
|
||||
}
|
||||
|
||||
async updateDocumentsMetadata(
|
||||
datasetId: string,
|
||||
request: MetadataOperationRequest
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyArray(request.operation_data, "operation_data");
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: `/datasets/${datasetId}/documents/metadata`,
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async hitTesting(
|
||||
datasetId: string,
|
||||
request: HitTestingRequest
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
if (request.query !== undefined && request.query !== null) {
|
||||
ensureOptionalString(request.query, "query");
|
||||
}
|
||||
if (request.attachment_ids && request.attachment_ids.length > 0) {
|
||||
ensureStringArray(request.attachment_ids, "attachment_ids");
|
||||
}
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: `/datasets/${datasetId}/hit-testing`,
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async retrieve(
|
||||
datasetId: string,
|
||||
request: HitTestingRequest
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: `/datasets/${datasetId}/retrieve`,
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async listDatasourcePlugins(
|
||||
datasetId: string,
|
||||
options?: DatasourcePluginListOptions
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
warnPipelineRoutes();
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureOptionalBoolean(options?.isPublished, "isPublished");
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: `/datasets/${datasetId}/pipeline/datasource-plugins`,
|
||||
query: {
|
||||
is_published: options?.isPublished ?? undefined,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async runDatasourceNode(
|
||||
datasetId: string,
|
||||
nodeId: string,
|
||||
request: DatasourceNodeRunRequest
|
||||
): Promise<DifyStream<PipelineStreamEvent>> {
|
||||
warnPipelineRoutes();
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(nodeId, "nodeId");
|
||||
ensureNonEmptyString(request.datasource_type, "datasource_type");
|
||||
return this.http.requestStream<PipelineStreamEvent>({
|
||||
method: "POST",
|
||||
path: `/datasets/${datasetId}/pipeline/datasource/nodes/${nodeId}/run`,
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async runPipeline(
|
||||
datasetId: string,
|
||||
request: PipelineRunRequest
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse> | DifyStream<PipelineStreamEvent>> {
|
||||
warnPipelineRoutes();
|
||||
ensureNonEmptyString(datasetId, "datasetId");
|
||||
ensureNonEmptyString(request.datasource_type, "datasource_type");
|
||||
ensureNonEmptyString(request.start_node_id, "start_node_id");
|
||||
const shouldStream = request.response_mode === "streaming";
|
||||
if (shouldStream) {
|
||||
return this.http.requestStream<PipelineStreamEvent>({
|
||||
method: "POST",
|
||||
path: `/datasets/${datasetId}/pipeline/run`,
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
return this.http.request<KnowledgeBaseResponse>({
|
||||
method: "POST",
|
||||
path: `/datasets/${datasetId}/pipeline/run`,
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
async uploadPipelineFile(
|
||||
form: unknown
|
||||
): Promise<DifyResponse<KnowledgeBaseResponse>> {
|
||||
warnPipelineRoutes();
|
||||
ensureFormData(form, "uploadPipelineFile");
|
||||
return this.http.request({
|
||||
method: "POST",
|
||||
path: "/datasets/pipeline/file-upload",
|
||||
data: form,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
import { describe, expect, it } from "vitest";
|
||||
import {
|
||||
ensureNonEmptyString,
|
||||
ensureOptionalBoolean,
|
||||
ensureOptionalInt,
|
||||
ensureOptionalString,
|
||||
ensureOptionalStringArray,
|
||||
ensureRating,
|
||||
ensureStringArray,
|
||||
validateParams,
|
||||
} from "./validation";
|
||||
|
||||
const makeLongString = (length) => "a".repeat(length);
|
||||
|
||||
describe("validation utilities", () => {
|
||||
it("ensureNonEmptyString throws on empty or whitespace", () => {
|
||||
expect(() => ensureNonEmptyString("", "name")).toThrow();
|
||||
expect(() => ensureNonEmptyString(" ", "name")).toThrow();
|
||||
});
|
||||
|
||||
it("ensureNonEmptyString throws on overly long strings", () => {
|
||||
expect(() =>
|
||||
ensureNonEmptyString(makeLongString(10001), "name")
|
||||
).toThrow();
|
||||
});
|
||||
|
||||
it("ensureOptionalString ignores undefined and validates when set", () => {
|
||||
expect(() => ensureOptionalString(undefined, "opt")).not.toThrow();
|
||||
expect(() => ensureOptionalString("", "opt")).toThrow();
|
||||
});
|
||||
|
||||
it("ensureOptionalString throws on overly long strings", () => {
|
||||
expect(() => ensureOptionalString(makeLongString(10001), "opt")).toThrow();
|
||||
});
|
||||
|
||||
it("ensureOptionalInt validates integer", () => {
|
||||
expect(() => ensureOptionalInt(undefined, "limit")).not.toThrow();
|
||||
expect(() => ensureOptionalInt(1.2, "limit")).toThrow();
|
||||
});
|
||||
|
||||
it("ensureOptionalBoolean validates boolean", () => {
|
||||
expect(() => ensureOptionalBoolean(undefined, "flag")).not.toThrow();
|
||||
expect(() => ensureOptionalBoolean("yes", "flag")).toThrow();
|
||||
});
|
||||
|
||||
it("ensureStringArray enforces size and content", () => {
|
||||
expect(() => ensureStringArray([], "items")).toThrow();
|
||||
expect(() => ensureStringArray([""], "items")).toThrow();
|
||||
expect(() =>
|
||||
ensureStringArray(Array.from({ length: 1001 }, () => "a"), "items")
|
||||
).toThrow();
|
||||
expect(() => ensureStringArray(["ok"], "items")).not.toThrow();
|
||||
});
|
||||
|
||||
it("ensureOptionalStringArray ignores undefined", () => {
|
||||
expect(() => ensureOptionalStringArray(undefined, "tags")).not.toThrow();
|
||||
});
|
||||
|
||||
it("ensureOptionalStringArray validates when set", () => {
|
||||
expect(() => ensureOptionalStringArray(["valid"], "tags")).not.toThrow();
|
||||
expect(() => ensureOptionalStringArray([], "tags")).toThrow();
|
||||
expect(() => ensureOptionalStringArray([""], "tags")).toThrow();
|
||||
});
|
||||
|
||||
it("ensureRating validates allowed values", () => {
|
||||
expect(() => ensureRating(undefined)).not.toThrow();
|
||||
expect(() => ensureRating("like")).not.toThrow();
|
||||
expect(() => ensureRating("bad")).toThrow();
|
||||
});
|
||||
|
||||
it("validateParams enforces generic rules", () => {
|
||||
expect(() => validateParams({ user: 123 })).toThrow();
|
||||
expect(() => validateParams({ rating: "bad" })).toThrow();
|
||||
expect(() => validateParams({ page: 1.1 })).toThrow();
|
||||
expect(() => validateParams({ files: "bad" })).toThrow();
|
||||
// Empty strings are allowed for optional params (e.g., keyword: "" means no filter)
|
||||
expect(() => validateParams({ keyword: "" })).not.toThrow();
|
||||
expect(() => validateParams({ name: makeLongString(10001) })).toThrow();
|
||||
expect(() =>
|
||||
validateParams({ items: Array.from({ length: 1001 }, () => "a") })
|
||||
).toThrow();
|
||||
expect(() =>
|
||||
validateParams({
|
||||
data: Object.fromEntries(
|
||||
Array.from({ length: 101 }, (_, i) => [String(i), i])
|
||||
),
|
||||
})
|
||||
).toThrow();
|
||||
expect(() => validateParams({ user: "u", page: 1 })).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
import { ValidationError } from "../errors/dify-error";
|
||||
|
||||
const MAX_STRING_LENGTH = 10000;
|
||||
const MAX_LIST_LENGTH = 1000;
|
||||
const MAX_DICT_LENGTH = 100;
|
||||
|
||||
export function ensureNonEmptyString(
|
||||
value: unknown,
|
||||
name: string
|
||||
): asserts value is string {
|
||||
if (typeof value !== "string" || value.trim().length === 0) {
|
||||
throw new ValidationError(`${name} must be a non-empty string`);
|
||||
}
|
||||
if (value.length > MAX_STRING_LENGTH) {
|
||||
throw new ValidationError(
|
||||
`${name} exceeds maximum length of ${MAX_STRING_LENGTH} characters`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates optional string fields that must be non-empty when provided.
|
||||
* Use this for fields like `name` that are optional but should not be empty strings.
|
||||
*
|
||||
* For filter parameters that accept empty strings (e.g., `keyword: ""`),
|
||||
* use `validateParams` which allows empty strings for optional params.
|
||||
*/
|
||||
export function ensureOptionalString(value: unknown, name: string): void {
|
||||
if (value === undefined || value === null) {
|
||||
return;
|
||||
}
|
||||
if (typeof value !== "string" || value.trim().length === 0) {
|
||||
throw new ValidationError(`${name} must be a non-empty string when set`);
|
||||
}
|
||||
if (value.length > MAX_STRING_LENGTH) {
|
||||
throw new ValidationError(
|
||||
`${name} exceeds maximum length of ${MAX_STRING_LENGTH} characters`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export function ensureOptionalInt(value: unknown, name: string): void {
|
||||
if (value === undefined || value === null) {
|
||||
return;
|
||||
}
|
||||
if (!Number.isInteger(value)) {
|
||||
throw new ValidationError(`${name} must be an integer when set`);
|
||||
}
|
||||
}
|
||||
|
||||
export function ensureOptionalBoolean(value: unknown, name: string): void {
|
||||
if (value === undefined || value === null) {
|
||||
return;
|
||||
}
|
||||
if (typeof value !== "boolean") {
|
||||
throw new ValidationError(`${name} must be a boolean when set`);
|
||||
}
|
||||
}
|
||||
|
||||
export function ensureStringArray(value: unknown, name: string): void {
|
||||
if (!Array.isArray(value) || value.length === 0) {
|
||||
throw new ValidationError(`${name} must be a non-empty string array`);
|
||||
}
|
||||
if (value.length > MAX_LIST_LENGTH) {
|
||||
throw new ValidationError(
|
||||
`${name} exceeds maximum size of ${MAX_LIST_LENGTH} items`
|
||||
);
|
||||
}
|
||||
value.forEach((item) => {
|
||||
if (typeof item !== "string" || item.trim().length === 0) {
|
||||
throw new ValidationError(`${name} must contain non-empty strings`);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
export function ensureOptionalStringArray(value: unknown, name: string): void {
|
||||
if (value === undefined || value === null) {
|
||||
return;
|
||||
}
|
||||
ensureStringArray(value, name);
|
||||
}
|
||||
|
||||
export function ensureRating(value: unknown): void {
|
||||
if (value === undefined || value === null) {
|
||||
return;
|
||||
}
|
||||
if (value !== "like" && value !== "dislike") {
|
||||
throw new ValidationError("rating must be either 'like' or 'dislike'");
|
||||
}
|
||||
}
|
||||
|
||||
export function validateParams(params: Record<string, unknown>): void {
|
||||
Object.entries(params).forEach(([key, value]) => {
|
||||
if (value === undefined || value === null) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Only check max length for strings; empty strings are allowed for optional params
|
||||
// Required fields are validated at method level via ensureNonEmptyString
|
||||
if (typeof value === "string") {
|
||||
if (value.length > MAX_STRING_LENGTH) {
|
||||
throw new ValidationError(
|
||||
`Parameter '${key}' exceeds maximum length of ${MAX_STRING_LENGTH} characters`
|
||||
);
|
||||
}
|
||||
} else if (Array.isArray(value)) {
|
||||
if (value.length > MAX_LIST_LENGTH) {
|
||||
throw new ValidationError(
|
||||
`Parameter '${key}' exceeds maximum size of ${MAX_LIST_LENGTH} items`
|
||||
);
|
||||
}
|
||||
} else if (typeof value === "object") {
|
||||
if (Object.keys(value as Record<string, unknown>).length > MAX_DICT_LENGTH) {
|
||||
throw new ValidationError(
|
||||
`Parameter '${key}' exceeds maximum size of ${MAX_DICT_LENGTH} items`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (key === "user" && typeof value !== "string") {
|
||||
throw new ValidationError(`Parameter '${key}' must be a string`);
|
||||
}
|
||||
if (
|
||||
(key === "page" || key === "limit" || key === "page_size") &&
|
||||
!Number.isInteger(value)
|
||||
) {
|
||||
throw new ValidationError(`Parameter '${key}' must be an integer`);
|
||||
}
|
||||
if (key === "files" && !Array.isArray(value) && typeof value !== "object") {
|
||||
throw new ValidationError(`Parameter '${key}' must be a list or dict`);
|
||||
}
|
||||
if (key === "rating" && value !== "like" && value !== "dislike") {
|
||||
throw new ValidationError(`Parameter '${key}' must be 'like' or 'dislike'`);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { WorkflowClient } from "./workflow";
|
||||
import { createHttpClientWithSpies } from "../../tests/test-utils";
|
||||
|
||||
describe("WorkflowClient", () => {
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it("runs workflows with blocking and streaming modes", async () => {
|
||||
const { client, request, requestStream } = createHttpClientWithSpies();
|
||||
const workflow = new WorkflowClient(client);
|
||||
|
||||
await workflow.run({ inputs: { input: "x" }, user: "user" });
|
||||
await workflow.run({ input: "x" }, "user", true);
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/workflows/run",
|
||||
data: {
|
||||
inputs: { input: "x" },
|
||||
user: "user",
|
||||
},
|
||||
});
|
||||
expect(requestStream).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/workflows/run",
|
||||
data: {
|
||||
inputs: { input: "x" },
|
||||
user: "user",
|
||||
response_mode: "streaming",
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("runs workflow by id", async () => {
|
||||
const { client, request, requestStream } = createHttpClientWithSpies();
|
||||
const workflow = new WorkflowClient(client);
|
||||
|
||||
await workflow.runById("wf", {
|
||||
inputs: { input: "x" },
|
||||
user: "user",
|
||||
response_mode: "blocking",
|
||||
});
|
||||
await workflow.runById("wf", {
|
||||
inputs: { input: "x" },
|
||||
user: "user",
|
||||
response_mode: "streaming",
|
||||
});
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/workflows/wf/run",
|
||||
data: {
|
||||
inputs: { input: "x" },
|
||||
user: "user",
|
||||
response_mode: "blocking",
|
||||
},
|
||||
});
|
||||
expect(requestStream).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/workflows/wf/run",
|
||||
data: {
|
||||
inputs: { input: "x" },
|
||||
user: "user",
|
||||
response_mode: "streaming",
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("gets run details and stops workflow", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const workflow = new WorkflowClient(client);
|
||||
|
||||
await workflow.getRun("run");
|
||||
await workflow.stop("task", "user");
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/workflows/run/run",
|
||||
});
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "POST",
|
||||
path: "/workflows/tasks/task/stop",
|
||||
data: { user: "user" },
|
||||
});
|
||||
});
|
||||
|
||||
it("fetches workflow logs", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const workflow = new WorkflowClient(client);
|
||||
|
||||
// Use createdByEndUserSessionId to filter by user session (backend API parameter)
|
||||
await workflow.getLogs({
|
||||
keyword: "k",
|
||||
status: "succeeded",
|
||||
startTime: "2024-01-01",
|
||||
endTime: "2024-01-02",
|
||||
createdByEndUserSessionId: "session-123",
|
||||
page: 1,
|
||||
limit: 20,
|
||||
});
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/workflows/logs",
|
||||
query: {
|
||||
keyword: "k",
|
||||
status: "succeeded",
|
||||
created_at__before: "2024-01-02",
|
||||
created_at__after: "2024-01-01",
|
||||
created_by_end_user_session_id: "session-123",
|
||||
created_by_account: undefined,
|
||||
page: 1,
|
||||
limit: 20,
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,165 @@
|
|||
import { DifyClient } from "./base";
|
||||
import type { WorkflowRunRequest, WorkflowRunResponse } from "../types/workflow";
|
||||
import type { DifyResponse, DifyStream, QueryParams } from "../types/common";
|
||||
import {
|
||||
ensureNonEmptyString,
|
||||
ensureOptionalInt,
|
||||
ensureOptionalString,
|
||||
} from "./validation";
|
||||
|
||||
export class WorkflowClient extends DifyClient {
|
||||
run(
|
||||
request: WorkflowRunRequest
|
||||
): Promise<DifyResponse<WorkflowRunResponse> | DifyStream<WorkflowRunResponse>>;
|
||||
run(
|
||||
inputs: Record<string, unknown>,
|
||||
user: string,
|
||||
stream?: boolean
|
||||
): Promise<DifyResponse<WorkflowRunResponse> | DifyStream<WorkflowRunResponse>>;
|
||||
run(
|
||||
inputOrRequest: WorkflowRunRequest | Record<string, unknown>,
|
||||
user?: string,
|
||||
stream = false
|
||||
): Promise<DifyResponse<WorkflowRunResponse> | DifyStream<WorkflowRunResponse>> {
|
||||
let payload: WorkflowRunRequest;
|
||||
let shouldStream = stream;
|
||||
|
||||
if (user === undefined && "user" in (inputOrRequest as WorkflowRunRequest)) {
|
||||
payload = inputOrRequest as WorkflowRunRequest;
|
||||
shouldStream = payload.response_mode === "streaming";
|
||||
} else {
|
||||
ensureNonEmptyString(user, "user");
|
||||
payload = {
|
||||
inputs: inputOrRequest as Record<string, unknown>,
|
||||
user,
|
||||
response_mode: stream ? "streaming" : "blocking",
|
||||
};
|
||||
}
|
||||
|
||||
ensureNonEmptyString(payload.user, "user");
|
||||
|
||||
if (shouldStream) {
|
||||
return this.http.requestStream<WorkflowRunResponse>({
|
||||
method: "POST",
|
||||
path: "/workflows/run",
|
||||
data: payload,
|
||||
});
|
||||
}
|
||||
|
||||
return this.http.request<WorkflowRunResponse>({
|
||||
method: "POST",
|
||||
path: "/workflows/run",
|
||||
data: payload,
|
||||
});
|
||||
}
|
||||
|
||||
runById(
|
||||
workflowId: string,
|
||||
request: WorkflowRunRequest
|
||||
): Promise<DifyResponse<WorkflowRunResponse> | DifyStream<WorkflowRunResponse>> {
|
||||
ensureNonEmptyString(workflowId, "workflowId");
|
||||
ensureNonEmptyString(request.user, "user");
|
||||
if (request.response_mode === "streaming") {
|
||||
return this.http.requestStream<WorkflowRunResponse>({
|
||||
method: "POST",
|
||||
path: `/workflows/${workflowId}/run`,
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
return this.http.request<WorkflowRunResponse>({
|
||||
method: "POST",
|
||||
path: `/workflows/${workflowId}/run`,
|
||||
data: request,
|
||||
});
|
||||
}
|
||||
|
||||
getRun(workflowRunId: string): Promise<DifyResponse<WorkflowRunResponse>> {
|
||||
ensureNonEmptyString(workflowRunId, "workflowRunId");
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: `/workflows/run/${workflowRunId}`,
|
||||
});
|
||||
}
|
||||
|
||||
stop(
|
||||
taskId: string,
|
||||
user: string
|
||||
): Promise<DifyResponse<WorkflowRunResponse>> {
|
||||
ensureNonEmptyString(taskId, "taskId");
|
||||
ensureNonEmptyString(user, "user");
|
||||
return this.http.request<WorkflowRunResponse>({
|
||||
method: "POST",
|
||||
path: `/workflows/tasks/${taskId}/stop`,
|
||||
data: { user },
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Get workflow execution logs with filtering options.
|
||||
*
|
||||
* Note: The backend API filters by `createdByEndUserSessionId` (end user session ID)
|
||||
* or `createdByAccount` (account ID), not by a generic `user` parameter.
|
||||
*/
|
||||
getLogs(options?: {
|
||||
keyword?: string;
|
||||
status?: string;
|
||||
createdAtBefore?: string;
|
||||
createdAtAfter?: string;
|
||||
createdByEndUserSessionId?: string;
|
||||
createdByAccount?: string;
|
||||
page?: number;
|
||||
limit?: number;
|
||||
startTime?: string;
|
||||
endTime?: string;
|
||||
}): Promise<DifyResponse<Record<string, unknown>>> {
|
||||
if (options?.keyword) {
|
||||
ensureOptionalString(options.keyword, "keyword");
|
||||
}
|
||||
if (options?.status) {
|
||||
ensureOptionalString(options.status, "status");
|
||||
}
|
||||
if (options?.createdAtBefore) {
|
||||
ensureOptionalString(options.createdAtBefore, "createdAtBefore");
|
||||
}
|
||||
if (options?.createdAtAfter) {
|
||||
ensureOptionalString(options.createdAtAfter, "createdAtAfter");
|
||||
}
|
||||
if (options?.createdByEndUserSessionId) {
|
||||
ensureOptionalString(
|
||||
options.createdByEndUserSessionId,
|
||||
"createdByEndUserSessionId"
|
||||
);
|
||||
}
|
||||
if (options?.createdByAccount) {
|
||||
ensureOptionalString(options.createdByAccount, "createdByAccount");
|
||||
}
|
||||
if (options?.startTime) {
|
||||
ensureOptionalString(options.startTime, "startTime");
|
||||
}
|
||||
if (options?.endTime) {
|
||||
ensureOptionalString(options.endTime, "endTime");
|
||||
}
|
||||
ensureOptionalInt(options?.page, "page");
|
||||
ensureOptionalInt(options?.limit, "limit");
|
||||
|
||||
const createdAtAfter = options?.createdAtAfter ?? options?.startTime;
|
||||
const createdAtBefore = options?.createdAtBefore ?? options?.endTime;
|
||||
|
||||
const query: QueryParams = {
|
||||
keyword: options?.keyword,
|
||||
status: options?.status,
|
||||
created_at__before: createdAtBefore,
|
||||
created_at__after: createdAtAfter,
|
||||
created_by_end_user_session_id: options?.createdByEndUserSessionId,
|
||||
created_by_account: options?.createdByAccount,
|
||||
page: options?.page,
|
||||
limit: options?.limit,
|
||||
};
|
||||
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: "/workflows/logs",
|
||||
query,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { WorkspaceClient } from "./workspace";
|
||||
import { createHttpClientWithSpies } from "../../tests/test-utils";
|
||||
|
||||
describe("WorkspaceClient", () => {
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it("gets models by type", async () => {
|
||||
const { client, request } = createHttpClientWithSpies();
|
||||
const workspace = new WorkspaceClient(client);
|
||||
|
||||
await workspace.getModelsByType("llm");
|
||||
|
||||
expect(request).toHaveBeenCalledWith({
|
||||
method: "GET",
|
||||
path: "/workspaces/current/models/model-types/llm",
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
import { DifyClient } from "./base";
|
||||
import type { WorkspaceModelType, WorkspaceModelsResponse } from "../types/workspace";
|
||||
import type { DifyResponse } from "../types/common";
|
||||
import { ensureNonEmptyString } from "./validation";
|
||||
|
||||
export class WorkspaceClient extends DifyClient {
|
||||
async getModelsByType(
|
||||
modelType: WorkspaceModelType
|
||||
): Promise<DifyResponse<WorkspaceModelsResponse>> {
|
||||
ensureNonEmptyString(modelType, "modelType");
|
||||
return this.http.request({
|
||||
method: "GET",
|
||||
path: `/workspaces/current/models/model-types/${modelType}`,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
import { describe, expect, it } from "vitest";
|
||||
import {
|
||||
APIError,
|
||||
AuthenticationError,
|
||||
DifyError,
|
||||
FileUploadError,
|
||||
NetworkError,
|
||||
RateLimitError,
|
||||
TimeoutError,
|
||||
ValidationError,
|
||||
} from "./dify-error";
|
||||
|
||||
describe("Dify errors", () => {
|
||||
it("sets base error fields", () => {
|
||||
const err = new DifyError("base", {
|
||||
statusCode: 400,
|
||||
responseBody: { message: "bad" },
|
||||
requestId: "req",
|
||||
retryAfter: 1,
|
||||
});
|
||||
expect(err.name).toBe("DifyError");
|
||||
expect(err.statusCode).toBe(400);
|
||||
expect(err.responseBody).toEqual({ message: "bad" });
|
||||
expect(err.requestId).toBe("req");
|
||||
expect(err.retryAfter).toBe(1);
|
||||
});
|
||||
|
||||
it("creates specific error types", () => {
|
||||
expect(new APIError("api").name).toBe("APIError");
|
||||
expect(new AuthenticationError("auth").name).toBe("AuthenticationError");
|
||||
expect(new RateLimitError("rate").name).toBe("RateLimitError");
|
||||
expect(new ValidationError("val").name).toBe("ValidationError");
|
||||
expect(new NetworkError("net").name).toBe("NetworkError");
|
||||
expect(new TimeoutError("timeout").name).toBe("TimeoutError");
|
||||
expect(new FileUploadError("upload").name).toBe("FileUploadError");
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
export type DifyErrorOptions = {
|
||||
statusCode?: number;
|
||||
responseBody?: unknown;
|
||||
requestId?: string;
|
||||
retryAfter?: number;
|
||||
cause?: unknown;
|
||||
};
|
||||
|
||||
export class DifyError extends Error {
|
||||
statusCode?: number;
|
||||
responseBody?: unknown;
|
||||
requestId?: string;
|
||||
retryAfter?: number;
|
||||
|
||||
constructor(message: string, options: DifyErrorOptions = {}) {
|
||||
super(message);
|
||||
this.name = "DifyError";
|
||||
this.statusCode = options.statusCode;
|
||||
this.responseBody = options.responseBody;
|
||||
this.requestId = options.requestId;
|
||||
this.retryAfter = options.retryAfter;
|
||||
if (options.cause) {
|
||||
(this as { cause?: unknown }).cause = options.cause;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class APIError extends DifyError {
|
||||
constructor(message: string, options: DifyErrorOptions = {}) {
|
||||
super(message, options);
|
||||
this.name = "APIError";
|
||||
}
|
||||
}
|
||||
|
||||
export class AuthenticationError extends APIError {
|
||||
constructor(message: string, options: DifyErrorOptions = {}) {
|
||||
super(message, options);
|
||||
this.name = "AuthenticationError";
|
||||
}
|
||||
}
|
||||
|
||||
export class RateLimitError extends APIError {
|
||||
constructor(message: string, options: DifyErrorOptions = {}) {
|
||||
super(message, options);
|
||||
this.name = "RateLimitError";
|
||||
}
|
||||
}
|
||||
|
||||
export class ValidationError extends APIError {
|
||||
constructor(message: string, options: DifyErrorOptions = {}) {
|
||||
super(message, options);
|
||||
this.name = "ValidationError";
|
||||
}
|
||||
}
|
||||
|
||||
export class NetworkError extends DifyError {
|
||||
constructor(message: string, options: DifyErrorOptions = {}) {
|
||||
super(message, options);
|
||||
this.name = "NetworkError";
|
||||
}
|
||||
}
|
||||
|
||||
export class TimeoutError extends DifyError {
|
||||
constructor(message: string, options: DifyErrorOptions = {}) {
|
||||
super(message, options);
|
||||
this.name = "TimeoutError";
|
||||
}
|
||||
}
|
||||
|
||||
export class FileUploadError extends DifyError {
|
||||
constructor(message: string, options: DifyErrorOptions = {}) {
|
||||
super(message, options);
|
||||
this.name = "FileUploadError";
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,304 @@
|
|||
import axios from "axios";
|
||||
import { Readable } from "node:stream";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import {
|
||||
APIError,
|
||||
AuthenticationError,
|
||||
FileUploadError,
|
||||
NetworkError,
|
||||
RateLimitError,
|
||||
TimeoutError,
|
||||
ValidationError,
|
||||
} from "../errors/dify-error";
|
||||
import { HttpClient } from "./client";
|
||||
|
||||
describe("HttpClient", () => {
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
it("builds requests with auth headers and JSON content type", async () => {
|
||||
const mockRequest = vi.fn().mockResolvedValue({
|
||||
status: 200,
|
||||
data: { ok: true },
|
||||
headers: { "x-request-id": "req" },
|
||||
});
|
||||
vi.spyOn(axios, "create").mockReturnValue({ request: mockRequest });
|
||||
|
||||
const client = new HttpClient({ apiKey: "test" });
|
||||
const response = await client.request({
|
||||
method: "POST",
|
||||
path: "/chat-messages",
|
||||
data: { user: "u" },
|
||||
});
|
||||
|
||||
expect(response.requestId).toBe("req");
|
||||
const config = mockRequest.mock.calls[0][0];
|
||||
expect(config.headers.Authorization).toBe("Bearer test");
|
||||
expect(config.headers["Content-Type"]).toBe("application/json");
|
||||
expect(config.responseType).toBe("json");
|
||||
});
|
||||
|
||||
it("serializes array query params", async () => {
|
||||
const mockRequest = vi.fn().mockResolvedValue({
|
||||
status: 200,
|
||||
data: "ok",
|
||||
headers: {},
|
||||
});
|
||||
vi.spyOn(axios, "create").mockReturnValue({ request: mockRequest });
|
||||
|
||||
const client = new HttpClient({ apiKey: "test" });
|
||||
await client.requestRaw({
|
||||
method: "GET",
|
||||
path: "/datasets",
|
||||
query: { tag_ids: ["a", "b"], limit: 2 },
|
||||
});
|
||||
|
||||
const config = mockRequest.mock.calls[0][0];
|
||||
const queryString = config.paramsSerializer.serialize({
|
||||
tag_ids: ["a", "b"],
|
||||
limit: 2,
|
||||
});
|
||||
expect(queryString).toBe("tag_ids=a&tag_ids=b&limit=2");
|
||||
});
|
||||
|
||||
it("returns SSE stream helpers", async () => {
|
||||
const mockRequest = vi.fn().mockResolvedValue({
|
||||
status: 200,
|
||||
data: Readable.from(["data: {\"text\":\"hi\"}\n\n"]),
|
||||
headers: { "x-request-id": "req" },
|
||||
});
|
||||
vi.spyOn(axios, "create").mockReturnValue({ request: mockRequest });
|
||||
|
||||
const client = new HttpClient({ apiKey: "test" });
|
||||
const stream = await client.requestStream({
|
||||
method: "POST",
|
||||
path: "/chat-messages",
|
||||
data: { user: "u" },
|
||||
});
|
||||
|
||||
expect(stream.status).toBe(200);
|
||||
expect(stream.requestId).toBe("req");
|
||||
await expect(stream.toText()).resolves.toBe("hi");
|
||||
});
|
||||
|
||||
it("returns binary stream helpers", async () => {
|
||||
const mockRequest = vi.fn().mockResolvedValue({
|
||||
status: 200,
|
||||
data: Readable.from(["chunk"]),
|
||||
headers: { "x-request-id": "req" },
|
||||
});
|
||||
vi.spyOn(axios, "create").mockReturnValue({ request: mockRequest });
|
||||
|
||||
const client = new HttpClient({ apiKey: "test" });
|
||||
const stream = await client.requestBinaryStream({
|
||||
method: "POST",
|
||||
path: "/text-to-audio",
|
||||
data: { user: "u", text: "hi" },
|
||||
});
|
||||
|
||||
expect(stream.status).toBe(200);
|
||||
expect(stream.requestId).toBe("req");
|
||||
});
|
||||
|
||||
it("respects form-data headers", async () => {
|
||||
const mockRequest = vi.fn().mockResolvedValue({
|
||||
status: 200,
|
||||
data: "ok",
|
||||
headers: {},
|
||||
});
|
||||
vi.spyOn(axios, "create").mockReturnValue({ request: mockRequest });
|
||||
|
||||
const client = new HttpClient({ apiKey: "test" });
|
||||
const form = {
|
||||
append: () => {},
|
||||
getHeaders: () => ({ "content-type": "multipart/form-data; boundary=abc" }),
|
||||
};
|
||||
|
||||
await client.requestRaw({
|
||||
method: "POST",
|
||||
path: "/files/upload",
|
||||
data: form,
|
||||
});
|
||||
|
||||
const config = mockRequest.mock.calls[0][0];
|
||||
expect(config.headers["content-type"]).toBe(
|
||||
"multipart/form-data; boundary=abc"
|
||||
);
|
||||
expect(config.headers["Content-Type"]).toBeUndefined();
|
||||
});
|
||||
|
||||
it("maps 401 and 429 errors", async () => {
|
||||
const mockRequest = vi.fn();
|
||||
vi.spyOn(axios, "create").mockReturnValue({ request: mockRequest });
|
||||
const client = new HttpClient({ apiKey: "test", maxRetries: 0 });
|
||||
|
||||
mockRequest.mockRejectedValueOnce({
|
||||
isAxiosError: true,
|
||||
response: {
|
||||
status: 401,
|
||||
data: { message: "unauthorized" },
|
||||
headers: {},
|
||||
},
|
||||
});
|
||||
await expect(
|
||||
client.requestRaw({ method: "GET", path: "/meta" })
|
||||
).rejects.toBeInstanceOf(AuthenticationError);
|
||||
|
||||
mockRequest.mockRejectedValueOnce({
|
||||
isAxiosError: true,
|
||||
response: {
|
||||
status: 429,
|
||||
data: { message: "rate" },
|
||||
headers: { "retry-after": "2" },
|
||||
},
|
||||
});
|
||||
const error = await client
|
||||
.requestRaw({ method: "GET", path: "/meta" })
|
||||
.catch((err) => err);
|
||||
expect(error).toBeInstanceOf(RateLimitError);
|
||||
expect(error.retryAfter).toBe(2);
|
||||
});
|
||||
|
||||
it("maps validation and upload errors", async () => {
|
||||
const mockRequest = vi.fn();
|
||||
vi.spyOn(axios, "create").mockReturnValue({ request: mockRequest });
|
||||
const client = new HttpClient({ apiKey: "test", maxRetries: 0 });
|
||||
|
||||
mockRequest.mockRejectedValueOnce({
|
||||
isAxiosError: true,
|
||||
response: {
|
||||
status: 422,
|
||||
data: { message: "invalid" },
|
||||
headers: {},
|
||||
},
|
||||
});
|
||||
await expect(
|
||||
client.requestRaw({ method: "POST", path: "/chat-messages", data: { user: "u" } })
|
||||
).rejects.toBeInstanceOf(ValidationError);
|
||||
|
||||
mockRequest.mockRejectedValueOnce({
|
||||
isAxiosError: true,
|
||||
config: { url: "/files/upload" },
|
||||
response: {
|
||||
status: 400,
|
||||
data: { message: "bad upload" },
|
||||
headers: {},
|
||||
},
|
||||
});
|
||||
await expect(
|
||||
client.requestRaw({ method: "POST", path: "/files/upload", data: { user: "u" } })
|
||||
).rejects.toBeInstanceOf(FileUploadError);
|
||||
});
|
||||
|
||||
it("maps timeout and network errors", async () => {
|
||||
const mockRequest = vi.fn();
|
||||
vi.spyOn(axios, "create").mockReturnValue({ request: mockRequest });
|
||||
const client = new HttpClient({ apiKey: "test", maxRetries: 0 });
|
||||
|
||||
mockRequest.mockRejectedValueOnce({
|
||||
isAxiosError: true,
|
||||
code: "ECONNABORTED",
|
||||
message: "timeout",
|
||||
});
|
||||
await expect(
|
||||
client.requestRaw({ method: "GET", path: "/meta" })
|
||||
).rejects.toBeInstanceOf(TimeoutError);
|
||||
|
||||
mockRequest.mockRejectedValueOnce({
|
||||
isAxiosError: true,
|
||||
message: "network",
|
||||
});
|
||||
await expect(
|
||||
client.requestRaw({ method: "GET", path: "/meta" })
|
||||
).rejects.toBeInstanceOf(NetworkError);
|
||||
});
|
||||
|
||||
it("retries on timeout errors", async () => {
|
||||
const mockRequest = vi.fn();
|
||||
vi.spyOn(axios, "create").mockReturnValue({ request: mockRequest });
|
||||
const client = new HttpClient({ apiKey: "test", maxRetries: 1, retryDelay: 0 });
|
||||
|
||||
mockRequest
|
||||
.mockRejectedValueOnce({
|
||||
isAxiosError: true,
|
||||
code: "ECONNABORTED",
|
||||
message: "timeout",
|
||||
})
|
||||
.mockResolvedValueOnce({ status: 200, data: "ok", headers: {} });
|
||||
|
||||
await client.requestRaw({ method: "GET", path: "/meta" });
|
||||
expect(mockRequest).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it("validates query parameters before request", async () => {
|
||||
const mockRequest = vi.fn();
|
||||
vi.spyOn(axios, "create").mockReturnValue({ request: mockRequest });
|
||||
const client = new HttpClient({ apiKey: "test" });
|
||||
|
||||
await expect(
|
||||
client.requestRaw({ method: "GET", path: "/meta", query: { user: 1 } })
|
||||
).rejects.toBeInstanceOf(ValidationError);
|
||||
expect(mockRequest).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("returns APIError for other http failures", async () => {
|
||||
const mockRequest = vi.fn();
|
||||
vi.spyOn(axios, "create").mockReturnValue({ request: mockRequest });
|
||||
const client = new HttpClient({ apiKey: "test", maxRetries: 0 });
|
||||
|
||||
mockRequest.mockRejectedValueOnce({
|
||||
isAxiosError: true,
|
||||
response: { status: 500, data: { message: "server" }, headers: {} },
|
||||
});
|
||||
|
||||
await expect(
|
||||
client.requestRaw({ method: "GET", path: "/meta" })
|
||||
).rejects.toBeInstanceOf(APIError);
|
||||
});
|
||||
|
||||
it("logs requests and responses when enableLogging is true", async () => {
|
||||
const mockRequest = vi.fn().mockResolvedValue({
|
||||
status: 200,
|
||||
data: { ok: true },
|
||||
headers: {},
|
||||
});
|
||||
vi.spyOn(axios, "create").mockReturnValue({ request: mockRequest });
|
||||
const consoleInfo = vi.spyOn(console, "info").mockImplementation(() => {});
|
||||
|
||||
const client = new HttpClient({ apiKey: "test", enableLogging: true });
|
||||
await client.requestRaw({ method: "GET", path: "/meta" });
|
||||
|
||||
expect(consoleInfo).toHaveBeenCalledWith(
|
||||
expect.stringContaining("dify-client-node response 200 GET")
|
||||
);
|
||||
consoleInfo.mockRestore();
|
||||
});
|
||||
|
||||
it("logs retry attempts when enableLogging is true", async () => {
|
||||
const mockRequest = vi.fn();
|
||||
vi.spyOn(axios, "create").mockReturnValue({ request: mockRequest });
|
||||
const consoleInfo = vi.spyOn(console, "info").mockImplementation(() => {});
|
||||
|
||||
const client = new HttpClient({
|
||||
apiKey: "test",
|
||||
maxRetries: 1,
|
||||
retryDelay: 0,
|
||||
enableLogging: true,
|
||||
});
|
||||
|
||||
mockRequest
|
||||
.mockRejectedValueOnce({
|
||||
isAxiosError: true,
|
||||
code: "ECONNABORTED",
|
||||
message: "timeout",
|
||||
})
|
||||
.mockResolvedValueOnce({ status: 200, data: "ok", headers: {} });
|
||||
|
||||
await client.requestRaw({ method: "GET", path: "/meta" });
|
||||
|
||||
expect(consoleInfo).toHaveBeenCalledWith(
|
||||
expect.stringContaining("dify-client-node retry")
|
||||
);
|
||||
consoleInfo.mockRestore();
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,368 @@
|
|||
import axios from "axios";
|
||||
import type {
|
||||
AxiosError,
|
||||
AxiosInstance,
|
||||
AxiosRequestConfig,
|
||||
AxiosResponse,
|
||||
} from "axios";
|
||||
import type { Readable } from "node:stream";
|
||||
import {
|
||||
DEFAULT_BASE_URL,
|
||||
DEFAULT_MAX_RETRIES,
|
||||
DEFAULT_RETRY_DELAY_SECONDS,
|
||||
DEFAULT_TIMEOUT_SECONDS,
|
||||
} from "../types/common";
|
||||
import type {
|
||||
DifyClientConfig,
|
||||
DifyResponse,
|
||||
Headers,
|
||||
QueryParams,
|
||||
RequestMethod,
|
||||
} from "../types/common";
|
||||
import type { DifyError } from "../errors/dify-error";
|
||||
import {
|
||||
APIError,
|
||||
AuthenticationError,
|
||||
FileUploadError,
|
||||
NetworkError,
|
||||
RateLimitError,
|
||||
TimeoutError,
|
||||
ValidationError,
|
||||
} from "../errors/dify-error";
|
||||
import { getFormDataHeaders, isFormData } from "./form-data";
|
||||
import { createBinaryStream, createSseStream } from "./sse";
|
||||
import { getRetryDelayMs, shouldRetry, sleep } from "./retry";
|
||||
import { validateParams } from "../client/validation";
|
||||
|
||||
const DEFAULT_USER_AGENT = "dify-client-node";
|
||||
|
||||
export type RequestOptions = {
|
||||
method: RequestMethod;
|
||||
path: string;
|
||||
query?: QueryParams;
|
||||
data?: unknown;
|
||||
headers?: Headers;
|
||||
responseType?: AxiosRequestConfig["responseType"];
|
||||
};
|
||||
|
||||
export type HttpClientSettings = Required<
|
||||
Omit<DifyClientConfig, "apiKey">
|
||||
> & {
|
||||
apiKey: string;
|
||||
};
|
||||
|
||||
const normalizeSettings = (config: DifyClientConfig): HttpClientSettings => ({
|
||||
apiKey: config.apiKey,
|
||||
baseUrl: config.baseUrl ?? DEFAULT_BASE_URL,
|
||||
timeout: config.timeout ?? DEFAULT_TIMEOUT_SECONDS,
|
||||
maxRetries: config.maxRetries ?? DEFAULT_MAX_RETRIES,
|
||||
retryDelay: config.retryDelay ?? DEFAULT_RETRY_DELAY_SECONDS,
|
||||
enableLogging: config.enableLogging ?? false,
|
||||
});
|
||||
|
||||
const normalizeHeaders = (headers: AxiosResponse["headers"]): Headers => {
|
||||
const result: Headers = {};
|
||||
if (!headers) {
|
||||
return result;
|
||||
}
|
||||
Object.entries(headers).forEach(([key, value]) => {
|
||||
if (Array.isArray(value)) {
|
||||
result[key.toLowerCase()] = value.join(", ");
|
||||
} else if (typeof value === "string") {
|
||||
result[key.toLowerCase()] = value;
|
||||
} else if (typeof value === "number") {
|
||||
result[key.toLowerCase()] = value.toString();
|
||||
}
|
||||
});
|
||||
return result;
|
||||
};
|
||||
|
||||
const resolveRequestId = (headers: Headers): string | undefined =>
|
||||
headers["x-request-id"] ?? headers["x-requestid"];
|
||||
|
||||
const buildRequestUrl = (baseUrl: string, path: string): string => {
|
||||
const trimmed = baseUrl.replace(/\/+$/, "");
|
||||
return `${trimmed}${path}`;
|
||||
};
|
||||
|
||||
const buildQueryString = (params?: QueryParams): string => {
|
||||
if (!params) {
|
||||
return "";
|
||||
}
|
||||
const searchParams = new URLSearchParams();
|
||||
Object.entries(params).forEach(([key, value]) => {
|
||||
if (value === undefined || value === null) {
|
||||
return;
|
||||
}
|
||||
if (Array.isArray(value)) {
|
||||
value.forEach((item) => {
|
||||
searchParams.append(key, String(item));
|
||||
});
|
||||
return;
|
||||
}
|
||||
searchParams.append(key, String(value));
|
||||
});
|
||||
return searchParams.toString();
|
||||
};
|
||||
|
||||
const parseRetryAfterSeconds = (headerValue?: string): number | undefined => {
|
||||
if (!headerValue) {
|
||||
return undefined;
|
||||
}
|
||||
const asNumber = Number.parseInt(headerValue, 10);
|
||||
if (!Number.isNaN(asNumber)) {
|
||||
return asNumber;
|
||||
}
|
||||
const asDate = Date.parse(headerValue);
|
||||
if (!Number.isNaN(asDate)) {
|
||||
const diff = asDate - Date.now();
|
||||
return diff > 0 ? Math.ceil(diff / 1000) : 0;
|
||||
}
|
||||
return undefined;
|
||||
};
|
||||
|
||||
const isReadableStream = (value: unknown): value is Readable => {
|
||||
if (!value || typeof value !== "object") {
|
||||
return false;
|
||||
}
|
||||
return typeof (value as { pipe?: unknown }).pipe === "function";
|
||||
};
|
||||
|
||||
const isUploadLikeRequest = (config?: AxiosRequestConfig): boolean => {
|
||||
const url = (config?.url ?? "").toLowerCase();
|
||||
if (!url) {
|
||||
return false;
|
||||
}
|
||||
return (
|
||||
url.includes("upload") ||
|
||||
url.includes("/files/") ||
|
||||
url.includes("audio-to-text") ||
|
||||
url.includes("create_by_file") ||
|
||||
url.includes("update_by_file")
|
||||
);
|
||||
};
|
||||
|
||||
const resolveErrorMessage = (status: number, responseBody: unknown): string => {
|
||||
if (typeof responseBody === "string" && responseBody.trim().length > 0) {
|
||||
return responseBody;
|
||||
}
|
||||
if (
|
||||
responseBody &&
|
||||
typeof responseBody === "object" &&
|
||||
"message" in responseBody
|
||||
) {
|
||||
const message = (responseBody as Record<string, unknown>).message;
|
||||
if (typeof message === "string" && message.trim().length > 0) {
|
||||
return message;
|
||||
}
|
||||
}
|
||||
return `Request failed with status code ${status}`;
|
||||
};
|
||||
|
||||
const mapAxiosError = (error: unknown): DifyError => {
|
||||
if (axios.isAxiosError(error)) {
|
||||
const axiosError = error as AxiosError;
|
||||
if (axiosError.response) {
|
||||
const status = axiosError.response.status;
|
||||
const headers = normalizeHeaders(axiosError.response.headers);
|
||||
const requestId = resolveRequestId(headers);
|
||||
const responseBody = axiosError.response.data;
|
||||
const message = resolveErrorMessage(status, responseBody);
|
||||
|
||||
if (status === 401) {
|
||||
return new AuthenticationError(message, {
|
||||
statusCode: status,
|
||||
responseBody,
|
||||
requestId,
|
||||
});
|
||||
}
|
||||
if (status === 429) {
|
||||
const retryAfter = parseRetryAfterSeconds(headers["retry-after"]);
|
||||
return new RateLimitError(message, {
|
||||
statusCode: status,
|
||||
responseBody,
|
||||
requestId,
|
||||
retryAfter,
|
||||
});
|
||||
}
|
||||
if (status === 422) {
|
||||
return new ValidationError(message, {
|
||||
statusCode: status,
|
||||
responseBody,
|
||||
requestId,
|
||||
});
|
||||
}
|
||||
if (status === 400) {
|
||||
if (isUploadLikeRequest(axiosError.config)) {
|
||||
return new FileUploadError(message, {
|
||||
statusCode: status,
|
||||
responseBody,
|
||||
requestId,
|
||||
});
|
||||
}
|
||||
}
|
||||
return new APIError(message, {
|
||||
statusCode: status,
|
||||
responseBody,
|
||||
requestId,
|
||||
});
|
||||
}
|
||||
if (axiosError.code === "ECONNABORTED") {
|
||||
return new TimeoutError("Request timed out", { cause: axiosError });
|
||||
}
|
||||
return new NetworkError(axiosError.message, { cause: axiosError });
|
||||
}
|
||||
if (error instanceof Error) {
|
||||
return new NetworkError(error.message, { cause: error });
|
||||
}
|
||||
return new NetworkError("Unexpected network error", { cause: error });
|
||||
};
|
||||
|
||||
export class HttpClient {
|
||||
private axios: AxiosInstance;
|
||||
private settings: HttpClientSettings;
|
||||
|
||||
constructor(config: DifyClientConfig) {
|
||||
this.settings = normalizeSettings(config);
|
||||
this.axios = axios.create({
|
||||
baseURL: this.settings.baseUrl,
|
||||
timeout: this.settings.timeout * 1000,
|
||||
});
|
||||
}
|
||||
|
||||
updateApiKey(apiKey: string): void {
|
||||
this.settings.apiKey = apiKey;
|
||||
}
|
||||
|
||||
getSettings(): HttpClientSettings {
|
||||
return { ...this.settings };
|
||||
}
|
||||
|
||||
async request<T>(options: RequestOptions): Promise<DifyResponse<T>> {
|
||||
const response = await this.requestRaw(options);
|
||||
const headers = normalizeHeaders(response.headers);
|
||||
return {
|
||||
data: response.data as T,
|
||||
status: response.status,
|
||||
headers,
|
||||
requestId: resolveRequestId(headers),
|
||||
};
|
||||
}
|
||||
|
||||
async requestStream<T>(options: RequestOptions) {
|
||||
const response = await this.requestRaw({
|
||||
...options,
|
||||
responseType: "stream",
|
||||
});
|
||||
const headers = normalizeHeaders(response.headers);
|
||||
return createSseStream<T>(response.data as Readable, {
|
||||
status: response.status,
|
||||
headers,
|
||||
requestId: resolveRequestId(headers),
|
||||
});
|
||||
}
|
||||
|
||||
async requestBinaryStream(options: RequestOptions) {
|
||||
const response = await this.requestRaw({
|
||||
...options,
|
||||
responseType: "stream",
|
||||
});
|
||||
const headers = normalizeHeaders(response.headers);
|
||||
return createBinaryStream(response.data as Readable, {
|
||||
status: response.status,
|
||||
headers,
|
||||
requestId: resolveRequestId(headers),
|
||||
});
|
||||
}
|
||||
|
||||
async requestRaw(options: RequestOptions): Promise<AxiosResponse> {
|
||||
const { method, path, query, data, headers, responseType } = options;
|
||||
const { apiKey, enableLogging, maxRetries, retryDelay, timeout } =
|
||||
this.settings;
|
||||
|
||||
if (query) {
|
||||
validateParams(query as Record<string, unknown>);
|
||||
}
|
||||
if (
|
||||
data &&
|
||||
typeof data === "object" &&
|
||||
!Array.isArray(data) &&
|
||||
!isFormData(data) &&
|
||||
!isReadableStream(data)
|
||||
) {
|
||||
validateParams(data as Record<string, unknown>);
|
||||
}
|
||||
|
||||
const requestHeaders: Headers = {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
...headers,
|
||||
};
|
||||
if (
|
||||
typeof process !== "undefined" &&
|
||||
!!process.versions?.node &&
|
||||
!requestHeaders["User-Agent"] &&
|
||||
!requestHeaders["user-agent"]
|
||||
) {
|
||||
requestHeaders["User-Agent"] = DEFAULT_USER_AGENT;
|
||||
}
|
||||
|
||||
if (isFormData(data)) {
|
||||
Object.assign(requestHeaders, getFormDataHeaders(data));
|
||||
} else if (data && method !== "GET") {
|
||||
requestHeaders["Content-Type"] = "application/json";
|
||||
}
|
||||
|
||||
const url = buildRequestUrl(this.settings.baseUrl, path);
|
||||
|
||||
if (enableLogging) {
|
||||
console.info(`dify-client-node request ${method} ${url}`);
|
||||
}
|
||||
|
||||
const axiosConfig: AxiosRequestConfig = {
|
||||
method,
|
||||
url: path,
|
||||
params: query,
|
||||
paramsSerializer: {
|
||||
serialize: (params) => buildQueryString(params as QueryParams),
|
||||
},
|
||||
headers: requestHeaders,
|
||||
responseType: responseType ?? "json",
|
||||
timeout: timeout * 1000,
|
||||
};
|
||||
|
||||
if (method !== "GET" && data !== undefined) {
|
||||
axiosConfig.data = data;
|
||||
}
|
||||
|
||||
let attempt = 0;
|
||||
// `attempt` is a zero-based retry counter
|
||||
// Total attempts = 1 (initial) + maxRetries
|
||||
// e.g., maxRetries=3 means: attempt 0 (initial), then retries at 1, 2, 3
|
||||
while (true) {
|
||||
try {
|
||||
const response = await this.axios.request(axiosConfig);
|
||||
if (enableLogging) {
|
||||
console.info(
|
||||
`dify-client-node response ${response.status} ${method} ${url}`
|
||||
);
|
||||
}
|
||||
return response;
|
||||
} catch (error) {
|
||||
const mapped = mapAxiosError(error);
|
||||
if (!shouldRetry(mapped, attempt, maxRetries)) {
|
||||
throw mapped;
|
||||
}
|
||||
const retryAfterSeconds =
|
||||
mapped instanceof RateLimitError ? mapped.retryAfter : undefined;
|
||||
const delay = getRetryDelayMs(attempt + 1, retryDelay, retryAfterSeconds);
|
||||
if (enableLogging) {
|
||||
console.info(
|
||||
`dify-client-node retry ${attempt + 1} in ${delay}ms for ${method} ${url}`
|
||||
);
|
||||
}
|
||||
attempt += 1;
|
||||
await sleep(delay);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
import { describe, expect, it } from "vitest";
|
||||
import { getFormDataHeaders, isFormData } from "./form-data";
|
||||
|
||||
describe("form-data helpers", () => {
|
||||
it("detects form-data like objects", () => {
|
||||
const formLike = {
|
||||
append: () => {},
|
||||
getHeaders: () => ({ "content-type": "multipart/form-data" }),
|
||||
};
|
||||
expect(isFormData(formLike)).toBe(true);
|
||||
expect(isFormData({})).toBe(false);
|
||||
});
|
||||
|
||||
it("returns headers from form-data", () => {
|
||||
const formLike = {
|
||||
append: () => {},
|
||||
getHeaders: () => ({ "content-type": "multipart/form-data" }),
|
||||
};
|
||||
expect(getFormDataHeaders(formLike)).toEqual({
|
||||
"content-type": "multipart/form-data",
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
import type { Headers } from "../types/common";
|
||||
|
||||
export type FormDataLike = {
|
||||
append: (...args: unknown[]) => void;
|
||||
getHeaders?: () => Headers;
|
||||
constructor?: { name?: string };
|
||||
};
|
||||
|
||||
export const isFormData = (value: unknown): value is FormDataLike => {
|
||||
if (!value || typeof value !== "object") {
|
||||
return false;
|
||||
}
|
||||
if (typeof FormData !== "undefined" && value instanceof FormData) {
|
||||
return true;
|
||||
}
|
||||
const candidate = value as FormDataLike;
|
||||
if (typeof candidate.append !== "function") {
|
||||
return false;
|
||||
}
|
||||
if (typeof candidate.getHeaders === "function") {
|
||||
return true;
|
||||
}
|
||||
return candidate.constructor?.name === "FormData";
|
||||
};
|
||||
|
||||
export const getFormDataHeaders = (form: FormDataLike): Headers => {
|
||||
if (typeof form.getHeaders === "function") {
|
||||
return form.getHeaders();
|
||||
}
|
||||
return {};
|
||||
};
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
import { describe, expect, it } from "vitest";
|
||||
import { getRetryDelayMs, shouldRetry } from "./retry";
|
||||
import { NetworkError, RateLimitError, TimeoutError } from "../errors/dify-error";
|
||||
|
||||
const withMockedRandom = (value, fn) => {
|
||||
const original = Math.random;
|
||||
Math.random = () => value;
|
||||
try {
|
||||
fn();
|
||||
} finally {
|
||||
Math.random = original;
|
||||
}
|
||||
};
|
||||
|
||||
describe("retry helpers", () => {
|
||||
it("getRetryDelayMs honors retry-after header", () => {
|
||||
expect(getRetryDelayMs(1, 1, 3)).toBe(3000);
|
||||
});
|
||||
|
||||
it("getRetryDelayMs uses exponential backoff with jitter", () => {
|
||||
withMockedRandom(0, () => {
|
||||
expect(getRetryDelayMs(1, 1)).toBe(1000);
|
||||
expect(getRetryDelayMs(2, 1)).toBe(2000);
|
||||
expect(getRetryDelayMs(3, 1)).toBe(4000);
|
||||
});
|
||||
});
|
||||
|
||||
it("shouldRetry respects max retries", () => {
|
||||
expect(shouldRetry(new TimeoutError("timeout"), 3, 3)).toBe(false);
|
||||
});
|
||||
|
||||
it("shouldRetry retries on network, timeout, and rate limit", () => {
|
||||
expect(shouldRetry(new TimeoutError("timeout"), 0, 3)).toBe(true);
|
||||
expect(shouldRetry(new NetworkError("network"), 0, 3)).toBe(true);
|
||||
expect(shouldRetry(new RateLimitError("limit"), 0, 3)).toBe(true);
|
||||
expect(shouldRetry(new Error("other"), 0, 3)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
import { RateLimitError, NetworkError, TimeoutError } from "../errors/dify-error";
|
||||
|
||||
export const sleep = (ms: number): Promise<void> =>
|
||||
new Promise((resolve) => {
|
||||
setTimeout(resolve, ms);
|
||||
});
|
||||
|
||||
export const getRetryDelayMs = (
|
||||
attempt: number,
|
||||
retryDelaySeconds: number,
|
||||
retryAfterSeconds?: number
|
||||
): number => {
|
||||
if (retryAfterSeconds && retryAfterSeconds > 0) {
|
||||
return retryAfterSeconds * 1000;
|
||||
}
|
||||
const base = retryDelaySeconds * 1000;
|
||||
const exponential = base * Math.pow(2, Math.max(0, attempt - 1));
|
||||
const jitter = Math.random() * base;
|
||||
return exponential + jitter;
|
||||
};
|
||||
|
||||
export const shouldRetry = (
|
||||
error: unknown,
|
||||
attempt: number,
|
||||
maxRetries: number
|
||||
): boolean => {
|
||||
if (attempt >= maxRetries) {
|
||||
return false;
|
||||
}
|
||||
if (error instanceof TimeoutError) {
|
||||
return true;
|
||||
}
|
||||
if (error instanceof NetworkError) {
|
||||
return true;
|
||||
}
|
||||
if (error instanceof RateLimitError) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
import { Readable } from "node:stream";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { createBinaryStream, createSseStream, parseSseStream } from "./sse";
|
||||
|
||||
describe("sse parsing", () => {
|
||||
it("parses event and data lines", async () => {
|
||||
const stream = Readable.from([
|
||||
"event: message\n",
|
||||
"data: {\"answer\":\"hi\"}\n",
|
||||
"\n",
|
||||
]);
|
||||
const events = [];
|
||||
for await (const event of parseSseStream(stream)) {
|
||||
events.push(event);
|
||||
}
|
||||
expect(events).toHaveLength(1);
|
||||
expect(events[0].event).toBe("message");
|
||||
expect(events[0].data).toEqual({ answer: "hi" });
|
||||
});
|
||||
|
||||
it("handles multi-line data payloads", async () => {
|
||||
const stream = Readable.from(["data: line1\n", "data: line2\n", "\n"]);
|
||||
const events = [];
|
||||
for await (const event of parseSseStream(stream)) {
|
||||
events.push(event);
|
||||
}
|
||||
expect(events[0].raw).toBe("line1\nline2");
|
||||
expect(events[0].data).toBe("line1\nline2");
|
||||
});
|
||||
|
||||
it("createSseStream exposes toText", async () => {
|
||||
const stream = Readable.from([
|
||||
"data: {\"answer\":\"hello\"}\n\n",
|
||||
"data: {\"delta\":\" world\"}\n\n",
|
||||
]);
|
||||
const sseStream = createSseStream(stream, {
|
||||
status: 200,
|
||||
headers: {},
|
||||
requestId: "req",
|
||||
});
|
||||
const text = await sseStream.toText();
|
||||
expect(text).toBe("hello world");
|
||||
});
|
||||
|
||||
it("toText extracts text from string data", async () => {
|
||||
const stream = Readable.from(["data: plain text\n\n"]);
|
||||
const sseStream = createSseStream(stream, { status: 200, headers: {} });
|
||||
const text = await sseStream.toText();
|
||||
expect(text).toBe("plain text");
|
||||
});
|
||||
|
||||
it("toText extracts text field from object", async () => {
|
||||
const stream = Readable.from(['data: {"text":"hello"}\n\n']);
|
||||
const sseStream = createSseStream(stream, { status: 200, headers: {} });
|
||||
const text = await sseStream.toText();
|
||||
expect(text).toBe("hello");
|
||||
});
|
||||
|
||||
it("toText returns empty for invalid data", async () => {
|
||||
const stream = Readable.from(["data: null\n\n", "data: 123\n\n"]);
|
||||
const sseStream = createSseStream(stream, { status: 200, headers: {} });
|
||||
const text = await sseStream.toText();
|
||||
expect(text).toBe("");
|
||||
});
|
||||
|
||||
it("createBinaryStream exposes metadata", () => {
|
||||
const stream = Readable.from(["chunk"]);
|
||||
const binary = createBinaryStream(stream, {
|
||||
status: 200,
|
||||
headers: { "content-type": "audio/mpeg" },
|
||||
requestId: "req",
|
||||
});
|
||||
expect(binary.status).toBe(200);
|
||||
expect(binary.headers["content-type"]).toBe("audio/mpeg");
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
import type { Readable } from "node:stream";
|
||||
import { StringDecoder } from "node:string_decoder";
|
||||
import type { BinaryStream, DifyStream, Headers, StreamEvent } from "../types/common";
|
||||
|
||||
const readLines = async function* (stream: Readable): AsyncIterable<string> {
|
||||
const decoder = new StringDecoder("utf8");
|
||||
let buffered = "";
|
||||
for await (const chunk of stream) {
|
||||
buffered += decoder.write(chunk as Buffer);
|
||||
let index = buffered.indexOf("\n");
|
||||
while (index >= 0) {
|
||||
let line = buffered.slice(0, index);
|
||||
buffered = buffered.slice(index + 1);
|
||||
if (line.endsWith("\r")) {
|
||||
line = line.slice(0, -1);
|
||||
}
|
||||
yield line;
|
||||
index = buffered.indexOf("\n");
|
||||
}
|
||||
}
|
||||
buffered += decoder.end();
|
||||
if (buffered) {
|
||||
yield buffered;
|
||||
}
|
||||
};
|
||||
|
||||
const parseMaybeJson = (value: string): unknown => {
|
||||
if (!value) {
|
||||
return null;
|
||||
}
|
||||
try {
|
||||
return JSON.parse(value);
|
||||
} catch {
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
export const parseSseStream = async function* <T>(
|
||||
stream: Readable
|
||||
): AsyncIterable<StreamEvent<T>> {
|
||||
let eventName: string | undefined;
|
||||
const dataLines: string[] = [];
|
||||
|
||||
const emitEvent = function* (): Iterable<StreamEvent<T>> {
|
||||
if (!eventName && dataLines.length === 0) {
|
||||
return;
|
||||
}
|
||||
const raw = dataLines.join("\n");
|
||||
const parsed = parseMaybeJson(raw) as T | string | null;
|
||||
yield {
|
||||
event: eventName,
|
||||
data: parsed,
|
||||
raw,
|
||||
};
|
||||
eventName = undefined;
|
||||
dataLines.length = 0;
|
||||
};
|
||||
|
||||
for await (const line of readLines(stream)) {
|
||||
if (!line) {
|
||||
yield* emitEvent();
|
||||
continue;
|
||||
}
|
||||
if (line.startsWith(":")) {
|
||||
continue;
|
||||
}
|
||||
if (line.startsWith("event:")) {
|
||||
eventName = line.slice("event:".length).trim();
|
||||
continue;
|
||||
}
|
||||
if (line.startsWith("data:")) {
|
||||
dataLines.push(line.slice("data:".length).trimStart());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
yield* emitEvent();
|
||||
};
|
||||
|
||||
const extractTextFromEvent = (data: unknown): string => {
|
||||
if (typeof data === "string") {
|
||||
return data;
|
||||
}
|
||||
if (!data || typeof data !== "object") {
|
||||
return "";
|
||||
}
|
||||
const record = data as Record<string, unknown>;
|
||||
if (typeof record.answer === "string") {
|
||||
return record.answer;
|
||||
}
|
||||
if (typeof record.text === "string") {
|
||||
return record.text;
|
||||
}
|
||||
if (typeof record.delta === "string") {
|
||||
return record.delta;
|
||||
}
|
||||
return "";
|
||||
};
|
||||
|
||||
export const createSseStream = <T>(
|
||||
stream: Readable,
|
||||
meta: { status: number; headers: Headers; requestId?: string }
|
||||
): DifyStream<T> => {
|
||||
const iterator = parseSseStream<T>(stream)[Symbol.asyncIterator]();
|
||||
const iterable = {
|
||||
[Symbol.asyncIterator]: () => iterator,
|
||||
data: stream,
|
||||
status: meta.status,
|
||||
headers: meta.headers,
|
||||
requestId: meta.requestId,
|
||||
toReadable: () => stream,
|
||||
toText: async () => {
|
||||
let text = "";
|
||||
for await (const event of iterable) {
|
||||
text += extractTextFromEvent(event.data);
|
||||
}
|
||||
return text;
|
||||
},
|
||||
} satisfies DifyStream<T>;
|
||||
|
||||
return iterable;
|
||||
};
|
||||
|
||||
export const createBinaryStream = (
|
||||
stream: Readable,
|
||||
meta: { status: number; headers: Headers; requestId?: string }
|
||||
): BinaryStream => ({
|
||||
data: stream,
|
||||
status: meta.status,
|
||||
headers: meta.headers,
|
||||
requestId: meta.requestId,
|
||||
toReadable: () => stream,
|
||||
});
|
||||
|
|
@ -0,0 +1,227 @@
|
|||
import { afterAll, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { ChatClient, DifyClient, WorkflowClient, BASE_URL, routes } from "./index";
|
||||
import axios from "axios";
|
||||
|
||||
const mockRequest = vi.fn();
|
||||
|
||||
const setupAxiosMock = () => {
|
||||
vi.spyOn(axios, "create").mockReturnValue({ request: mockRequest });
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
mockRequest.mockReset();
|
||||
setupAxiosMock();
|
||||
});
|
||||
|
||||
describe("Client", () => {
|
||||
it("should create a client", () => {
|
||||
new DifyClient("test");
|
||||
|
||||
expect(axios.create).toHaveBeenCalledWith({
|
||||
baseURL: BASE_URL,
|
||||
timeout: 60000,
|
||||
});
|
||||
});
|
||||
|
||||
it("should update the api key", () => {
|
||||
const difyClient = new DifyClient("test");
|
||||
difyClient.updateApiKey("test2");
|
||||
|
||||
expect(difyClient.getHttpClient().getSettings().apiKey).toBe("test2");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Send Requests", () => {
|
||||
it("should make a successful request to the application parameter", async () => {
|
||||
const difyClient = new DifyClient("test");
|
||||
const method = "GET";
|
||||
const endpoint = routes.application.url();
|
||||
mockRequest.mockResolvedValue({
|
||||
status: 200,
|
||||
data: "response",
|
||||
headers: {},
|
||||
});
|
||||
|
||||
await difyClient.sendRequest(method, endpoint);
|
||||
|
||||
const requestConfig = mockRequest.mock.calls[0][0];
|
||||
expect(requestConfig).toMatchObject({
|
||||
method,
|
||||
url: endpoint,
|
||||
params: undefined,
|
||||
responseType: "json",
|
||||
timeout: 60000,
|
||||
});
|
||||
expect(requestConfig.headers.Authorization).toBe("Bearer test");
|
||||
});
|
||||
|
||||
it("uses the getMeta route configuration", async () => {
|
||||
const difyClient = new DifyClient("test");
|
||||
mockRequest.mockResolvedValue({ status: 200, data: "ok", headers: {} });
|
||||
|
||||
await difyClient.getMeta("end-user");
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith(expect.objectContaining({
|
||||
method: routes.getMeta.method,
|
||||
url: routes.getMeta.url(),
|
||||
params: { user: "end-user" },
|
||||
headers: expect.objectContaining({
|
||||
Authorization: "Bearer test",
|
||||
}),
|
||||
responseType: "json",
|
||||
timeout: 60000,
|
||||
}));
|
||||
});
|
||||
});
|
||||
|
||||
describe("File uploads", () => {
|
||||
const OriginalFormData = globalThis.FormData;
|
||||
|
||||
beforeAll(() => {
|
||||
globalThis.FormData = class FormDataMock {
|
||||
append() {}
|
||||
|
||||
getHeaders() {
|
||||
return {
|
||||
"content-type": "multipart/form-data; boundary=test",
|
||||
};
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
globalThis.FormData = OriginalFormData;
|
||||
});
|
||||
|
||||
it("does not override multipart boundary headers for FormData", async () => {
|
||||
const difyClient = new DifyClient("test");
|
||||
const form = new globalThis.FormData();
|
||||
mockRequest.mockResolvedValue({ status: 200, data: "ok", headers: {} });
|
||||
|
||||
await difyClient.fileUpload(form, "end-user");
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith(expect.objectContaining({
|
||||
method: routes.fileUpload.method,
|
||||
url: routes.fileUpload.url(),
|
||||
params: undefined,
|
||||
headers: expect.objectContaining({
|
||||
Authorization: "Bearer test",
|
||||
"content-type": "multipart/form-data; boundary=test",
|
||||
}),
|
||||
responseType: "json",
|
||||
timeout: 60000,
|
||||
data: form,
|
||||
}));
|
||||
});
|
||||
});
|
||||
|
||||
describe("Workflow client", () => {
|
||||
it("uses tasks stop path for workflow stop", async () => {
|
||||
const workflowClient = new WorkflowClient("test");
|
||||
mockRequest.mockResolvedValue({ status: 200, data: "stopped", headers: {} });
|
||||
|
||||
await workflowClient.stop("task-1", "end-user");
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith(expect.objectContaining({
|
||||
method: routes.stopWorkflow.method,
|
||||
url: routes.stopWorkflow.url("task-1"),
|
||||
params: undefined,
|
||||
headers: expect.objectContaining({
|
||||
Authorization: "Bearer test",
|
||||
"Content-Type": "application/json",
|
||||
}),
|
||||
responseType: "json",
|
||||
timeout: 60000,
|
||||
data: { user: "end-user" },
|
||||
}));
|
||||
});
|
||||
|
||||
it("maps workflow log filters to service api params", async () => {
|
||||
const workflowClient = new WorkflowClient("test");
|
||||
mockRequest.mockResolvedValue({ status: 200, data: "ok", headers: {} });
|
||||
|
||||
await workflowClient.getLogs({
|
||||
createdAtAfter: "2024-01-01T00:00:00Z",
|
||||
createdAtBefore: "2024-01-02T00:00:00Z",
|
||||
createdByEndUserSessionId: "sess-1",
|
||||
createdByAccount: "acc-1",
|
||||
page: 2,
|
||||
limit: 10,
|
||||
});
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith(expect.objectContaining({
|
||||
method: "GET",
|
||||
url: "/workflows/logs",
|
||||
params: {
|
||||
created_at__after: "2024-01-01T00:00:00Z",
|
||||
created_at__before: "2024-01-02T00:00:00Z",
|
||||
created_by_end_user_session_id: "sess-1",
|
||||
created_by_account: "acc-1",
|
||||
page: 2,
|
||||
limit: 10,
|
||||
},
|
||||
headers: expect.objectContaining({
|
||||
Authorization: "Bearer test",
|
||||
}),
|
||||
responseType: "json",
|
||||
timeout: 60000,
|
||||
}));
|
||||
});
|
||||
});
|
||||
|
||||
describe("Chat client", () => {
|
||||
it("places user in query for suggested messages", async () => {
|
||||
const chatClient = new ChatClient("test");
|
||||
mockRequest.mockResolvedValue({ status: 200, data: "ok", headers: {} });
|
||||
|
||||
await chatClient.getSuggested("msg-1", "end-user");
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith(expect.objectContaining({
|
||||
method: routes.getSuggested.method,
|
||||
url: routes.getSuggested.url("msg-1"),
|
||||
params: { user: "end-user" },
|
||||
headers: expect.objectContaining({
|
||||
Authorization: "Bearer test",
|
||||
}),
|
||||
responseType: "json",
|
||||
timeout: 60000,
|
||||
}));
|
||||
});
|
||||
|
||||
it("uses last_id when listing conversations", async () => {
|
||||
const chatClient = new ChatClient("test");
|
||||
mockRequest.mockResolvedValue({ status: 200, data: "ok", headers: {} });
|
||||
|
||||
await chatClient.getConversations("end-user", "last-1", 10);
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith(expect.objectContaining({
|
||||
method: routes.getConversations.method,
|
||||
url: routes.getConversations.url(),
|
||||
params: { user: "end-user", last_id: "last-1", limit: 10 },
|
||||
headers: expect.objectContaining({
|
||||
Authorization: "Bearer test",
|
||||
}),
|
||||
responseType: "json",
|
||||
timeout: 60000,
|
||||
}));
|
||||
});
|
||||
|
||||
it("lists app feedbacks without user params", async () => {
|
||||
const chatClient = new ChatClient("test");
|
||||
mockRequest.mockResolvedValue({ status: 200, data: "ok", headers: {} });
|
||||
|
||||
await chatClient.getAppFeedbacks(1, 20);
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith(expect.objectContaining({
|
||||
method: "GET",
|
||||
url: "/app/feedbacks",
|
||||
params: { page: 1, limit: 20 },
|
||||
headers: expect.objectContaining({
|
||||
Authorization: "Bearer test",
|
||||
}),
|
||||
responseType: "json",
|
||||
timeout: 60000,
|
||||
}));
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,103 @@
|
|||
import { DEFAULT_BASE_URL } from "./types/common";
|
||||
|
||||
export const BASE_URL = DEFAULT_BASE_URL;
|
||||
|
||||
export const routes = {
|
||||
feedback: {
|
||||
method: "POST",
|
||||
url: (messageId: string) => `/messages/${messageId}/feedbacks`,
|
||||
},
|
||||
application: {
|
||||
method: "GET",
|
||||
url: () => "/parameters",
|
||||
},
|
||||
fileUpload: {
|
||||
method: "POST",
|
||||
url: () => "/files/upload",
|
||||
},
|
||||
filePreview: {
|
||||
method: "GET",
|
||||
url: (fileId: string) => `/files/${fileId}/preview`,
|
||||
},
|
||||
textToAudio: {
|
||||
method: "POST",
|
||||
url: () => "/text-to-audio",
|
||||
},
|
||||
audioToText: {
|
||||
method: "POST",
|
||||
url: () => "/audio-to-text",
|
||||
},
|
||||
getMeta: {
|
||||
method: "GET",
|
||||
url: () => "/meta",
|
||||
},
|
||||
getInfo: {
|
||||
method: "GET",
|
||||
url: () => "/info",
|
||||
},
|
||||
getSite: {
|
||||
method: "GET",
|
||||
url: () => "/site",
|
||||
},
|
||||
createCompletionMessage: {
|
||||
method: "POST",
|
||||
url: () => "/completion-messages",
|
||||
},
|
||||
stopCompletionMessage: {
|
||||
method: "POST",
|
||||
url: (taskId: string) => `/completion-messages/${taskId}/stop`,
|
||||
},
|
||||
createChatMessage: {
|
||||
method: "POST",
|
||||
url: () => "/chat-messages",
|
||||
},
|
||||
getSuggested: {
|
||||
method: "GET",
|
||||
url: (messageId: string) => `/messages/${messageId}/suggested`,
|
||||
},
|
||||
stopChatMessage: {
|
||||
method: "POST",
|
||||
url: (taskId: string) => `/chat-messages/${taskId}/stop`,
|
||||
},
|
||||
getConversations: {
|
||||
method: "GET",
|
||||
url: () => "/conversations",
|
||||
},
|
||||
getConversationMessages: {
|
||||
method: "GET",
|
||||
url: () => "/messages",
|
||||
},
|
||||
renameConversation: {
|
||||
method: "POST",
|
||||
url: (conversationId: string) => `/conversations/${conversationId}/name`,
|
||||
},
|
||||
deleteConversation: {
|
||||
method: "DELETE",
|
||||
url: (conversationId: string) => `/conversations/${conversationId}`,
|
||||
},
|
||||
runWorkflow: {
|
||||
method: "POST",
|
||||
url: () => "/workflows/run",
|
||||
},
|
||||
stopWorkflow: {
|
||||
method: "POST",
|
||||
url: (taskId: string) => `/workflows/tasks/${taskId}/stop`,
|
||||
},
|
||||
};
|
||||
|
||||
export { DifyClient } from "./client/base";
|
||||
export { ChatClient } from "./client/chat";
|
||||
export { CompletionClient } from "./client/completion";
|
||||
export { WorkflowClient } from "./client/workflow";
|
||||
export { KnowledgeBaseClient } from "./client/knowledge-base";
|
||||
export { WorkspaceClient } from "./client/workspace";
|
||||
|
||||
export * from "./errors/dify-error";
|
||||
export * from "./types/common";
|
||||
export * from "./types/annotation";
|
||||
export * from "./types/chat";
|
||||
export * from "./types/completion";
|
||||
export * from "./types/knowledge-base";
|
||||
export * from "./types/workflow";
|
||||
export * from "./types/workspace";
|
||||
export { HttpClient } from "./http/client";
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
export type AnnotationCreateRequest = {
|
||||
question: string;
|
||||
answer: string;
|
||||
};
|
||||
|
||||
export type AnnotationReplyActionRequest = {
|
||||
score_threshold: number;
|
||||
embedding_provider_name: string;
|
||||
embedding_model_name: string;
|
||||
};
|
||||
|
||||
export type AnnotationListOptions = {
|
||||
page?: number;
|
||||
limit?: number;
|
||||
keyword?: string;
|
||||
};
|
||||
|
||||
export type AnnotationResponse = Record<string, unknown>;
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
import type { StreamEvent } from "./common";
|
||||
|
||||
export type ChatMessageRequest = {
|
||||
inputs?: Record<string, unknown>;
|
||||
query: string;
|
||||
user: string;
|
||||
response_mode?: "blocking" | "streaming";
|
||||
files?: Array<Record<string, unknown>> | null;
|
||||
conversation_id?: string;
|
||||
auto_generate_name?: boolean;
|
||||
workflow_id?: string;
|
||||
retriever_from?: "app" | "dataset";
|
||||
};
|
||||
|
||||
export type ChatMessageResponse = Record<string, unknown>;
|
||||
|
||||
export type ChatStreamEvent = StreamEvent<Record<string, unknown>>;
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
export const DEFAULT_BASE_URL = "https://api.dify.ai/v1";
|
||||
export const DEFAULT_TIMEOUT_SECONDS = 60;
|
||||
export const DEFAULT_MAX_RETRIES = 3;
|
||||
export const DEFAULT_RETRY_DELAY_SECONDS = 1;
|
||||
|
||||
export type RequestMethod = "GET" | "POST" | "PATCH" | "PUT" | "DELETE";
|
||||
|
||||
export type QueryParamValue =
|
||||
| string
|
||||
| number
|
||||
| boolean
|
||||
| Array<string | number | boolean>
|
||||
| undefined;
|
||||
|
||||
export type QueryParams = Record<string, QueryParamValue>;
|
||||
|
||||
export type Headers = Record<string, string>;
|
||||
|
||||
export type DifyClientConfig = {
|
||||
apiKey: string;
|
||||
baseUrl?: string;
|
||||
timeout?: number;
|
||||
maxRetries?: number;
|
||||
retryDelay?: number;
|
||||
enableLogging?: boolean;
|
||||
};
|
||||
|
||||
export type DifyResponse<T> = {
|
||||
data: T;
|
||||
status: number;
|
||||
headers: Headers;
|
||||
requestId?: string;
|
||||
};
|
||||
|
||||
export type MessageFeedbackRequest = {
|
||||
messageId: string;
|
||||
user: string;
|
||||
rating?: "like" | "dislike" | null;
|
||||
content?: string | null;
|
||||
};
|
||||
|
||||
export type TextToAudioRequest = {
|
||||
user: string;
|
||||
text?: string;
|
||||
message_id?: string;
|
||||
streaming?: boolean;
|
||||
voice?: string;
|
||||
};
|
||||
|
||||
export type StreamEvent<T = unknown> = {
|
||||
event?: string;
|
||||
data: T | string | null;
|
||||
raw: string;
|
||||
};
|
||||
|
||||
export type DifyStream<T = unknown> = AsyncIterable<StreamEvent<T>> & {
|
||||
data: NodeJS.ReadableStream;
|
||||
status: number;
|
||||
headers: Headers;
|
||||
requestId?: string;
|
||||
toText(): Promise<string>;
|
||||
toReadable(): NodeJS.ReadableStream;
|
||||
};
|
||||
|
||||
export type BinaryStream = {
|
||||
data: NodeJS.ReadableStream;
|
||||
status: number;
|
||||
headers: Headers;
|
||||
requestId?: string;
|
||||
toReadable(): NodeJS.ReadableStream;
|
||||
};
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
import type { StreamEvent } from "./common";
|
||||
|
||||
export type CompletionRequest = {
|
||||
inputs?: Record<string, unknown>;
|
||||
response_mode?: "blocking" | "streaming";
|
||||
user: string;
|
||||
files?: Array<Record<string, unknown>> | null;
|
||||
retriever_from?: "app" | "dataset";
|
||||
};
|
||||
|
||||
export type CompletionResponse = Record<string, unknown>;
|
||||
|
||||
export type CompletionStreamEvent = StreamEvent<Record<string, unknown>>;
|
||||
|
|
@ -0,0 +1,184 @@
|
|||
export type DatasetListOptions = {
|
||||
page?: number;
|
||||
limit?: number;
|
||||
keyword?: string | null;
|
||||
tagIds?: string[];
|
||||
includeAll?: boolean;
|
||||
};
|
||||
|
||||
export type DatasetCreateRequest = {
|
||||
name: string;
|
||||
description?: string;
|
||||
indexing_technique?: "high_quality" | "economy";
|
||||
permission?: string | null;
|
||||
external_knowledge_api_id?: string | null;
|
||||
provider?: string;
|
||||
external_knowledge_id?: string | null;
|
||||
retrieval_model?: Record<string, unknown> | null;
|
||||
embedding_model?: string | null;
|
||||
embedding_model_provider?: string | null;
|
||||
};
|
||||
|
||||
export type DatasetUpdateRequest = {
|
||||
name?: string;
|
||||
description?: string | null;
|
||||
indexing_technique?: "high_quality" | "economy" | null;
|
||||
permission?: string | null;
|
||||
embedding_model?: string | null;
|
||||
embedding_model_provider?: string | null;
|
||||
retrieval_model?: Record<string, unknown> | null;
|
||||
partial_member_list?: Array<Record<string, string>> | null;
|
||||
external_retrieval_model?: Record<string, unknown> | null;
|
||||
external_knowledge_id?: string | null;
|
||||
external_knowledge_api_id?: string | null;
|
||||
};
|
||||
|
||||
export type DocumentStatusAction = "enable" | "disable" | "archive" | "un_archive";
|
||||
|
||||
export type DatasetTagCreateRequest = {
|
||||
name: string;
|
||||
};
|
||||
|
||||
export type DatasetTagUpdateRequest = {
|
||||
tag_id: string;
|
||||
name: string;
|
||||
};
|
||||
|
||||
export type DatasetTagDeleteRequest = {
|
||||
tag_id: string;
|
||||
};
|
||||
|
||||
export type DatasetTagBindingRequest = {
|
||||
tag_ids: string[];
|
||||
target_id: string;
|
||||
};
|
||||
|
||||
export type DatasetTagUnbindingRequest = {
|
||||
tag_id: string;
|
||||
target_id: string;
|
||||
};
|
||||
|
||||
export type DocumentTextCreateRequest = {
|
||||
name: string;
|
||||
text: string;
|
||||
process_rule?: Record<string, unknown> | null;
|
||||
original_document_id?: string | null;
|
||||
doc_form?: string;
|
||||
doc_language?: string;
|
||||
indexing_technique?: string | null;
|
||||
retrieval_model?: Record<string, unknown> | null;
|
||||
embedding_model?: string | null;
|
||||
embedding_model_provider?: string | null;
|
||||
};
|
||||
|
||||
export type DocumentTextUpdateRequest = {
|
||||
name?: string | null;
|
||||
text?: string | null;
|
||||
process_rule?: Record<string, unknown> | null;
|
||||
doc_form?: string;
|
||||
doc_language?: string;
|
||||
retrieval_model?: Record<string, unknown> | null;
|
||||
};
|
||||
|
||||
export type DocumentListOptions = {
|
||||
page?: number;
|
||||
limit?: number;
|
||||
keyword?: string | null;
|
||||
status?: string | null;
|
||||
};
|
||||
|
||||
export type DocumentGetOptions = {
|
||||
metadata?: "all" | "only" | "without";
|
||||
};
|
||||
|
||||
export type SegmentCreateRequest = {
|
||||
segments: Array<Record<string, unknown>>;
|
||||
};
|
||||
|
||||
export type SegmentUpdateRequest = {
|
||||
segment: {
|
||||
content?: string | null;
|
||||
answer?: string | null;
|
||||
keywords?: string[] | null;
|
||||
regenerate_child_chunks?: boolean;
|
||||
enabled?: boolean | null;
|
||||
attachment_ids?: string[] | null;
|
||||
};
|
||||
};
|
||||
|
||||
export type SegmentListOptions = {
|
||||
page?: number;
|
||||
limit?: number;
|
||||
status?: string[];
|
||||
keyword?: string | null;
|
||||
};
|
||||
|
||||
export type ChildChunkCreateRequest = {
|
||||
content: string;
|
||||
};
|
||||
|
||||
export type ChildChunkUpdateRequest = {
|
||||
content: string;
|
||||
};
|
||||
|
||||
export type ChildChunkListOptions = {
|
||||
page?: number;
|
||||
limit?: number;
|
||||
keyword?: string | null;
|
||||
};
|
||||
|
||||
export type MetadataCreateRequest = {
|
||||
type: "string" | "number" | "time";
|
||||
name: string;
|
||||
};
|
||||
|
||||
export type MetadataUpdateRequest = {
|
||||
name: string;
|
||||
value?: string | number | null;
|
||||
};
|
||||
|
||||
export type DocumentMetadataDetail = {
|
||||
id: string;
|
||||
name: string;
|
||||
value?: string | number | null;
|
||||
};
|
||||
|
||||
export type DocumentMetadataOperation = {
|
||||
document_id: string;
|
||||
metadata_list: DocumentMetadataDetail[];
|
||||
partial_update?: boolean;
|
||||
};
|
||||
|
||||
export type MetadataOperationRequest = {
|
||||
operation_data: DocumentMetadataOperation[];
|
||||
};
|
||||
|
||||
export type HitTestingRequest = {
|
||||
query?: string | null;
|
||||
retrieval_model?: Record<string, unknown> | null;
|
||||
external_retrieval_model?: Record<string, unknown> | null;
|
||||
attachment_ids?: string[] | null;
|
||||
};
|
||||
|
||||
export type DatasourcePluginListOptions = {
|
||||
isPublished?: boolean;
|
||||
};
|
||||
|
||||
export type DatasourceNodeRunRequest = {
|
||||
inputs: Record<string, unknown>;
|
||||
datasource_type: string;
|
||||
credential_id?: string | null;
|
||||
is_published: boolean;
|
||||
};
|
||||
|
||||
export type PipelineRunRequest = {
|
||||
inputs: Record<string, unknown>;
|
||||
datasource_type: string;
|
||||
datasource_info_list: Array<Record<string, unknown>>;
|
||||
start_node_id: string;
|
||||
is_published: boolean;
|
||||
response_mode: "streaming" | "blocking";
|
||||
};
|
||||
|
||||
export type KnowledgeBaseResponse = Record<string, unknown>;
|
||||
export type PipelineStreamEvent = Record<string, unknown>;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue