mirror of
https://github.com/langgenius/dify.git
synced 2026-04-20 14:17:18 +08:00
Merge remote-tracking branch 'origin/main' into feat/vibe-wf
This commit is contained in:
commit
89b29bd836
@ -187,7 +187,7 @@ const Template = useMemo(() => {
|
||||
|
||||
**When**: Component directly handles API calls, data transformation, or complex async operations.
|
||||
|
||||
**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks. Project is migrating from SWR to React Query.
|
||||
**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: API logic in component
|
||||
|
||||
@ -28,17 +28,14 @@ import userEvent from '@testing-library/user-event'
|
||||
|
||||
// i18n (automatically mocked)
|
||||
// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
|
||||
// No explicit mock needed - it returns translation keys as-is
|
||||
// The global mock provides: useTranslation, Trans, useMixedTranslation, useGetLanguage
|
||||
// No explicit mock needed for most tests
|
||||
//
|
||||
// Override only if custom translations are required:
|
||||
// vi.mock('react-i18next', () => ({
|
||||
// useTranslation: () => ({
|
||||
// t: (key: string) => {
|
||||
// const customTranslations: Record<string, string> = {
|
||||
// 'my.custom.key': 'Custom Translation',
|
||||
// }
|
||||
// return customTranslations[key] || key
|
||||
// },
|
||||
// }),
|
||||
// import { createReactI18nextMock } from '@/test/i18n-mock'
|
||||
// vi.mock('react-i18next', () => createReactI18nextMock({
|
||||
// 'my.custom.key': 'Custom Translation',
|
||||
// 'button.save': 'Save',
|
||||
// }))
|
||||
|
||||
// Router (if component uses useRouter, usePathname, useSearchParams)
|
||||
|
||||
@ -52,23 +52,29 @@ Modules are not mocked automatically. Use `vi.mock` in test files, or add global
|
||||
### 1. i18n (Auto-loaded via Global Mock)
|
||||
|
||||
A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
|
||||
**No explicit mock needed** for most tests - it returns translation keys as-is.
|
||||
|
||||
For tests requiring custom translations, override the mock:
|
||||
The global mock provides:
|
||||
|
||||
- `useTranslation` - returns translation keys with namespace prefix
|
||||
- `Trans` component - renders i18nKey and components
|
||||
- `useMixedTranslation` (from `@/app/components/plugins/marketplace/hooks`)
|
||||
- `useGetLanguage` (from `@/context/i18n`) - returns `'en-US'`
|
||||
|
||||
**Default behavior**: Most tests should use the global mock (no local override needed).
|
||||
|
||||
**For custom translations**: Use the helper function from `@/test/i18n-mock`:
|
||||
|
||||
```typescript
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => {
|
||||
const translations: Record<string, string> = {
|
||||
'my.custom.key': 'Custom translation',
|
||||
}
|
||||
return translations[key] || key
|
||||
},
|
||||
}),
|
||||
import { createReactI18nextMock } from '@/test/i18n-mock'
|
||||
|
||||
vi.mock('react-i18next', () => createReactI18nextMock({
|
||||
'my.custom.key': 'Custom translation',
|
||||
'button.save': 'Save',
|
||||
}))
|
||||
```
|
||||
|
||||
**Avoid**: Manually defining `useTranslation` mocks that just return the key - the global mock already does this.
|
||||
|
||||
### 2. Next.js Router
|
||||
|
||||
```typescript
|
||||
|
||||
6
.github/workflows/api-tests.yml
vendored
6
.github/workflows/api-tests.yml
vendored
@ -22,12 +22,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@ -57,7 +57,7 @@ jobs:
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Sandbox
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
uses: hoverkraft-tech/compose-action@v2
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
|
||||
4
.github/workflows/autofix.yml
vendored
4
.github/workflows/autofix.yml
vendored
@ -12,7 +12,7 @@ jobs:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Check Docker Compose inputs
|
||||
id: docker-compose-changes
|
||||
@ -27,7 +27,7 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
- uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
|
||||
2
.github/workflows/build-push.yml
vendored
2
.github/workflows/build-push.yml
vendored
@ -90,7 +90,7 @@ jobs:
|
||||
touch "/tmp/digests/${sanitized_digest}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
|
||||
8
.github/workflows/db-migration-test.yml
vendored
8
.github/workflows/db-migration-test.yml
vendored
@ -13,13 +13,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@ -63,13 +63,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
3
.github/workflows/main-ci.yml
vendored
3
.github/workflows/main-ci.yml
vendored
@ -27,7 +27,7 @@ jobs:
|
||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||
migration-changed: ${{ steps.changes.outputs.migration }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
- uses: dorny/paths-filter@v3
|
||||
id: changes
|
||||
with:
|
||||
@ -38,6 +38,7 @@ jobs:
|
||||
- '.github/workflows/api-tests.yml'
|
||||
web:
|
||||
- 'web/**'
|
||||
- '.github/workflows/web-tests.yml'
|
||||
vdb:
|
||||
- 'api/core/rag/datasource/**'
|
||||
- 'docker/**'
|
||||
|
||||
30
.github/workflows/style.yml
vendored
30
.github/workflows/style.yml
vendored
@ -19,13 +19,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v46
|
||||
uses: tj-actions/changed-files@v47
|
||||
with:
|
||||
files: |
|
||||
api/**
|
||||
@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@v6
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
@ -68,15 +68,17 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v46
|
||||
uses: tj-actions/changed-files@v47
|
||||
with:
|
||||
files: web/**
|
||||
files: |
|
||||
web/**
|
||||
.github/workflows/style.yml
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
@ -85,7 +87,7 @@ jobs:
|
||||
run_install: false
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 22
|
||||
@ -108,20 +110,30 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: pnpm run type-check:tsgo
|
||||
|
||||
- name: Web dead code check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run knip
|
||||
|
||||
- name: Web build check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run build
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v46
|
||||
uses: tj-actions/changed-files@v47
|
||||
with:
|
||||
files: |
|
||||
**.sh
|
||||
|
||||
4
.github/workflows/tool-test-sdks.yaml
vendored
4
.github/workflows/tool-test-sdks.yaml
vendored
@ -25,12 +25,12 @@ jobs:
|
||||
working-directory: sdks/nodejs-client
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Use Node.js ${{ matrix.node-version }}
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
cache: ''
|
||||
|
||||
@ -4,7 +4,7 @@ on:
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'web/i18n/en-US/*.ts'
|
||||
- 'web/i18n/en-US/*.json'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@ -18,7 +18,7 @@ jobs:
|
||||
run:
|
||||
working-directory: web
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@ -28,13 +28,13 @@ jobs:
|
||||
run: |
|
||||
git fetch origin "${{ github.event.before }}" || true
|
||||
git fetch origin "${{ github.sha }}" || true
|
||||
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.ts')
|
||||
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
|
||||
echo "Changed files: $changed_files"
|
||||
if [ -n "$changed_files" ]; then
|
||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||
file_args=""
|
||||
for file in $changed_files; do
|
||||
filename=$(basename "$file" .ts)
|
||||
filename=$(basename "$file" .json)
|
||||
file_args="$file_args --file $filename"
|
||||
done
|
||||
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
|
||||
@ -51,7 +51,7 @@ jobs:
|
||||
|
||||
- name: Set up Node.js
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 'lts/*'
|
||||
cache: pnpm
|
||||
@ -65,7 +65,7 @@ jobs:
|
||||
- name: Generate i18n translations
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
|
||||
run: pnpm run i18n:gen ${{ env.FILE_ARGS }}
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
|
||||
6
.github/workflows/vdb-tests.yml
vendored
6
.github/workflows/vdb-tests.yml
vendored
@ -19,19 +19,19 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Free Disk Space
|
||||
uses: endersonmenezes/free-disk-space@v2
|
||||
uses: endersonmenezes/free-disk-space@v3
|
||||
with:
|
||||
remove_dotnet: true
|
||||
remove_haskell: true
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
6
.github/workflows/web-tests.yml
vendored
6
.github/workflows/web-tests.yml
vendored
@ -18,7 +18,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@ -29,7 +29,7 @@ jobs:
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
@ -360,7 +360,7 @@ jobs:
|
||||
|
||||
- name: Upload Coverage Artifact
|
||||
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: web-coverage-report
|
||||
path: web/coverage
|
||||
|
||||
@ -101,6 +101,15 @@ S3_ACCESS_KEY=your-access-key
|
||||
S3_SECRET_KEY=your-secret-key
|
||||
S3_REGION=your-region
|
||||
|
||||
# Workflow run and Conversation archive storage (S3-compatible)
|
||||
ARCHIVE_STORAGE_ENABLED=false
|
||||
ARCHIVE_STORAGE_ENDPOINT=
|
||||
ARCHIVE_STORAGE_ARCHIVE_BUCKET=
|
||||
ARCHIVE_STORAGE_EXPORT_BUCKET=
|
||||
ARCHIVE_STORAGE_ACCESS_KEY=
|
||||
ARCHIVE_STORAGE_SECRET_KEY=
|
||||
ARCHIVE_STORAGE_REGION=auto
|
||||
|
||||
# Azure Blob Storage configuration
|
||||
AZURE_BLOB_ACCOUNT_NAME=your-account-name
|
||||
AZURE_BLOB_ACCOUNT_KEY=your-account-key
|
||||
@ -128,6 +137,7 @@ TENCENT_COS_SECRET_KEY=your-secret-key
|
||||
TENCENT_COS_SECRET_ID=your-secret-id
|
||||
TENCENT_COS_REGION=your-region
|
||||
TENCENT_COS_SCHEME=your-scheme
|
||||
TENCENT_COS_CUSTOM_DOMAIN=your-custom-domain
|
||||
|
||||
# Huawei OBS Storage Configuration
|
||||
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
|
||||
|
||||
@ -1,4 +1,8 @@
|
||||
exclude = ["migrations/*"]
|
||||
exclude = [
|
||||
"migrations/*",
|
||||
".git",
|
||||
".git/**",
|
||||
]
|
||||
line-length = 120
|
||||
|
||||
[format]
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from configs.extra.archive_config import ArchiveStorageConfig
|
||||
from configs.extra.notion_config import NotionConfig
|
||||
from configs.extra.sentry_config import SentryConfig
|
||||
|
||||
|
||||
class ExtraServiceConfig(
|
||||
# place the configs in alphabet order
|
||||
ArchiveStorageConfig,
|
||||
NotionConfig,
|
||||
SentryConfig,
|
||||
):
|
||||
|
||||
43
api/configs/extra/archive_config.py
Normal file
43
api/configs/extra/archive_config.py
Normal file
@ -0,0 +1,43 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ArchiveStorageConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for workflow run logs archiving storage.
|
||||
"""
|
||||
|
||||
ARCHIVE_STORAGE_ENABLED: bool = Field(
|
||||
description="Enable workflow run logs archiving to S3-compatible storage",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_ENDPOINT: str | None = Field(
|
||||
description="URL of the S3-compatible storage endpoint (e.g., 'https://storage.example.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_ARCHIVE_BUCKET: str | None = Field(
|
||||
description="Name of the bucket to store archived workflow logs",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_EXPORT_BUCKET: str | None = Field(
|
||||
description="Name of the bucket to store exported workflow runs",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_ACCESS_KEY: str | None = Field(
|
||||
description="Access key ID for authenticating with storage",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_SECRET_KEY: str | None = Field(
|
||||
description="Secret access key for authenticating with storage",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ARCHIVE_STORAGE_REGION: str = Field(
|
||||
description="Region for storage (use 'auto' if the provider supports it)",
|
||||
default="auto",
|
||||
)
|
||||
@ -31,3 +31,8 @@ class TencentCloudCOSStorageConfig(BaseSettings):
|
||||
description="Protocol scheme for COS requests: 'https' (recommended) or 'http'",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TENCENT_COS_CUSTOM_DOMAIN: str | None = Field(
|
||||
description="Tencent Cloud COS custom domain setting",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@ -1,62 +1,59 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from __future__ import annotations
|
||||
|
||||
from libs.helper import AppIconUrlField
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
parameters__system_parameters = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
"workflow_file_upload_limit": fields.Integer,
|
||||
}
|
||||
from pydantic import BaseModel, ConfigDict, computed_field
|
||||
|
||||
from core.file import helpers as file_helpers
|
||||
from models.model import IconType
|
||||
|
||||
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||
JSONObject: TypeAlias = dict[str, Any]
|
||||
|
||||
|
||||
def build_system_parameters_model(api_or_ns: Api | Namespace):
|
||||
"""Build the system parameters model for the API or Namespace."""
|
||||
return api_or_ns.model("SystemParameters", parameters__system_parameters)
|
||||
class SystemParameters(BaseModel):
|
||||
image_file_size_limit: int
|
||||
video_file_size_limit: int
|
||||
audio_file_size_limit: int
|
||||
file_size_limit: int
|
||||
workflow_file_upload_limit: int
|
||||
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(parameters__system_parameters),
|
||||
}
|
||||
class Parameters(BaseModel):
|
||||
opening_statement: str | None = None
|
||||
suggested_questions: list[str]
|
||||
suggested_questions_after_answer: JSONObject
|
||||
speech_to_text: JSONObject
|
||||
text_to_speech: JSONObject
|
||||
retriever_resource: JSONObject
|
||||
annotation_reply: JSONObject
|
||||
more_like_this: JSONObject
|
||||
user_input_form: list[JSONObject]
|
||||
sensitive_word_avoidance: JSONObject
|
||||
file_upload: JSONObject
|
||||
system_parameters: SystemParameters
|
||||
|
||||
|
||||
def build_parameters_model(api_or_ns: Api | Namespace):
|
||||
"""Build the parameters model for the API or Namespace."""
|
||||
copied_fields = parameters_fields.copy()
|
||||
copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns))
|
||||
return api_or_ns.model("Parameters", copied_fields)
|
||||
class Site(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
title: str
|
||||
chat_color_theme: str | None = None
|
||||
chat_color_theme_inverted: bool
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
description: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
default_language: str
|
||||
show_workflow_steps: bool
|
||||
use_icon_as_answer_icon: bool
|
||||
|
||||
site_fields = {
|
||||
"title": fields.String,
|
||||
"chat_color_theme": fields.String,
|
||||
"chat_color_theme_inverted": fields.Boolean,
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"description": fields.String,
|
||||
"copyright": fields.String,
|
||||
"privacy_policy": fields.String,
|
||||
"custom_disclaimer": fields.String,
|
||||
"default_language": fields.String,
|
||||
"show_workflow_steps": fields.Boolean,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
}
|
||||
|
||||
|
||||
def build_site_model(api_or_ns: Api | Namespace):
|
||||
"""Build the site model for the API or Namespace."""
|
||||
return api_or_ns.model("Site", site_fields)
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
def icon_url(self) -> str | None:
|
||||
if self.icon and self.icon_type == IconType.IMAGE:
|
||||
return file_helpers.get_signed_file_url(self.icon)
|
||||
return None
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import re
|
||||
import uuid
|
||||
from typing import Literal
|
||||
|
||||
@ -73,6 +74,48 @@ class AppListQuery(BaseModel):
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
|
||||
# XSS prevention: patterns that could lead to XSS attacks
|
||||
# Includes: script tags, iframe tags, javascript: protocol, SVG with onload, etc.
|
||||
_XSS_PATTERNS = [
|
||||
r"<script[^>]*>.*?</script>", # Script tags
|
||||
r"<iframe\b[^>]*?(?:/>|>.*?</iframe>)", # Iframe tags (including self-closing)
|
||||
r"javascript:", # JavaScript protocol
|
||||
r"<svg[^>]*?\s+onload\s*=[^>]*>", # SVG with onload handler (attribute-aware, flexible whitespace)
|
||||
r"<.*?on\s*\w+\s*=", # Event handlers like onclick, onerror, etc.
|
||||
r"<object\b[^>]*(?:\s*/>|>.*?</object\s*>)", # Object tags (opening tag)
|
||||
r"<embed[^>]*>", # Embed tags (self-closing)
|
||||
r"<link[^>]*>", # Link tags with javascript
|
||||
]
|
||||
|
||||
|
||||
def _validate_xss_safe(value: str | None, field_name: str = "Field") -> str | None:
|
||||
"""
|
||||
Validate that a string value doesn't contain potential XSS payloads.
|
||||
|
||||
Args:
|
||||
value: The string value to validate
|
||||
field_name: Name of the field for error messages
|
||||
|
||||
Returns:
|
||||
The original value if safe
|
||||
|
||||
Raises:
|
||||
ValueError: If the value contains XSS patterns
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
value_lower = value.lower()
|
||||
for pattern in _XSS_PATTERNS:
|
||||
if re.search(pattern, value_lower, re.DOTALL | re.IGNORECASE):
|
||||
raise ValueError(
|
||||
f"{field_name} contains invalid characters or patterns. "
|
||||
"HTML tags, JavaScript, and other potentially dangerous content are not allowed."
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
@ -81,6 +124,11 @@ class CreateAppPayload(BaseModel):
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("name", "description", mode="before")
|
||||
@classmethod
|
||||
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||
return _validate_xss_safe(value, info.field_name)
|
||||
|
||||
|
||||
class UpdateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
@ -91,6 +139,11 @@ class UpdateAppPayload(BaseModel):
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
||||
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
|
||||
|
||||
@field_validator("name", "description", mode="before")
|
||||
@classmethod
|
||||
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||
return _validate_xss_safe(value, info.field_name)
|
||||
|
||||
|
||||
class CopyAppPayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Name for the copied app")
|
||||
@ -99,6 +152,11 @@ class CopyAppPayload(BaseModel):
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("name", "description", mode="before")
|
||||
@classmethod
|
||||
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||
return _validate_xss_safe(value, info.field_name)
|
||||
|
||||
|
||||
class AppExportQuery(BaseModel):
|
||||
include_secret: bool = Field(default=False, description="Include secrets in export")
|
||||
|
||||
@ -124,7 +124,7 @@ class OAuthCallback(Resource):
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
|
||||
|
||||
try:
|
||||
account = _generate_account(provider, user_info)
|
||||
account, oauth_new_user = _generate_account(provider, user_info)
|
||||
except AccountNotFoundError:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
|
||||
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
|
||||
@ -159,7 +159,10 @@ class OAuthCallback(Resource):
|
||||
ip_address=extract_remote_ip(request),
|
||||
)
|
||||
|
||||
response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
|
||||
base_url = dify_config.CONSOLE_WEB_URL
|
||||
query_char = "&" if "?" in base_url else "?"
|
||||
target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}"
|
||||
response = redirect(target_url)
|
||||
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
@ -177,9 +180,10 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
return account
|
||||
|
||||
|
||||
def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||
def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]:
|
||||
# Get account by openid or email.
|
||||
account = _get_account_by_openid_or_email(provider, user_info)
|
||||
oauth_new_user = False
|
||||
|
||||
if account:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
@ -193,6 +197,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||
tenant_was_created.send(new_tenant)
|
||||
|
||||
if not account:
|
||||
oauth_new_user = True
|
||||
if not FeatureService.get_system_features().is_allow_register:
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
|
||||
raise AccountRegisterError(
|
||||
@ -220,4 +225,4 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||
# Link account
|
||||
AccountService.link_account_integrate(provider, user_info.id, account)
|
||||
|
||||
return account
|
||||
return account, oauth_new_user
|
||||
|
||||
@ -3,10 +3,12 @@ import uuid
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import String, cast, func, or_, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
@ -143,7 +145,29 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
query = query.where(DocumentSegment.hit_count >= hit_count_gte)
|
||||
|
||||
if keyword:
|
||||
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
|
||||
# Search in both content and keywords fields
|
||||
# Use database-specific methods for JSON array search
|
||||
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
|
||||
# PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
|
||||
keywords_condition = func.array_to_string(
|
||||
func.array(
|
||||
select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
|
||||
.correlate(DocumentSegment)
|
||||
.scalar_subquery()
|
||||
),
|
||||
",",
|
||||
).ilike(f"%{keyword}%")
|
||||
else:
|
||||
# MySQL: Cast JSON to string for pattern matching
|
||||
# MySQL stores Chinese text directly in JSON without Unicode escaping
|
||||
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%")
|
||||
|
||||
query = query.where(
|
||||
or_(
|
||||
DocumentSegment.content.ilike(f"%{keyword}%"),
|
||||
keywords_condition,
|
||||
)
|
||||
)
|
||||
|
||||
if args.enabled.lower() != "all":
|
||||
if args.enabled.lower() == "true":
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from flask_restx import marshal_with
|
||||
|
||||
from controllers.common import fields
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
@ -13,7 +11,6 @@ from services.app_service import AppService
|
||||
class AppParameterApi(InstalledAppResource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, installed_app: InstalledApp):
|
||||
"""Retrieve app parameters."""
|
||||
app_model = installed_app.app
|
||||
@ -37,7 +34,8 @@ class AppParameterApi(InstalledAppResource):
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
|
||||
|
||||
@ -20,7 +20,6 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
@ -987,9 +986,6 @@ class ToolProviderMCPApi(Resource):
|
||||
# Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
|
||||
logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
|
||||
|
||||
# Final cache invalidation to ensure list views are up to date
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@console_ns.expect(parser_mcp_put)
|
||||
@ -1036,9 +1032,6 @@ class ToolProviderMCPApi(Resource):
|
||||
validation_result=validation_result,
|
||||
)
|
||||
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
ToolProviderListCache.invalidate_cache(current_tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@console_ns.expect(parser_mcp_delete)
|
||||
@ -1053,9 +1046,6 @@ class ToolProviderMCPApi(Resource):
|
||||
service = MCPToolManageService(session=session)
|
||||
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
||||
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
ToolProviderListCache.invalidate_cache(current_tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@ -1106,8 +1096,6 @@ class ToolMCPAuthApi(Resource):
|
||||
credentials=provider_entity.credentials,
|
||||
authed=True,
|
||||
)
|
||||
# Invalidate cache after updating credentials
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
try:
|
||||
@ -1121,22 +1109,16 @@ class ToolMCPAuthApi(Resource):
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
response = service.execute_auth_actions(auth_result)
|
||||
# Invalidate cache after auth actions may have updated provider state
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
return response
|
||||
except MCPRefreshTokenError as e:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
# Invalidate cache after clearing credentials
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except (MCPError, ValueError) as e:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
# Invalidate cache after clearing credentials
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from flask_restx.api import HTTPStatus
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -92,7 +92,7 @@ annotation_list_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_annotation_list_model(api_or_ns: Api | Namespace):
|
||||
def build_annotation_list_model(api_or_ns: Namespace):
|
||||
"""Build the annotation list model for the API or Namespace."""
|
||||
copied_annotation_list_fields = annotation_list_fields.copy()
|
||||
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.common.fields import build_parameters_model
|
||||
from controllers.common.fields import Parameters
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
@ -23,7 +23,6 @@ class AppParameterApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_parameters_model(service_api_ns))
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app parameters.
|
||||
|
||||
@ -45,7 +44,8 @@ class AppParameterApi(Resource):
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@service_api_ns.route("/meta")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.fields import build_site_model
|
||||
from controllers.common.fields import Site as SiteResponse
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_database import db
|
||||
@ -23,7 +23,6 @@ class AppSiteApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_site_model(service_api_ns))
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app site info.
|
||||
|
||||
@ -38,4 +37,4 @@ class AppSiteApi(Resource):
|
||||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
return site
|
||||
return SiteResponse.model_validate(site).model_dump(mode="json")
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Any, Literal
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
@ -78,7 +78,7 @@ workflow_run_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_workflow_run_model(api_or_ns: Api | Namespace):
|
||||
def build_workflow_run_model(api_or_ns: Namespace):
|
||||
"""Build the workflow run model for the API or Namespace."""
|
||||
return api_or_ns.model("WorkflowRun", workflow_run_fields)
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
@ -50,7 +50,6 @@ class AppParameterApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
@ -69,7 +68,8 @@ class AppParameterApi(WebApiResource):
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@web_ns.route("/meta")
|
||||
|
||||
@ -22,6 +22,7 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.workflow.nodes.agent.exc import AgentMaxIterationError
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -165,6 +166,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# Check if max iteration is reached and model still wants to call tools
|
||||
if iteration_step == max_iteration_steps and scratchpad.action:
|
||||
if scratchpad.action.action_name.lower() != "final answer":
|
||||
raise AgentMaxIterationError(app_config.agent.max_iteration)
|
||||
|
||||
# get llm usage
|
||||
if "usage" in usage_dict:
|
||||
if usage_dict["usage"] is not None:
|
||||
|
||||
@ -25,6 +25,7 @@ from core.model_runtime.entities.message_entities import ImagePromptMessageConte
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.workflow.nodes.agent.exc import AgentMaxIterationError
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -222,6 +223,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
final_answer += response + "\n"
|
||||
|
||||
# Check if max iteration is reached and model still wants to call tools
|
||||
if iteration_step == max_iteration_steps and tool_calls:
|
||||
raise AgentMaxIterationError(app_config.agent.max_iteration)
|
||||
|
||||
# call tools
|
||||
tool_responses = []
|
||||
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
|
||||
|
||||
@ -1,9 +1,14 @@
|
||||
from collections.abc import Mapping
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
|
||||
class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
# Use separate placeholder for base64-encoded template to avoid confusion
|
||||
_template_b64_placeholder: str = "{{template_b64}}"
|
||||
|
||||
@classmethod
|
||||
def transform_response(cls, response: str):
|
||||
"""
|
||||
@ -13,18 +18,35 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
"""
|
||||
return {"result": cls.extract_result_str_from_response(response)}
|
||||
|
||||
@classmethod
|
||||
def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Override base class to use base64 encoding for template code.
|
||||
This prevents issues with special characters (quotes, newlines) in templates
|
||||
breaking the generated Python script. Fixes #26818.
|
||||
"""
|
||||
script = cls.get_runner_script()
|
||||
# Encode template as base64 to safely embed any content including quotes
|
||||
code_b64 = cls.serialize_code(code)
|
||||
script = script.replace(cls._template_b64_placeholder, code_b64)
|
||||
inputs_str = cls.serialize_inputs(inputs)
|
||||
script = script.replace(cls._inputs_placeholder, inputs_str)
|
||||
return script
|
||||
|
||||
@classmethod
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(f"""
|
||||
# declare main function
|
||||
def main(**inputs):
|
||||
import jinja2
|
||||
template = jinja2.Template('''{cls._code_placeholder}''')
|
||||
return template.render(**inputs)
|
||||
|
||||
import jinja2
|
||||
import json
|
||||
from base64 import b64decode
|
||||
|
||||
# declare main function
|
||||
def main(**inputs):
|
||||
# Decode base64-encoded template to handle special characters safely
|
||||
template_code = b64decode('{cls._template_b64_placeholder}').decode('utf-8')
|
||||
template = jinja2.Template(template_code)
|
||||
return template.render(**inputs)
|
||||
|
||||
# decode and prepare input dict
|
||||
inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8'))
|
||||
|
||||
|
||||
@ -13,6 +13,15 @@ class TemplateTransformer(ABC):
|
||||
_inputs_placeholder: str = "{{inputs}}"
|
||||
_result_tag: str = "<<RESULT>>"
|
||||
|
||||
@classmethod
|
||||
def serialize_code(cls, code: str) -> str:
|
||||
"""
|
||||
Serialize template code to base64 to safely embed in generated script.
|
||||
This prevents issues with special characters like quotes breaking the script.
|
||||
"""
|
||||
code_bytes = code.encode("utf-8")
|
||||
return b64encode(code_bytes).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
def transform_caller(cls, code: str, inputs: Mapping[str, Any]) -> tuple[str, str]:
|
||||
"""
|
||||
|
||||
@ -1,58 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolProviderListCache:
|
||||
"""Cache for tool provider lists"""
|
||||
|
||||
CACHE_TTL = 300 # 5 minutes
|
||||
|
||||
@staticmethod
|
||||
def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str:
|
||||
"""Generate cache key for tool providers list"""
|
||||
type_filter = typ or "all"
|
||||
return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}"
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None:
|
||||
"""Get cached tool providers"""
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
cached_data = redis_client.get(cache_key)
|
||||
if cached_data:
|
||||
try:
|
||||
return json.loads(cached_data.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
logger.warning("Failed to decode cached tool providers data")
|
||||
return None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback()
|
||||
def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]):
|
||||
"""Cache tool providers"""
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers))
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback()
|
||||
def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
|
||||
"""Invalidate cache for tool providers"""
|
||||
if typ:
|
||||
# Invalidate specific type cache
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
redis_client.delete(cache_key)
|
||||
else:
|
||||
# Invalidate all caches for this tenant
|
||||
keys = ["builtin", "model", "api", "workflow", "mcp"]
|
||||
pipeline = redis_client.pipeline()
|
||||
for key in keys:
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, cast(ToolProviderTypeApiLiteral, key))
|
||||
pipeline.delete(cache_key)
|
||||
pipeline.execute()
|
||||
@ -27,26 +27,44 @@ class CleanProcessor:
|
||||
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
|
||||
text = re.sub(pattern, "", text)
|
||||
|
||||
# Remove URL but keep Markdown image URLs
|
||||
# First, temporarily replace Markdown image URLs with a placeholder
|
||||
markdown_image_pattern = r"!\[.*?\]\((https?://[^\s)]+)\)"
|
||||
placeholders: list[str] = []
|
||||
# Remove URL but keep Markdown image URLs and link URLs
|
||||
# Replace the ENTIRE markdown link/image with a single placeholder to protect
|
||||
# the link text (which might also be a URL) from being removed
|
||||
markdown_link_pattern = r"\[([^\]]*)\]\((https?://[^)]+)\)"
|
||||
markdown_image_pattern = r"!\[.*?\]\((https?://[^)]+)\)"
|
||||
placeholders: list[tuple[str, str, str]] = [] # (type, text, url)
|
||||
|
||||
def replace_with_placeholder(match, placeholders=placeholders):
|
||||
def replace_markdown_with_placeholder(match, placeholders=placeholders):
|
||||
link_type = "link"
|
||||
link_text = match.group(1)
|
||||
url = match.group(2)
|
||||
placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
|
||||
placeholders.append((link_type, link_text, url))
|
||||
return placeholder
|
||||
|
||||
def replace_image_with_placeholder(match, placeholders=placeholders):
|
||||
link_type = "image"
|
||||
url = match.group(1)
|
||||
placeholder = f"__MARKDOWN_IMAGE_URL_{len(placeholders)}__"
|
||||
placeholders.append(url)
|
||||
return f""
|
||||
placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
|
||||
placeholders.append((link_type, "image", url))
|
||||
return placeholder
|
||||
|
||||
text = re.sub(markdown_image_pattern, replace_with_placeholder, text)
|
||||
# Protect markdown links first
|
||||
text = re.sub(markdown_link_pattern, replace_markdown_with_placeholder, text)
|
||||
# Then protect markdown images
|
||||
text = re.sub(markdown_image_pattern, replace_image_with_placeholder, text)
|
||||
|
||||
# Now remove all remaining URLs
|
||||
url_pattern = r"https?://[^\s)]+"
|
||||
url_pattern = r"https?://\S+"
|
||||
text = re.sub(url_pattern, "", text)
|
||||
|
||||
# Finally, restore the Markdown image URLs
|
||||
for i, url in enumerate(placeholders):
|
||||
text = text.replace(f"__MARKDOWN_IMAGE_URL_{i}__", url)
|
||||
# Restore the Markdown links and images
|
||||
for i, (link_type, text_or_alt, url) in enumerate(placeholders):
|
||||
placeholder = f"__MARKDOWN_PLACEHOLDER_{i}__"
|
||||
if link_type == "link":
|
||||
text = text.replace(placeholder, f"[{text_or_alt}]({url})")
|
||||
else: # image
|
||||
text = text.replace(placeholder, f"")
|
||||
return text
|
||||
|
||||
def filter_string(self, text):
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
@ -36,6 +37,8 @@ default_retrieval_model = {
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetrievalService:
|
||||
# Cache precompiled regular expressions to avoid repeated compilation
|
||||
@ -106,7 +109,12 @@ class RetrievalService:
|
||||
)
|
||||
)
|
||||
|
||||
concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED)
|
||||
if futures:
|
||||
for future in concurrent.futures.as_completed(futures, timeout=3600):
|
||||
if exceptions:
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
break
|
||||
|
||||
if exceptions:
|
||||
raise ValueError(";\n".join(exceptions))
|
||||
@ -210,6 +218,7 @@ class RetrievalService:
|
||||
)
|
||||
all_documents.extend(documents)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
exceptions.append(str(e))
|
||||
|
||||
@classmethod
|
||||
@ -303,6 +312,7 @@ class RetrievalService:
|
||||
else:
|
||||
all_documents.extend(documents)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
exceptions.append(str(e))
|
||||
|
||||
@classmethod
|
||||
@ -351,6 +361,7 @@ class RetrievalService:
|
||||
else:
|
||||
all_documents.extend(documents)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
exceptions.append(str(e))
|
||||
|
||||
@staticmethod
|
||||
@ -663,7 +674,14 @@ class RetrievalService:
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
)
|
||||
concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED)
|
||||
# Use as_completed for early error propagation - cancel remaining futures on first error
|
||||
if futures:
|
||||
for future in concurrent.futures.as_completed(futures, timeout=300):
|
||||
if future.exception():
|
||||
# Cancel remaining futures to avoid unnecessary waiting
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
break
|
||||
|
||||
if exceptions:
|
||||
raise ValueError(";\n".join(exceptions))
|
||||
|
||||
@ -112,7 +112,7 @@ class ExtractProcessor:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == ".pdf":
|
||||
extractor = PdfExtractor(file_path)
|
||||
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension in {".md", ".markdown", ".mdx"}:
|
||||
extractor = (
|
||||
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
@ -148,7 +148,7 @@ class ExtractProcessor:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == ".pdf":
|
||||
extractor = PdfExtractor(file_path)
|
||||
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension in {".md", ".markdown", ".mdx"}:
|
||||
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in {".htm", ".html"}:
|
||||
|
||||
@ -1,25 +1,57 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pypdfium2
|
||||
import pypdfium2.raw as pdfium_c
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.blob.blob import Blob
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PdfExtractor(BaseExtractor):
|
||||
"""Load pdf files.
|
||||
|
||||
"""
|
||||
PdfExtractor is used to extract text and images from PDF files.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
file_path: Path to the PDF file.
|
||||
tenant_id: Workspace ID.
|
||||
user_id: ID of the user performing the extraction.
|
||||
file_cache_key: Optional cache key for the extracted text.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, file_cache_key: str | None = None):
|
||||
"""Initialize with file path."""
|
||||
# Magic bytes for image format detection: (magic_bytes, extension, mime_type)
|
||||
IMAGE_FORMATS = [
|
||||
(b"\xff\xd8\xff", "jpg", "image/jpeg"),
|
||||
(b"\x89PNG\r\n\x1a\n", "png", "image/png"),
|
||||
(b"\x00\x00\x00\x0c\x6a\x50\x20\x20\x0d\x0a\x87\x0a", "jp2", "image/jp2"),
|
||||
(b"GIF8", "gif", "image/gif"),
|
||||
(b"BM", "bmp", "image/bmp"),
|
||||
(b"II*\x00", "tiff", "image/tiff"),
|
||||
(b"MM\x00*", "tiff", "image/tiff"),
|
||||
(b"II+\x00", "tiff", "image/tiff"),
|
||||
(b"MM\x00+", "tiff", "image/tiff"),
|
||||
]
|
||||
MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS)
|
||||
|
||||
def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None):
|
||||
"""Initialize PdfExtractor."""
|
||||
self._file_path = file_path
|
||||
self._tenant_id = tenant_id
|
||||
self._user_id = user_id
|
||||
self._file_cache_key = file_cache_key
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
@ -50,7 +82,6 @@ class PdfExtractor(BaseExtractor):
|
||||
|
||||
def parse(self, blob: Blob) -> Iterator[Document]:
|
||||
"""Lazily parse the blob."""
|
||||
import pypdfium2 # type: ignore
|
||||
|
||||
with blob.as_bytes_io() as file_path:
|
||||
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
|
||||
@ -59,8 +90,87 @@ class PdfExtractor(BaseExtractor):
|
||||
text_page = page.get_textpage()
|
||||
content = text_page.get_text_range()
|
||||
text_page.close()
|
||||
|
||||
image_content = self._extract_images(page)
|
||||
if image_content:
|
||||
content += "\n" + image_content
|
||||
|
||||
page.close()
|
||||
metadata = {"source": blob.source, "page": page_number}
|
||||
yield Document(page_content=content, metadata=metadata)
|
||||
finally:
|
||||
pdf_reader.close()
|
||||
|
||||
def _extract_images(self, page) -> str:
|
||||
"""
|
||||
Extract images from a PDF page, save them to storage and database,
|
||||
and return markdown image links.
|
||||
|
||||
Args:
|
||||
page: pypdfium2 page object.
|
||||
|
||||
Returns:
|
||||
Markdown string containing links to the extracted images.
|
||||
"""
|
||||
image_content = []
|
||||
upload_files = []
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
|
||||
try:
|
||||
image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,))
|
||||
for obj in image_objects:
|
||||
try:
|
||||
# Extract image bytes
|
||||
img_byte_arr = io.BytesIO()
|
||||
# Extract DCTDecode (JPEG) and JPXDecode (JPEG 2000) images directly
|
||||
# Fallback to png for other formats
|
||||
obj.extract(img_byte_arr, fb_format="png")
|
||||
img_bytes = img_byte_arr.getvalue()
|
||||
|
||||
if not img_bytes:
|
||||
continue
|
||||
|
||||
header = img_bytes[: self.MAX_MAGIC_LEN]
|
||||
image_ext = None
|
||||
mime_type = None
|
||||
for magic, ext, mime in self.IMAGE_FORMATS:
|
||||
if header.startswith(magic):
|
||||
image_ext = ext
|
||||
mime_type = mime
|
||||
break
|
||||
|
||||
if not image_ext or not mime_type:
|
||||
continue
|
||||
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = "image_files/" + self._tenant_id + "/" + file_uuid + "." + image_ext
|
||||
|
||||
storage.save(file_key, img_bytes)
|
||||
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=len(img_bytes),
|
||||
extension=image_ext,
|
||||
mime_type=mime_type,
|
||||
created_by=self._user_id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=naive_utc_now(),
|
||||
used=True,
|
||||
used_by=self._user_id,
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
upload_files.append(upload_file)
|
||||
image_content.append(f"")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to extract image from PDF: %s", e)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning("Failed to get objects from PDF page: %s", e)
|
||||
if upload_files:
|
||||
db.session.add_all(upload_files)
|
||||
db.session.commit()
|
||||
return "\n".join(image_content)
|
||||
|
||||
@ -516,6 +516,9 @@ class DatasetRetrieval:
|
||||
].embedding_model_provider
|
||||
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
||||
with measure_time() as timer:
|
||||
cancel_event = threading.Event()
|
||||
thread_exceptions: list[Exception] = []
|
||||
|
||||
if query:
|
||||
query_thread = threading.Thread(
|
||||
target=self._multiple_retrieve_thread,
|
||||
@ -534,6 +537,8 @@ class DatasetRetrieval:
|
||||
"score_threshold": score_threshold,
|
||||
"query": query,
|
||||
"attachment_id": None,
|
||||
"cancel_event": cancel_event,
|
||||
"thread_exceptions": thread_exceptions,
|
||||
},
|
||||
)
|
||||
all_threads.append(query_thread)
|
||||
@ -557,12 +562,25 @@ class DatasetRetrieval:
|
||||
"score_threshold": score_threshold,
|
||||
"query": None,
|
||||
"attachment_id": attachment_id,
|
||||
"cancel_event": cancel_event,
|
||||
"thread_exceptions": thread_exceptions,
|
||||
},
|
||||
)
|
||||
all_threads.append(attachment_thread)
|
||||
attachment_thread.start()
|
||||
for thread in all_threads:
|
||||
thread.join()
|
||||
|
||||
# Poll threads with short timeout to detect errors quickly (fail-fast)
|
||||
while any(t.is_alive() for t in all_threads):
|
||||
for thread in all_threads:
|
||||
thread.join(timeout=0.1)
|
||||
if thread_exceptions:
|
||||
cancel_event.set()
|
||||
break
|
||||
if thread_exceptions:
|
||||
break
|
||||
|
||||
if thread_exceptions:
|
||||
raise thread_exceptions[0]
|
||||
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
|
||||
|
||||
if all_documents:
|
||||
@ -1404,40 +1422,53 @@ class DatasetRetrieval:
|
||||
score_threshold: float,
|
||||
query: str | None,
|
||||
attachment_id: str | None,
|
||||
cancel_event: threading.Event | None = None,
|
||||
thread_exceptions: list[Exception] | None = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
threads = []
|
||||
all_documents_item: list[Document] = []
|
||||
index_type = None
|
||||
for dataset in available_datasets:
|
||||
index_type = dataset.indexing_technique
|
||||
document_ids_filter = None
|
||||
if dataset.provider != "external":
|
||||
if metadata_condition and not metadata_filter_document_ids:
|
||||
continue
|
||||
if metadata_filter_document_ids:
|
||||
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||
if document_ids:
|
||||
document_ids_filter = document_ids
|
||||
else:
|
||||
try:
|
||||
with flask_app.app_context():
|
||||
threads = []
|
||||
all_documents_item: list[Document] = []
|
||||
index_type = None
|
||||
for dataset in available_datasets:
|
||||
# Check for cancellation signal
|
||||
if cancel_event and cancel_event.is_set():
|
||||
break
|
||||
index_type = dataset.indexing_technique
|
||||
document_ids_filter = None
|
||||
if dataset.provider != "external":
|
||||
if metadata_condition and not metadata_filter_document_ids:
|
||||
continue
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
"flask_app": flask_app,
|
||||
"dataset_id": dataset.id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"all_documents": all_documents_item,
|
||||
"document_ids_filter": document_ids_filter,
|
||||
"metadata_condition": metadata_condition,
|
||||
"attachment_ids": [attachment_id] if attachment_id else None,
|
||||
},
|
||||
)
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
if metadata_filter_document_ids:
|
||||
document_ids = metadata_filter_document_ids.get(dataset.id, [])
|
||||
if document_ids:
|
||||
document_ids_filter = document_ids
|
||||
else:
|
||||
continue
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
"flask_app": flask_app,
|
||||
"dataset_id": dataset.id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"all_documents": all_documents_item,
|
||||
"document_ids_filter": document_ids_filter,
|
||||
"metadata_condition": metadata_condition,
|
||||
"attachment_ids": [attachment_id] if attachment_id else None,
|
||||
},
|
||||
)
|
||||
threads.append(retrieval_thread)
|
||||
retrieval_thread.start()
|
||||
|
||||
# Poll threads with short timeout to respond quickly to cancellation
|
||||
while any(t.is_alive() for t in threads):
|
||||
for thread in threads:
|
||||
thread.join(timeout=0.1)
|
||||
if cancel_event and cancel_event.is_set():
|
||||
break
|
||||
if cancel_event and cancel_event.is_set():
|
||||
break
|
||||
|
||||
if reranking_enable:
|
||||
# do rerank for searched documents
|
||||
@ -1470,3 +1501,8 @@ class DatasetRetrieval:
|
||||
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
|
||||
if all_documents_item:
|
||||
all_documents.extend(all_documents_item)
|
||||
except Exception as e:
|
||||
if cancel_event:
|
||||
cancel_event.set()
|
||||
if thread_exceptions is not None:
|
||||
thread_exceptions.append(e)
|
||||
|
||||
@ -378,7 +378,7 @@ class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(
|
||||
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> tuple[list[ApiToolBundle], str]:
|
||||
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import re
|
||||
def remove_leading_symbols(text: str) -> str:
|
||||
"""
|
||||
Remove leading punctuation or symbols from the given text.
|
||||
Preserves markdown links like [text](url) at the start.
|
||||
|
||||
Args:
|
||||
text (str): The input text to process.
|
||||
@ -11,6 +12,11 @@ def remove_leading_symbols(text: str) -> str:
|
||||
Returns:
|
||||
str: The text with leading punctuation or symbols removed.
|
||||
"""
|
||||
# Check if text starts with a markdown link - preserve it
|
||||
markdown_link_pattern = r"^\[([^\]]+)\]\((https?://[^)]+)\)"
|
||||
if re.match(markdown_link_pattern, text):
|
||||
return text
|
||||
|
||||
# Match Unicode ranges for punctuation and symbols
|
||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'
|
||||
|
||||
@ -54,7 +54,6 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
raise ValueError("app not found")
|
||||
|
||||
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
|
||||
|
||||
controller = WorkflowToolProviderController(
|
||||
entity=ToolProviderEntity(
|
||||
identity=ToolProviderIdentity(
|
||||
@ -67,7 +66,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
credentials_schema=[],
|
||||
plugin_id=None,
|
||||
),
|
||||
provider_id="",
|
||||
provider_id=db_provider.id,
|
||||
)
|
||||
|
||||
controller.tools = [
|
||||
|
||||
@ -60,6 +60,7 @@ class SkipPropagator:
|
||||
if edge_states["has_taken"]:
|
||||
# Enqueue node
|
||||
self._state_manager.enqueue_node(downstream_node_id)
|
||||
self._state_manager.start_execution(downstream_node_id)
|
||||
return
|
||||
|
||||
# All edges are skipped, propagate skip to this node
|
||||
|
||||
@ -119,3 +119,14 @@ class AgentVariableTypeError(AgentNodeError):
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentMaxIterationError(AgentNodeError):
|
||||
"""Exception raised when the agent exceeds the maximum iteration limit."""
|
||||
|
||||
def __init__(self, max_iteration: int):
|
||||
self.max_iteration = max_iteration
|
||||
super().__init__(
|
||||
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
|
||||
f"The agent was unable to complete the task within the allowed number of iterations."
|
||||
)
|
||||
|
||||
@ -12,9 +12,8 @@ from dify_app import DifyApp
|
||||
|
||||
def _get_celery_ssl_options() -> dict[str, Any] | None:
|
||||
"""Get SSL configuration for Celery broker/backend connections."""
|
||||
# Use REDIS_USE_SSL for consistency with the main Redis client
|
||||
# Only apply SSL if we're using Redis as broker/backend
|
||||
if not dify_config.REDIS_USE_SSL:
|
||||
if not dify_config.BROKER_USE_SSL:
|
||||
return None
|
||||
|
||||
# Check if Celery is actually using Redis
|
||||
|
||||
@ -13,12 +13,20 @@ class TencentCosStorage(BaseStorage):
|
||||
super().__init__()
|
||||
|
||||
self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME
|
||||
config = CosConfig(
|
||||
Region=dify_config.TENCENT_COS_REGION,
|
||||
SecretId=dify_config.TENCENT_COS_SECRET_ID,
|
||||
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
|
||||
Scheme=dify_config.TENCENT_COS_SCHEME,
|
||||
)
|
||||
if dify_config.TENCENT_COS_CUSTOM_DOMAIN:
|
||||
config = CosConfig(
|
||||
Domain=dify_config.TENCENT_COS_CUSTOM_DOMAIN,
|
||||
SecretId=dify_config.TENCENT_COS_SECRET_ID,
|
||||
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
|
||||
Scheme=dify_config.TENCENT_COS_SCHEME,
|
||||
)
|
||||
else:
|
||||
config = CosConfig(
|
||||
Region=dify_config.TENCENT_COS_REGION,
|
||||
SecretId=dify_config.TENCENT_COS_SECRET_ID,
|
||||
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
|
||||
Scheme=dify_config.TENCENT_COS_SCHEME,
|
||||
)
|
||||
self.client = CosS3Client(config)
|
||||
|
||||
def save(self, filename, data):
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
@ -12,7 +12,7 @@ annotation_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_annotation_model(api_or_ns: Api | Namespace):
|
||||
def build_annotation_model(api_or_ns: Namespace):
|
||||
"""Build the annotation model for the API or Namespace."""
|
||||
return api_or_ns.model("Annotation", annotation_fields)
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from fields.member_fields import simple_account_fields
|
||||
from libs.helper import TimestampField
|
||||
@ -46,7 +46,7 @@ message_file_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_message_file_model(api_or_ns: Api | Namespace):
|
||||
def build_message_file_model(api_or_ns: Namespace):
|
||||
"""Build the message file fields for the API or Namespace."""
|
||||
return api_or_ns.model("MessageFile", message_file_fields)
|
||||
|
||||
@ -217,7 +217,7 @@ conversation_infinite_scroll_pagination_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
|
||||
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the conversation infinite scroll pagination model for the API or Namespace."""
|
||||
simple_conversation_model = build_simple_conversation_model(api_or_ns)
|
||||
|
||||
@ -226,11 +226,11 @@ def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespa
|
||||
return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields)
|
||||
|
||||
|
||||
def build_conversation_delete_model(api_or_ns: Api | Namespace):
|
||||
def build_conversation_delete_model(api_or_ns: Namespace):
|
||||
"""Build the conversation delete model for the API or Namespace."""
|
||||
return api_or_ns.model("ConversationDelete", conversation_delete_fields)
|
||||
|
||||
|
||||
def build_simple_conversation_model(api_or_ns: Api | Namespace):
|
||||
def build_simple_conversation_model(api_or_ns: Namespace):
|
||||
"""Build the simple conversation model for the API or Namespace."""
|
||||
return api_or_ns.model("SimpleConversation", simple_conversation_fields)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
@ -29,12 +29,12 @@ conversation_variable_infinite_scroll_pagination_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_conversation_variable_model(api_or_ns: Api | Namespace):
|
||||
def build_conversation_variable_model(api_or_ns: Namespace):
|
||||
"""Build the conversation variable model for the API or Namespace."""
|
||||
return api_or_ns.model("ConversationVariable", conversation_variable_fields)
|
||||
|
||||
|
||||
def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
|
||||
def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the conversation variable infinite scroll pagination model for the API or Namespace."""
|
||||
# Build the nested variable model first
|
||||
conversation_variable_model = build_conversation_variable_model(api_or_ns)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
simple_end_user_fields = {
|
||||
"id": fields.String,
|
||||
@ -8,5 +8,5 @@ simple_end_user_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_simple_end_user_model(api_or_ns: Api | Namespace):
|
||||
def build_simple_end_user_model(api_or_ns: Namespace):
|
||||
return api_or_ns.model("SimpleEndUser", simple_end_user_fields)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
@ -14,7 +14,7 @@ upload_config_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_upload_config_model(api_or_ns: Api | Namespace):
|
||||
def build_upload_config_model(api_or_ns: Namespace):
|
||||
"""Build the upload config model for the API or Namespace.
|
||||
|
||||
Args:
|
||||
@ -39,7 +39,7 @@ file_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_file_model(api_or_ns: Api | Namespace):
|
||||
def build_file_model(api_or_ns: Namespace):
|
||||
"""Build the file model for the API or Namespace.
|
||||
|
||||
Args:
|
||||
@ -57,7 +57,7 @@ remote_file_info_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_remote_file_info_model(api_or_ns: Api | Namespace):
|
||||
def build_remote_file_info_model(api_or_ns: Namespace):
|
||||
"""Build the remote file info model for the API or Namespace.
|
||||
|
||||
Args:
|
||||
@ -81,7 +81,7 @@ file_fields_with_signed_url = {
|
||||
}
|
||||
|
||||
|
||||
def build_file_with_signed_url_model(api_or_ns: Api | Namespace):
|
||||
def build_file_with_signed_url_model(api_or_ns: Namespace):
|
||||
"""Build the file with signed URL model for the API or Namespace.
|
||||
|
||||
Args:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from libs.helper import AvatarUrlField, TimestampField
|
||||
|
||||
@ -9,7 +9,7 @@ simple_account_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_simple_account_model(api_or_ns: Api | Namespace):
|
||||
def build_simple_account_model(api_or_ns: Namespace):
|
||||
return api_or_ns.model("SimpleAccount", simple_account_fields)
|
||||
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField
|
||||
@ -10,7 +10,7 @@ feedback_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_feedback_model(api_or_ns: Api | Namespace):
|
||||
def build_feedback_model(api_or_ns: Namespace):
|
||||
"""Build the feedback model for the API or Namespace."""
|
||||
return api_or_ns.model("Feedback", feedback_fields)
|
||||
|
||||
@ -30,7 +30,7 @@ agent_thought_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_agent_thought_model(api_or_ns: Api | Namespace):
|
||||
def build_agent_thought_model(api_or_ns: Namespace):
|
||||
"""Build the agent thought model for the API or Namespace."""
|
||||
return api_or_ns.model("AgentThought", agent_thought_fields)
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
dataset_tag_fields = {
|
||||
"id": fields.String,
|
||||
@ -8,5 +8,5 @@ dataset_tag_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_dataset_tag_fields(api_or_ns: Api | Namespace):
|
||||
def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
|
||||
from fields.member_fields import build_simple_account_model, simple_account_fields
|
||||
@ -17,7 +17,7 @@ workflow_app_log_partial_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace):
|
||||
def build_workflow_app_log_partial_model(api_or_ns: Namespace):
|
||||
"""Build the workflow app log partial model for the API or Namespace."""
|
||||
workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
|
||||
simple_account_model = build_simple_account_model(api_or_ns)
|
||||
@ -43,7 +43,7 @@ workflow_app_log_pagination_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace):
|
||||
def build_workflow_app_log_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the workflow app log pagination model for the API or Namespace."""
|
||||
# Build the nested partial model first
|
||||
workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from flask_restx import Api, Namespace, fields
|
||||
from flask_restx import Namespace, fields
|
||||
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
from fields.member_fields import simple_account_fields
|
||||
@ -19,7 +19,7 @@ workflow_run_for_log_fields = {
|
||||
}
|
||||
|
||||
|
||||
def build_workflow_run_for_log_model(api_or_ns: Api | Namespace):
|
||||
def build_workflow_run_for_log_model(api_or_ns: Namespace):
|
||||
return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields)
|
||||
|
||||
|
||||
|
||||
347
api/libs/archive_storage.py
Normal file
347
api/libs/archive_storage.py
Normal file
@ -0,0 +1,347 @@
|
||||
"""
|
||||
Archive Storage Client for S3-compatible storage.
|
||||
|
||||
This module provides a dedicated storage client for archiving or exporting logs
|
||||
to S3-compatible object storage.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import gzip
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
import boto3
|
||||
import orjson
|
||||
from botocore.client import Config
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ArchiveStorageError(Exception):
|
||||
"""Base exception for archive storage operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ArchiveStorageNotConfiguredError(ArchiveStorageError):
|
||||
"""Raised when archive storage is not properly configured."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ArchiveStorage:
|
||||
"""
|
||||
S3-compatible storage client for archiving or exporting.
|
||||
|
||||
This client provides methods for storing and retrieving archived data in JSONL+gzip format.
|
||||
"""
|
||||
|
||||
def __init__(self, bucket: str):
|
||||
if not dify_config.ARCHIVE_STORAGE_ENABLED:
|
||||
raise ArchiveStorageNotConfiguredError("Archive storage is not enabled")
|
||||
|
||||
if not bucket:
|
||||
raise ArchiveStorageNotConfiguredError("Archive storage bucket is not configured")
|
||||
if not all(
|
||||
[
|
||||
dify_config.ARCHIVE_STORAGE_ENDPOINT,
|
||||
bucket,
|
||||
dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
|
||||
dify_config.ARCHIVE_STORAGE_SECRET_KEY,
|
||||
]
|
||||
):
|
||||
raise ArchiveStorageNotConfiguredError(
|
||||
"Archive storage configuration is incomplete. "
|
||||
"Required: ARCHIVE_STORAGE_ENDPOINT, ARCHIVE_STORAGE_ACCESS_KEY, "
|
||||
"ARCHIVE_STORAGE_SECRET_KEY, and a bucket name"
|
||||
)
|
||||
|
||||
self.bucket = bucket
|
||||
self.client = boto3.client(
|
||||
"s3",
|
||||
endpoint_url=dify_config.ARCHIVE_STORAGE_ENDPOINT,
|
||||
aws_access_key_id=dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
|
||||
aws_secret_access_key=dify_config.ARCHIVE_STORAGE_SECRET_KEY,
|
||||
region_name=dify_config.ARCHIVE_STORAGE_REGION,
|
||||
config=Config(s3={"addressing_style": "path"}),
|
||||
)
|
||||
|
||||
# Verify bucket accessibility
|
||||
try:
|
||||
self.client.head_bucket(Bucket=self.bucket)
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code")
|
||||
if error_code == "404":
|
||||
raise ArchiveStorageNotConfiguredError(f"Archive bucket '{self.bucket}' does not exist")
|
||||
elif error_code == "403":
|
||||
raise ArchiveStorageNotConfiguredError(f"Access denied to archive bucket '{self.bucket}'")
|
||||
else:
|
||||
raise ArchiveStorageError(f"Failed to access archive bucket: {e}")
|
||||
|
||||
def put_object(self, key: str, data: bytes) -> str:
|
||||
"""
|
||||
Upload an object to the archive storage.
|
||||
|
||||
Args:
|
||||
key: Object key (path) within the bucket
|
||||
data: Binary data to upload
|
||||
|
||||
Returns:
|
||||
MD5 checksum of the uploaded data
|
||||
|
||||
Raises:
|
||||
ArchiveStorageError: If upload fails
|
||||
"""
|
||||
checksum = hashlib.md5(data).hexdigest()
|
||||
try:
|
||||
self.client.put_object(
|
||||
Bucket=self.bucket,
|
||||
Key=key,
|
||||
Body=data,
|
||||
ContentMD5=self._content_md5(data),
|
||||
)
|
||||
logger.debug("Uploaded object: %s (size=%d, checksum=%s)", key, len(data), checksum)
|
||||
return checksum
|
||||
except ClientError as e:
|
||||
raise ArchiveStorageError(f"Failed to upload object '{key}': {e}")
|
||||
|
||||
def get_object(self, key: str) -> bytes:
|
||||
"""
|
||||
Download an object from the archive storage.
|
||||
|
||||
Args:
|
||||
key: Object key (path) within the bucket
|
||||
|
||||
Returns:
|
||||
Binary data of the object
|
||||
|
||||
Raises:
|
||||
ArchiveStorageError: If download fails
|
||||
FileNotFoundError: If object does not exist
|
||||
"""
|
||||
try:
|
||||
response = self.client.get_object(Bucket=self.bucket, Key=key)
|
||||
return response["Body"].read()
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code")
|
||||
if error_code == "NoSuchKey":
|
||||
raise FileNotFoundError(f"Archive object not found: {key}")
|
||||
raise ArchiveStorageError(f"Failed to download object '{key}': {e}")
|
||||
|
||||
def get_object_stream(self, key: str) -> Generator[bytes, None, None]:
|
||||
"""
|
||||
Stream an object from the archive storage.
|
||||
|
||||
Args:
|
||||
key: Object key (path) within the bucket
|
||||
|
||||
Yields:
|
||||
Chunks of binary data
|
||||
|
||||
Raises:
|
||||
ArchiveStorageError: If download fails
|
||||
FileNotFoundError: If object does not exist
|
||||
"""
|
||||
try:
|
||||
response = self.client.get_object(Bucket=self.bucket, Key=key)
|
||||
yield from response["Body"].iter_chunks()
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code")
|
||||
if error_code == "NoSuchKey":
|
||||
raise FileNotFoundError(f"Archive object not found: {key}")
|
||||
raise ArchiveStorageError(f"Failed to stream object '{key}': {e}")
|
||||
|
||||
def object_exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if an object exists in the archive storage.
|
||||
|
||||
Args:
|
||||
key: Object key (path) within the bucket
|
||||
|
||||
Returns:
|
||||
True if object exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
self.client.head_object(Bucket=self.bucket, Key=key)
|
||||
return True
|
||||
except ClientError:
|
||||
return False
|
||||
|
||||
def delete_object(self, key: str) -> None:
|
||||
"""
|
||||
Delete an object from the archive storage.
|
||||
|
||||
Args:
|
||||
key: Object key (path) within the bucket
|
||||
|
||||
Raises:
|
||||
ArchiveStorageError: If deletion fails
|
||||
"""
|
||||
try:
|
||||
self.client.delete_object(Bucket=self.bucket, Key=key)
|
||||
logger.debug("Deleted object: %s", key)
|
||||
except ClientError as e:
|
||||
raise ArchiveStorageError(f"Failed to delete object '{key}': {e}")
|
||||
|
||||
def generate_presigned_url(self, key: str, expires_in: int = 3600) -> str:
|
||||
"""
|
||||
Generate a pre-signed URL for downloading an object.
|
||||
|
||||
Args:
|
||||
key: Object key (path) within the bucket
|
||||
expires_in: URL validity duration in seconds (default: 1 hour)
|
||||
|
||||
Returns:
|
||||
Pre-signed URL string.
|
||||
|
||||
Raises:
|
||||
ArchiveStorageError: If generation fails
|
||||
"""
|
||||
try:
|
||||
return self.client.generate_presigned_url(
|
||||
ClientMethod="get_object",
|
||||
Params={"Bucket": self.bucket, "Key": key},
|
||||
ExpiresIn=expires_in,
|
||||
)
|
||||
except ClientError as e:
|
||||
raise ArchiveStorageError(f"Failed to generate pre-signed URL for '{key}': {e}")
|
||||
|
||||
def list_objects(self, prefix: str) -> list[str]:
|
||||
"""
|
||||
List objects under a given prefix.
|
||||
|
||||
Args:
|
||||
prefix: Object key prefix to filter by
|
||||
|
||||
Returns:
|
||||
List of object keys matching the prefix
|
||||
"""
|
||||
keys = []
|
||||
paginator = self.client.get_paginator("list_objects_v2")
|
||||
|
||||
try:
|
||||
for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix):
|
||||
for obj in page.get("Contents", []):
|
||||
keys.append(obj["Key"])
|
||||
except ClientError as e:
|
||||
raise ArchiveStorageError(f"Failed to list objects with prefix '{prefix}': {e}")
|
||||
|
||||
return keys
|
||||
|
||||
@staticmethod
|
||||
def _content_md5(data: bytes) -> str:
|
||||
"""Calculate base64-encoded MD5 for Content-MD5 header."""
|
||||
return base64.b64encode(hashlib.md5(data).digest()).decode()
|
||||
|
||||
@staticmethod
|
||||
def serialize_to_jsonl_gz(records: list[dict[str, Any]]) -> bytes:
|
||||
"""
|
||||
Serialize records to gzipped JSONL format.
|
||||
|
||||
Args:
|
||||
records: List of dictionaries to serialize
|
||||
|
||||
Returns:
|
||||
Gzipped JSONL bytes
|
||||
"""
|
||||
lines = []
|
||||
for record in records:
|
||||
# Convert datetime objects to ISO format strings
|
||||
serialized = ArchiveStorage._serialize_record(record)
|
||||
lines.append(orjson.dumps(serialized))
|
||||
|
||||
jsonl_content = b"\n".join(lines)
|
||||
if jsonl_content:
|
||||
jsonl_content += b"\n"
|
||||
|
||||
return gzip.compress(jsonl_content)
|
||||
|
||||
@staticmethod
|
||||
def deserialize_from_jsonl_gz(data: bytes) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Deserialize gzipped JSONL data to records.
|
||||
|
||||
Args:
|
||||
data: Gzipped JSONL bytes
|
||||
|
||||
Returns:
|
||||
List of dictionaries
|
||||
"""
|
||||
jsonl_content = gzip.decompress(data)
|
||||
records = []
|
||||
|
||||
for line in jsonl_content.splitlines():
|
||||
if line:
|
||||
records.append(orjson.loads(line))
|
||||
|
||||
return records
|
||||
|
||||
@staticmethod
|
||||
def _serialize_record(record: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Serialize a single record, converting special types."""
|
||||
|
||||
def _serialize(item: Any) -> Any:
|
||||
if isinstance(item, datetime.datetime):
|
||||
return item.isoformat()
|
||||
if isinstance(item, dict):
|
||||
return {key: _serialize(value) for key, value in item.items()}
|
||||
if isinstance(item, list):
|
||||
return [_serialize(value) for value in item]
|
||||
return item
|
||||
|
||||
return cast(dict[str, Any], _serialize(record))
|
||||
|
||||
@staticmethod
|
||||
def compute_checksum(data: bytes) -> str:
|
||||
"""Compute MD5 checksum of data."""
|
||||
return hashlib.md5(data).hexdigest()
|
||||
|
||||
|
||||
# Singleton instance (lazy initialization)
|
||||
_archive_storage: ArchiveStorage | None = None
|
||||
_export_storage: ArchiveStorage | None = None
|
||||
|
||||
|
||||
def get_archive_storage() -> ArchiveStorage:
|
||||
"""
|
||||
Get the archive storage singleton instance.
|
||||
|
||||
Returns:
|
||||
ArchiveStorage instance
|
||||
|
||||
Raises:
|
||||
ArchiveStorageNotConfiguredError: If archive storage is not configured
|
||||
"""
|
||||
global _archive_storage
|
||||
if _archive_storage is None:
|
||||
archive_bucket = dify_config.ARCHIVE_STORAGE_ARCHIVE_BUCKET
|
||||
if not archive_bucket:
|
||||
raise ArchiveStorageNotConfiguredError(
|
||||
"Archive storage bucket is not configured. Required: ARCHIVE_STORAGE_ARCHIVE_BUCKET"
|
||||
)
|
||||
_archive_storage = ArchiveStorage(bucket=archive_bucket)
|
||||
return _archive_storage
|
||||
|
||||
|
||||
def get_export_storage() -> ArchiveStorage:
|
||||
"""
|
||||
Get the export storage singleton instance.
|
||||
|
||||
Returns:
|
||||
ArchiveStorage instance
|
||||
"""
|
||||
global _export_storage
|
||||
if _export_storage is None:
|
||||
export_bucket = dify_config.ARCHIVE_STORAGE_EXPORT_BUCKET
|
||||
if not export_bucket:
|
||||
raise ArchiveStorageNotConfiguredError(
|
||||
"Archive export bucket is not configured. Required: ARCHIVE_STORAGE_EXPORT_BUCKET"
|
||||
)
|
||||
_export_storage = ArchiveStorage(bucket=export_bucket)
|
||||
return _export_storage
|
||||
@ -8,7 +8,7 @@ from uuid import uuid4
|
||||
import sqlalchemy as sa
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column, validates
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from .base import TypeBase
|
||||
@ -116,6 +116,12 @@ class Account(UserMixin, TypeBase):
|
||||
role: TenantAccountRole | None = field(default=None, init=False)
|
||||
_current_tenant: "Tenant | None" = field(default=None, init=False)
|
||||
|
||||
@validates("status")
|
||||
def _normalize_status(self, _key: str, value: str | AccountStatus) -> str:
|
||||
if isinstance(value, AccountStatus):
|
||||
return value.value
|
||||
return value
|
||||
|
||||
@property
|
||||
def is_password_set(self):
|
||||
return self.password is not None
|
||||
|
||||
@ -16,6 +16,11 @@ celery_redis = Redis(
|
||||
port=redis_config.get("port") or 6379,
|
||||
password=redis_config.get("password") or None,
|
||||
db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1,
|
||||
ssl=bool(dify_config.BROKER_USE_SSL),
|
||||
ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS if dify_config.BROKER_USE_SSL else None,
|
||||
ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None,
|
||||
ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None,
|
||||
ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
@ -31,6 +32,11 @@ class BillingService:
|
||||
|
||||
compliance_download_rate_limiter = RateLimiter("compliance_download_rate_limiter", 4, 60)
|
||||
|
||||
# Redis key prefix for tenant plan cache
|
||||
_PLAN_CACHE_KEY_PREFIX = "tenant_plan:"
|
||||
# Cache TTL: 10 minutes
|
||||
_PLAN_CACHE_TTL = 600
|
||||
|
||||
@classmethod
|
||||
def get_info(cls, tenant_id: str):
|
||||
params = {"tenant_id": tenant_id}
|
||||
@ -272,14 +278,110 @@ class BillingService:
|
||||
data = resp.get("data", {})
|
||||
|
||||
for tenant_id, plan in data.items():
|
||||
subscription_plan = subscription_adapter.validate_python(plan)
|
||||
results[tenant_id] = subscription_plan
|
||||
try:
|
||||
subscription_plan = subscription_adapter.validate_python(plan)
|
||||
results[tenant_id] = subscription_plan
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"get_plan_bulk: failed to validate subscription plan for tenant(%s)", tenant_id
|
||||
)
|
||||
continue
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch billing info batch for tenants: %s", chunk)
|
||||
logger.exception("get_plan_bulk: failed to fetch billing info batch for tenants: %s", chunk)
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def _make_plan_cache_key(cls, tenant_id: str) -> str:
|
||||
return f"{cls._PLAN_CACHE_KEY_PREFIX}{tenant_id}"
|
||||
|
||||
@classmethod
|
||||
def get_plan_bulk_with_cache(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]:
|
||||
"""
|
||||
Bulk fetch billing subscription plan with cache to reduce billing API loads in batch job scenarios.
|
||||
|
||||
NOTE: if you want to high data consistency, use get_plan_bulk instead.
|
||||
|
||||
Returns:
|
||||
Mapping of tenant_id -> {plan: str, expiration_date: int}
|
||||
"""
|
||||
tenant_plans: dict[str, SubscriptionPlan] = {}
|
||||
|
||||
if not tenant_ids:
|
||||
return tenant_plans
|
||||
|
||||
subscription_adapter = TypeAdapter(SubscriptionPlan)
|
||||
|
||||
# Step 1: Batch fetch from Redis cache using mget
|
||||
redis_keys = [cls._make_plan_cache_key(tenant_id) for tenant_id in tenant_ids]
|
||||
try:
|
||||
cached_values = redis_client.mget(redis_keys)
|
||||
|
||||
if len(cached_values) != len(tenant_ids):
|
||||
raise Exception(
|
||||
"get_plan_bulk_with_cache: unexpected error: redis mget failed: cached values length mismatch"
|
||||
)
|
||||
|
||||
# Map cached values back to tenant_ids
|
||||
cache_misses: list[str] = []
|
||||
|
||||
for tenant_id, cached_value in zip(tenant_ids, cached_values):
|
||||
if cached_value:
|
||||
try:
|
||||
# Redis returns bytes, decode to string and parse JSON
|
||||
json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value
|
||||
plan_dict = json.loads(json_str)
|
||||
subscription_plan = subscription_adapter.validate_python(plan_dict)
|
||||
tenant_plans[tenant_id] = subscription_plan
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"get_plan_bulk_with_cache: process tenant(%s) failed, add to cache misses", tenant_id
|
||||
)
|
||||
cache_misses.append(tenant_id)
|
||||
else:
|
||||
cache_misses.append(tenant_id)
|
||||
|
||||
logger.info(
|
||||
"get_plan_bulk_with_cache: cache hits=%s, cache misses=%s",
|
||||
len(tenant_plans),
|
||||
len(cache_misses),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("get_plan_bulk_with_cache: redis mget failed, falling back to API")
|
||||
cache_misses = list(tenant_ids)
|
||||
|
||||
# Step 2: Fetch missing plans from billing API
|
||||
if cache_misses:
|
||||
bulk_plans = BillingService.get_plan_bulk(cache_misses)
|
||||
|
||||
if bulk_plans:
|
||||
plans_to_cache: dict[str, SubscriptionPlan] = {}
|
||||
|
||||
for tenant_id, subscription_plan in bulk_plans.items():
|
||||
tenant_plans[tenant_id] = subscription_plan
|
||||
plans_to_cache[tenant_id] = subscription_plan
|
||||
|
||||
# Step 3: Batch update Redis cache using pipeline
|
||||
if plans_to_cache:
|
||||
try:
|
||||
pipe = redis_client.pipeline()
|
||||
for tenant_id, subscription_plan in plans_to_cache.items():
|
||||
redis_key = cls._make_plan_cache_key(tenant_id)
|
||||
# Serialize dict to JSON string
|
||||
json_str = json.dumps(subscription_plan)
|
||||
pipe.setex(redis_key, cls._PLAN_CACHE_TTL, json_str)
|
||||
pipe.execute()
|
||||
|
||||
logger.info(
|
||||
"get_plan_bulk_with_cache: cached %s new tenant plans to Redis",
|
||||
len(plans_to_cache),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("get_plan_bulk_with_cache: redis pipeline failed")
|
||||
|
||||
return tenant_plans
|
||||
|
||||
@classmethod
|
||||
def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]:
|
||||
resp = cls._send_request("GET", "/subscription/cleanup/whitelist")
|
||||
|
||||
@ -7,7 +7,6 @@ from httpx import get
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
@ -86,7 +85,9 @@ class ApiToolManageService:
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]:
|
||||
def convert_schema_to_tool_bundles(
|
||||
schema: str, extra_info: dict | None = None
|
||||
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
|
||||
"""
|
||||
convert schema to tool bundles
|
||||
|
||||
@ -104,7 +105,7 @@ class ApiToolManageService:
|
||||
provider_name: str,
|
||||
icon: dict,
|
||||
credentials: dict,
|
||||
schema_type: str,
|
||||
schema_type: ApiProviderSchemaType,
|
||||
schema: str,
|
||||
privacy_policy: str,
|
||||
custom_disclaimer: str,
|
||||
@ -113,9 +114,6 @@ class ApiToolManageService:
|
||||
"""
|
||||
create api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f"invalid schema type {schema}")
|
||||
|
||||
provider_name = provider_name.strip()
|
||||
|
||||
# check if the provider exists
|
||||
@ -178,9 +176,6 @@ class ApiToolManageService:
|
||||
# update labels
|
||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
@ -245,18 +240,15 @@ class ApiToolManageService:
|
||||
original_provider: str,
|
||||
icon: dict,
|
||||
credentials: dict,
|
||||
schema_type: str,
|
||||
_schema_type: ApiProviderSchemaType,
|
||||
schema: str,
|
||||
privacy_policy: str,
|
||||
privacy_policy: str | None,
|
||||
custom_disclaimer: str,
|
||||
labels: list[str],
|
||||
):
|
||||
"""
|
||||
update api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f"invalid schema type {schema}")
|
||||
|
||||
provider_name = provider_name.strip()
|
||||
|
||||
# check if the provider exists
|
||||
@ -281,7 +273,7 @@ class ApiToolManageService:
|
||||
provider.icon = json.dumps(icon)
|
||||
provider.schema = schema
|
||||
provider.description = extra_info.get("description", "")
|
||||
provider.schema_type_str = ApiProviderSchemaType.OPENAPI
|
||||
provider.schema_type_str = schema_type
|
||||
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
|
||||
provider.privacy_policy = privacy_policy
|
||||
provider.custom_disclaimer = custom_disclaimer
|
||||
@ -322,9 +314,6 @@ class ApiToolManageService:
|
||||
# update labels
|
||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
@ -347,9 +336,6 @@ class ApiToolManageService:
|
||||
db.session.delete(provider)
|
||||
db.session.commit()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
@ -366,7 +352,7 @@ class ApiToolManageService:
|
||||
tool_name: str,
|
||||
credentials: dict,
|
||||
parameters: dict,
|
||||
schema_type: str,
|
||||
schema_type: ApiProviderSchemaType,
|
||||
schema: str,
|
||||
):
|
||||
"""
|
||||
|
||||
@ -12,7 +12,6 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
@ -205,9 +204,6 @@ class BuiltinToolManageService:
|
||||
db_provider.name = name
|
||||
|
||||
session.commit()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
@ -290,8 +286,6 @@ class BuiltinToolManageService:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id, "builtin")
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
@ -409,9 +403,6 @@ class BuiltinToolManageService:
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
@ -434,8 +425,6 @@ class BuiltinToolManageService:
|
||||
target_provider.is_default = True
|
||||
session.commit()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
@ -16,14 +15,6 @@ class ToolCommonService:
|
||||
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
# Try to get from cache first
|
||||
cached_result = ToolProviderListCache.get_cached_providers(tenant_id, typ)
|
||||
if cached_result is not None:
|
||||
logger.debug("Returning cached tool providers for tenant %s, type %s", tenant_id, typ)
|
||||
return cached_result
|
||||
|
||||
# Cache miss - fetch from database
|
||||
logger.debug("Cache miss for tool providers, fetching from database for tenant %s, type %s", tenant_id, typ)
|
||||
providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ)
|
||||
|
||||
# add icon
|
||||
@ -32,7 +23,4 @@ class ToolCommonService:
|
||||
|
||||
result = [provider.to_dict() for provider in providers]
|
||||
|
||||
# Cache the result
|
||||
ToolProviderListCache.set_cached_providers(tenant_id, typ, result)
|
||||
|
||||
return result
|
||||
|
||||
@ -5,9 +5,8 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
@ -86,17 +85,13 @@ class WorkflowToolManageService:
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
session.add(workflow_tool_provider)
|
||||
|
||||
if labels is not None:
|
||||
ToolLabelManager.update_tool_labels(
|
||||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
|
||||
)
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
@ -184,9 +179,6 @@ class WorkflowToolManageService:
|
||||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
|
||||
)
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
@ -249,9 +241,6 @@ class WorkflowToolManageService:
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -7,11 +7,14 @@ CODE_LANGUAGE = CodeLanguage.JINJA2
|
||||
|
||||
|
||||
def test_jinja2():
|
||||
"""Test basic Jinja2 template rendering."""
|
||||
template = "Hello {{template}}"
|
||||
# Template must be base64 encoded to match the new safe embedding approach
|
||||
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
|
||||
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
|
||||
code = (
|
||||
Jinja2TemplateTransformer.get_runner_script()
|
||||
.replace(Jinja2TemplateTransformer._code_placeholder, template)
|
||||
.replace(Jinja2TemplateTransformer._template_b64_placeholder, template_b64)
|
||||
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
|
||||
)
|
||||
result = CodeExecutor.execute_code(
|
||||
@ -21,6 +24,7 @@ def test_jinja2():
|
||||
|
||||
|
||||
def test_jinja2_with_code_template():
|
||||
"""Test template rendering via the high-level workflow API."""
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CODE_LANGUAGE, code="Hello {{template}}", inputs={"template": "World"}
|
||||
)
|
||||
@ -28,7 +32,64 @@ def test_jinja2_with_code_template():
|
||||
|
||||
|
||||
def test_jinja2_get_runner_script():
|
||||
"""Test that runner script contains required placeholders."""
|
||||
runner_script = Jinja2TemplateTransformer.get_runner_script()
|
||||
assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._template_b64_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
|
||||
|
||||
|
||||
def test_jinja2_template_with_special_characters():
|
||||
"""
|
||||
Test that templates with special characters (quotes, newlines) render correctly.
|
||||
This is a regression test for issue #26818 where textarea pre-fill values
|
||||
containing special characters would break template rendering.
|
||||
"""
|
||||
# Template with triple quotes, single quotes, double quotes, and newlines
|
||||
template = """<html>
|
||||
<body>
|
||||
<input value="{{ task.get('Task ID', '') }}"/>
|
||||
<textarea>{{ task.get('Issues', 'No issues reported') }}</textarea>
|
||||
<p>Status: "{{ status }}"</p>
|
||||
<pre>'''code block'''</pre>
|
||||
</body>
|
||||
</html>"""
|
||||
inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"}
|
||||
|
||||
result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs)
|
||||
|
||||
# Verify the template rendered correctly with all special characters
|
||||
output = result["result"]
|
||||
assert 'value="TASK-123"' in output
|
||||
assert "<textarea>Line 1\nLine 2\nLine 3</textarea>" in output
|
||||
assert 'Status: "completed"' in output
|
||||
assert "'''code block'''" in output
|
||||
|
||||
|
||||
def test_jinja2_template_with_html_textarea_prefill():
|
||||
"""
|
||||
Specific test for HTML textarea with Jinja2 variable pre-fill.
|
||||
Verifies fix for issue #26818.
|
||||
"""
|
||||
template = "<textarea name='notes'>{{ notes }}</textarea>"
|
||||
notes_content = "This is a multi-line note.\nWith special chars: 'single' and \"double\" quotes."
|
||||
inputs = {"notes": notes_content}
|
||||
|
||||
result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs)
|
||||
|
||||
expected_output = f"<textarea name='notes'>{notes_content}</textarea>"
|
||||
assert result["result"] == expected_output
|
||||
|
||||
|
||||
def test_jinja2_assemble_runner_script_encodes_template():
|
||||
"""Test that assemble_runner_script properly base64 encodes the template."""
|
||||
template = "Hello {{ name }}!"
|
||||
inputs = {"name": "World"}
|
||||
|
||||
script = Jinja2TemplateTransformer.assemble_runner_script(template, inputs)
|
||||
|
||||
# The template should be base64 encoded in the script
|
||||
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
|
||||
assert template_b64 in script
|
||||
# The raw template should NOT appear in the script (it's encoded)
|
||||
assert "Hello {{ name }}!" not in script
|
||||
|
||||
@ -0,0 +1,365 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
class TestBillingServiceGetPlanBulkWithCache:
|
||||
"""
|
||||
Comprehensive integration tests for get_plan_bulk_with_cache using testcontainers.
|
||||
|
||||
This test class covers all major scenarios:
|
||||
- Cache hit/miss scenarios
|
||||
- Redis operation failures and fallback behavior
|
||||
- Invalid cache data handling
|
||||
- TTL expiration handling
|
||||
- Error recovery and logging
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_redis_cleanup(self, flask_app_with_containers):
|
||||
"""Clean up Redis cache before and after each test."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Clean up before test
|
||||
yield
|
||||
# Clean up after test
|
||||
# Delete all test cache keys
|
||||
pattern = f"{BillingService._PLAN_CACHE_KEY_PREFIX}*"
|
||||
keys = redis_client.keys(pattern)
|
||||
if keys:
|
||||
redis_client.delete(*keys)
|
||||
|
||||
def _create_test_plan_data(self, plan: str = "sandbox", expiration_date: int = 1735689600):
|
||||
"""Helper to create test SubscriptionPlan data."""
|
||||
return {"plan": plan, "expiration_date": expiration_date}
|
||||
|
||||
def _set_cache(self, tenant_id: str, plan_data: dict, ttl: int = 600):
|
||||
"""Helper to set cache data in Redis."""
|
||||
cache_key = BillingService._make_plan_cache_key(tenant_id)
|
||||
json_str = json.dumps(plan_data)
|
||||
redis_client.setex(cache_key, ttl, json_str)
|
||||
|
||||
def _get_cache(self, tenant_id: str):
|
||||
"""Helper to get cache data from Redis."""
|
||||
cache_key = BillingService._make_plan_cache_key(tenant_id)
|
||||
value = redis_client.get(cache_key)
|
||||
if value:
|
||||
if isinstance(value, bytes):
|
||||
return value.decode("utf-8")
|
||||
return value
|
||||
return None
|
||||
|
||||
def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers):
|
||||
"""Test bulk plan retrieval when all tenants are in cache."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
|
||||
expected_plans = {
|
||||
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
|
||||
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||
"tenant-3": self._create_test_plan_data("team", 1798761600),
|
||||
}
|
||||
|
||||
# Pre-populate cache
|
||||
for tenant_id, plan_data in expected_plans.items():
|
||||
self._set_cache(tenant_id, plan_data)
|
||||
|
||||
# Act
|
||||
with patch.object(BillingService, "get_plan_bulk") as mock_get_plan_bulk:
|
||||
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
assert result["tenant-1"]["plan"] == "sandbox"
|
||||
assert result["tenant-1"]["expiration_date"] == 1735689600
|
||||
assert result["tenant-2"]["plan"] == "professional"
|
||||
assert result["tenant-2"]["expiration_date"] == 1767225600
|
||||
assert result["tenant-3"]["plan"] == "team"
|
||||
assert result["tenant-3"]["expiration_date"] == 1798761600
|
||||
|
||||
# Verify API was not called
|
||||
mock_get_plan_bulk.assert_not_called()
|
||||
|
||||
def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers):
|
||||
"""Test bulk plan retrieval when all tenants are not in cache."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1", "tenant-2"]
|
||||
expected_plans = {
|
||||
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
|
||||
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||
}
|
||||
|
||||
# Act
|
||||
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
|
||||
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result["tenant-1"]["plan"] == "sandbox"
|
||||
assert result["tenant-2"]["plan"] == "professional"
|
||||
|
||||
# Verify API was called with correct tenant_ids
|
||||
mock_get_plan_bulk.assert_called_once_with(tenant_ids)
|
||||
|
||||
# Verify data was written to cache
|
||||
cached_1 = self._get_cache("tenant-1")
|
||||
cached_2 = self._get_cache("tenant-2")
|
||||
assert cached_1 is not None
|
||||
assert cached_2 is not None
|
||||
|
||||
# Verify cache content
|
||||
cached_data_1 = json.loads(cached_1)
|
||||
cached_data_2 = json.loads(cached_2)
|
||||
assert cached_data_1 == expected_plans["tenant-1"]
|
||||
assert cached_data_2 == expected_plans["tenant-2"]
|
||||
|
||||
# Verify TTL is set
|
||||
cache_key_1 = BillingService._make_plan_cache_key("tenant-1")
|
||||
ttl_1 = redis_client.ttl(cache_key_1)
|
||||
assert ttl_1 > 0
|
||||
assert ttl_1 <= 600 # Should be <= 600 seconds
|
||||
|
||||
def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers):
|
||||
"""Test bulk plan retrieval when some tenants are in cache, some are not."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
|
||||
# Pre-populate cache for tenant-1 and tenant-2
|
||||
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600))
|
||||
self._set_cache("tenant-2", self._create_test_plan_data("professional", 1767225600))
|
||||
|
||||
# tenant-3 is not in cache
|
||||
missing_plan = {"tenant-3": self._create_test_plan_data("team", 1798761600)}
|
||||
|
||||
# Act
|
||||
with patch.object(BillingService, "get_plan_bulk", return_value=missing_plan) as mock_get_plan_bulk:
|
||||
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
assert result["tenant-1"]["plan"] == "sandbox"
|
||||
assert result["tenant-2"]["plan"] == "professional"
|
||||
assert result["tenant-3"]["plan"] == "team"
|
||||
|
||||
# Verify API was called only for missing tenant
|
||||
mock_get_plan_bulk.assert_called_once_with(["tenant-3"])
|
||||
|
||||
# Verify tenant-3 data was written to cache
|
||||
cached_3 = self._get_cache("tenant-3")
|
||||
assert cached_3 is not None
|
||||
cached_data_3 = json.loads(cached_3)
|
||||
assert cached_data_3 == missing_plan["tenant-3"]
|
||||
|
||||
def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers):
|
||||
"""Test fallback to API when Redis mget fails."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1", "tenant-2"]
|
||||
expected_plans = {
|
||||
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
|
||||
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||
}
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch.object(redis_client, "mget", side_effect=Exception("Redis connection error")),
|
||||
patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk,
|
||||
):
|
||||
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result["tenant-1"]["plan"] == "sandbox"
|
||||
assert result["tenant-2"]["plan"] == "professional"
|
||||
|
||||
# Verify API was called for all tenants (fallback)
|
||||
mock_get_plan_bulk.assert_called_once_with(tenant_ids)
|
||||
|
||||
# Verify data was written to cache after fallback
|
||||
cached_1 = self._get_cache("tenant-1")
|
||||
cached_2 = self._get_cache("tenant-2")
|
||||
assert cached_1 is not None
|
||||
assert cached_2 is not None
|
||||
|
||||
def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers):
|
||||
"""Test fallback to API when cache contains invalid JSON."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
|
||||
|
||||
# Set valid cache for tenant-1
|
||||
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600))
|
||||
|
||||
# Set invalid JSON for tenant-2
|
||||
cache_key_2 = BillingService._make_plan_cache_key("tenant-2")
|
||||
redis_client.setex(cache_key_2, 600, "invalid json {")
|
||||
|
||||
# tenant-3 is not in cache
|
||||
expected_plans = {
|
||||
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||
"tenant-3": self._create_test_plan_data("team", 1798761600),
|
||||
}
|
||||
|
||||
# Act
|
||||
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
|
||||
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
assert result["tenant-1"]["plan"] == "sandbox" # From cache
|
||||
assert result["tenant-2"]["plan"] == "professional" # From API (fallback)
|
||||
assert result["tenant-3"]["plan"] == "team" # From API
|
||||
|
||||
# Verify API was called for tenant-2 and tenant-3
|
||||
mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"])
|
||||
|
||||
# Verify tenant-2's invalid JSON was replaced with correct data in cache
|
||||
cached_2 = self._get_cache("tenant-2")
|
||||
assert cached_2 is not None
|
||||
cached_data_2 = json.loads(cached_2)
|
||||
assert cached_data_2 == expected_plans["tenant-2"]
|
||||
assert cached_data_2["plan"] == "professional"
|
||||
assert cached_data_2["expiration_date"] == 1767225600
|
||||
|
||||
# Verify tenant-2 cache has correct TTL
|
||||
cache_key_2_new = BillingService._make_plan_cache_key("tenant-2")
|
||||
ttl_2 = redis_client.ttl(cache_key_2_new)
|
||||
assert ttl_2 > 0
|
||||
assert ttl_2 <= 600
|
||||
|
||||
# Verify tenant-3 data was also written to cache
|
||||
cached_3 = self._get_cache("tenant-3")
|
||||
assert cached_3 is not None
|
||||
cached_data_3 = json.loads(cached_3)
|
||||
assert cached_data_3 == expected_plans["tenant-3"]
|
||||
|
||||
def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers):
|
||||
"""Test fallback to API when cache data doesn't match SubscriptionPlan schema."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
|
||||
|
||||
# Set valid cache for tenant-1
|
||||
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600))
|
||||
|
||||
# Set invalid plan data for tenant-2 (missing expiration_date)
|
||||
cache_key_2 = BillingService._make_plan_cache_key("tenant-2")
|
||||
invalid_data = json.dumps({"plan": "professional"}) # Missing expiration_date
|
||||
redis_client.setex(cache_key_2, 600, invalid_data)
|
||||
|
||||
# tenant-3 is not in cache
|
||||
expected_plans = {
|
||||
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||
"tenant-3": self._create_test_plan_data("team", 1798761600),
|
||||
}
|
||||
|
||||
# Act
|
||||
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
|
||||
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
assert result["tenant-1"]["plan"] == "sandbox" # From cache
|
||||
assert result["tenant-2"]["plan"] == "professional" # From API (fallback)
|
||||
assert result["tenant-3"]["plan"] == "team" # From API
|
||||
|
||||
# Verify API was called for tenant-2 and tenant-3
|
||||
mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"])
|
||||
|
||||
def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers):
|
||||
"""Test that pipeline failure doesn't affect return value."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1", "tenant-2"]
|
||||
expected_plans = {
|
||||
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
|
||||
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||
}
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch.object(BillingService, "get_plan_bulk", return_value=expected_plans),
|
||||
patch.object(redis_client, "pipeline") as mock_pipeline,
|
||||
):
|
||||
# Create a mock pipeline that fails on execute
|
||||
mock_pipe = mock_pipeline.return_value
|
||||
mock_pipe.execute.side_effect = Exception("Pipeline execution failed")
|
||||
|
||||
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||
|
||||
# Assert - Function should still return correct result despite pipeline failure
|
||||
assert len(result) == 2
|
||||
assert result["tenant-1"]["plan"] == "sandbox"
|
||||
assert result["tenant-2"]["plan"] == "professional"
|
||||
|
||||
# Verify pipeline was attempted
|
||||
mock_pipeline.assert_called_once()
|
||||
|
||||
def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers):
|
||||
"""Test with empty tenant_ids list."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Act
|
||||
with patch.object(BillingService, "get_plan_bulk") as mock_get_plan_bulk:
|
||||
result = BillingService.get_plan_bulk_with_cache([])
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
assert len(result) == 0
|
||||
|
||||
# Verify no API calls
|
||||
mock_get_plan_bulk.assert_not_called()
|
||||
|
||||
# Verify no Redis operations (mget with empty list would return empty list)
|
||||
# But we should check that mget was not called at all
|
||||
# Since we can't easily verify this without more mocking, we just verify the result
|
||||
|
||||
def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers):
|
||||
"""Test that expired cache keys are treated as cache misses."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-1", "tenant-2"]
|
||||
|
||||
# Set cache for tenant-1 with very short TTL (1 second) to simulate expiration
|
||||
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600), ttl=1)
|
||||
|
||||
# Wait for TTL to expire (key will be deleted by Redis)
|
||||
import time
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
# Verify cache is expired (key doesn't exist)
|
||||
cache_key_1 = BillingService._make_plan_cache_key("tenant-1")
|
||||
exists = redis_client.exists(cache_key_1)
|
||||
assert exists == 0 # Key doesn't exist (expired)
|
||||
|
||||
# tenant-2 is not in cache
|
||||
expected_plans = {
|
||||
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
|
||||
"tenant-2": self._create_test_plan_data("professional", 1767225600),
|
||||
}
|
||||
|
||||
# Act
|
||||
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
|
||||
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result["tenant-1"]["plan"] == "sandbox"
|
||||
assert result["tenant-2"]["plan"] == "professional"
|
||||
|
||||
# Verify API was called for both tenants (tenant-1 expired, tenant-2 missing)
|
||||
mock_get_plan_bulk.assert_called_once_with(tenant_ids)
|
||||
|
||||
# Verify both were written to cache with correct TTL
|
||||
cache_key_1_new = BillingService._make_plan_cache_key("tenant-1")
|
||||
cache_key_2 = BillingService._make_plan_cache_key("tenant-2")
|
||||
ttl_1_new = redis_client.ttl(cache_key_1_new)
|
||||
ttl_2 = redis_client.ttl(cache_key_2)
|
||||
assert ttl_1_new > 0
|
||||
assert ttl_1_new <= 600
|
||||
assert ttl_2 > 0
|
||||
assert ttl_2 <= 600
|
||||
@ -12,10 +12,12 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin):
|
||||
_, Jinja2TemplateTransformer = self.jinja2_imports
|
||||
|
||||
template = "Hello {{template}}"
|
||||
# Template must be base64 encoded to match the new safe embedding approach
|
||||
template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8")
|
||||
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
|
||||
code = (
|
||||
Jinja2TemplateTransformer.get_runner_script()
|
||||
.replace(Jinja2TemplateTransformer._code_placeholder, template)
|
||||
.replace(Jinja2TemplateTransformer._template_b64_placeholder, template_b64)
|
||||
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
|
||||
)
|
||||
result = CodeExecutor.execute_code(
|
||||
@ -37,6 +39,34 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin):
|
||||
_, Jinja2TemplateTransformer = self.jinja2_imports
|
||||
|
||||
runner_script = Jinja2TemplateTransformer.get_runner_script()
|
||||
assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._template_b64_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
|
||||
|
||||
def test_jinja2_template_with_special_characters(self, flask_app_with_containers):
|
||||
"""
|
||||
Test that templates with special characters (quotes, newlines) render correctly.
|
||||
This is a regression test for issue #26818 where textarea pre-fill values
|
||||
containing special characters would break template rendering.
|
||||
"""
|
||||
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||
|
||||
# Template with triple quotes, single quotes, double quotes, and newlines
|
||||
template = """<html>
|
||||
<body>
|
||||
<input value="{{ task.get('Task ID', '') }}"/>
|
||||
<textarea>{{ task.get('Issues', 'No issues reported') }}</textarea>
|
||||
<p>Status: "{{ status }}"</p>
|
||||
<pre>'''code block'''</pre>
|
||||
</body>
|
||||
</html>"""
|
||||
inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"}
|
||||
|
||||
result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs)
|
||||
|
||||
# Verify the template rendered correctly with all special characters
|
||||
output = result["result"]
|
||||
assert 'value="TASK-123"' in output
|
||||
assert "<textarea>Line 1\nLine 2\nLine 3</textarea>" in output
|
||||
assert 'Status: "completed"' in output
|
||||
assert "'''code block'''" in output
|
||||
|
||||
69
api/tests/unit_tests/controllers/common/test_fields.py
Normal file
69
api/tests/unit_tests/controllers/common/test_fields.py
Normal file
@ -0,0 +1,69 @@
|
||||
import builtins
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask.views import MethodView as FlaskMethodView
|
||||
|
||||
_NEEDS_METHOD_VIEW_CLEANUP = False
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = FlaskMethodView
|
||||
_NEEDS_METHOD_VIEW_CLEANUP = True
|
||||
from controllers.common.fields import Parameters, Site
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from models.model import IconType
|
||||
|
||||
|
||||
def test_parameters_model_round_trip():
|
||||
parameters = get_parameters_from_feature_dict(features_dict={}, user_input_form=[])
|
||||
|
||||
model = Parameters.model_validate(parameters)
|
||||
|
||||
assert model.model_dump(mode="json") == parameters
|
||||
|
||||
|
||||
def test_site_icon_url_uses_signed_url_for_image_icon():
|
||||
site = SimpleNamespace(
|
||||
title="Example",
|
||||
chat_color_theme=None,
|
||||
chat_color_theme_inverted=False,
|
||||
icon_type=IconType.IMAGE,
|
||||
icon="file-id",
|
||||
icon_background=None,
|
||||
description=None,
|
||||
copyright=None,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer=None,
|
||||
default_language="en-US",
|
||||
show_workflow_steps=True,
|
||||
use_icon_as_answer_icon=False,
|
||||
)
|
||||
|
||||
with patch("controllers.common.fields.file_helpers.get_signed_file_url", return_value="signed") as mock_helper:
|
||||
model = Site.model_validate(site)
|
||||
|
||||
assert model.icon_url == "signed"
|
||||
mock_helper.assert_called_once_with("file-id")
|
||||
|
||||
|
||||
def test_site_icon_url_is_none_for_non_image_icon():
|
||||
site = SimpleNamespace(
|
||||
title="Example",
|
||||
chat_color_theme=None,
|
||||
chat_color_theme_inverted=False,
|
||||
icon_type=IconType.EMOJI,
|
||||
icon="file-id",
|
||||
icon_background=None,
|
||||
description=None,
|
||||
copyright=None,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer=None,
|
||||
default_language="en-US",
|
||||
show_workflow_steps=True,
|
||||
use_icon_as_answer_icon=False,
|
||||
)
|
||||
|
||||
with patch("controllers.common.fields.file_helpers.get_signed_file_url") as mock_helper:
|
||||
model = Site.model_validate(site)
|
||||
|
||||
assert model.icon_url is None
|
||||
mock_helper.assert_not_called()
|
||||
@ -0,0 +1,254 @@
|
||||
"""
|
||||
Unit tests for XSS prevention in App payloads.
|
||||
|
||||
This test module validates that HTML tags, JavaScript, and other potentially
|
||||
dangerous content are rejected in App names and descriptions.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app.app import CopyAppPayload, CreateAppPayload, UpdateAppPayload
|
||||
|
||||
|
||||
class TestXSSPreventionUnit:
|
||||
"""Unit tests for XSS prevention in App payloads."""
|
||||
|
||||
def test_create_app_valid_names(self):
|
||||
"""Test CreateAppPayload with valid app names."""
|
||||
# Normal app names should be valid
|
||||
valid_names = [
|
||||
"My App",
|
||||
"Test App 123",
|
||||
"App with - dash",
|
||||
"App with _ underscore",
|
||||
"App with + plus",
|
||||
"App with () parentheses",
|
||||
"App with [] brackets",
|
||||
"App with {} braces",
|
||||
"App with ! exclamation",
|
||||
"App with @ at",
|
||||
"App with # hash",
|
||||
"App with $ dollar",
|
||||
"App with % percent",
|
||||
"App with ^ caret",
|
||||
"App with & ampersand",
|
||||
"App with * asterisk",
|
||||
"Unicode: 测试应用",
|
||||
"Emoji: 🤖",
|
||||
"Mixed: Test 测试 123",
|
||||
]
|
||||
|
||||
for name in valid_names:
|
||||
payload = CreateAppPayload(
|
||||
name=name,
|
||||
mode="chat",
|
||||
)
|
||||
assert payload.name == name
|
||||
|
||||
def test_create_app_xss_script_tags(self):
|
||||
"""Test CreateAppPayload rejects script tags."""
|
||||
xss_payloads = [
|
||||
"<script>alert(document.cookie)</script>",
|
||||
"<Script>alert(1)</Script>",
|
||||
"<SCRIPT>alert('XSS')</SCRIPT>",
|
||||
"<script>alert(String.fromCharCode(88,83,83))</script>",
|
||||
"<script src='evil.js'></script>",
|
||||
"<script>document.location='http://evil.com'</script>",
|
||||
]
|
||||
|
||||
for name in xss_payloads:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
CreateAppPayload(name=name, mode="chat")
|
||||
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||
|
||||
def test_create_app_xss_iframe_tags(self):
|
||||
"""Test CreateAppPayload rejects iframe tags."""
|
||||
xss_payloads = [
|
||||
"<iframe src='evil.com'></iframe>",
|
||||
"<Iframe srcdoc='<script>alert(1)</script>'></iframe>",
|
||||
"<IFRAME src='javascript:alert(1)'></iframe>",
|
||||
]
|
||||
|
||||
for name in xss_payloads:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
CreateAppPayload(name=name, mode="chat")
|
||||
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||
|
||||
def test_create_app_xss_javascript_protocol(self):
|
||||
"""Test CreateAppPayload rejects javascript: protocol."""
|
||||
xss_payloads = [
|
||||
"javascript:alert(1)",
|
||||
"JAVASCRIPT:alert(1)",
|
||||
"JavaScript:alert(document.cookie)",
|
||||
"javascript:void(0)",
|
||||
"javascript://comment%0Aalert(1)",
|
||||
]
|
||||
|
||||
for name in xss_payloads:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
CreateAppPayload(name=name, mode="chat")
|
||||
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||
|
||||
def test_create_app_xss_svg_onload(self):
|
||||
"""Test CreateAppPayload rejects SVG with onload."""
|
||||
xss_payloads = [
|
||||
"<svg onload=alert(1)>",
|
||||
"<SVG ONLOAD=alert(1)>",
|
||||
"<svg/x/onload=alert(1)>",
|
||||
]
|
||||
|
||||
for name in xss_payloads:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
CreateAppPayload(name=name, mode="chat")
|
||||
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||
|
||||
def test_create_app_xss_event_handlers(self):
|
||||
"""Test CreateAppPayload rejects HTML event handlers."""
|
||||
xss_payloads = [
|
||||
"<div onclick=alert(1)>",
|
||||
"<img onerror=alert(1)>",
|
||||
"<body onload=alert(1)>",
|
||||
"<input onfocus=alert(1)>",
|
||||
"<a onmouseover=alert(1)>",
|
||||
"<DIV ONCLICK=alert(1)>",
|
||||
"<img src=x onerror=alert(1)>",
|
||||
]
|
||||
|
||||
for name in xss_payloads:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
CreateAppPayload(name=name, mode="chat")
|
||||
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||
|
||||
def test_create_app_xss_object_embed(self):
|
||||
"""Test CreateAppPayload rejects object and embed tags."""
|
||||
xss_payloads = [
|
||||
"<object data='evil.swf'></object>",
|
||||
"<embed src='evil.swf'>",
|
||||
"<OBJECT data='javascript:alert(1)'></OBJECT>",
|
||||
]
|
||||
|
||||
for name in xss_payloads:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
CreateAppPayload(name=name, mode="chat")
|
||||
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||
|
||||
def test_create_app_xss_link_javascript(self):
|
||||
"""Test CreateAppPayload rejects link tags with javascript."""
|
||||
xss_payloads = [
|
||||
"<link href='javascript:alert(1)'>",
|
||||
"<LINK HREF='javascript:alert(1)'>",
|
||||
]
|
||||
|
||||
for name in xss_payloads:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
CreateAppPayload(name=name, mode="chat")
|
||||
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||
|
||||
def test_create_app_xss_in_description(self):
|
||||
"""Test CreateAppPayload rejects XSS in description."""
|
||||
xss_descriptions = [
|
||||
"<script>alert(1)</script>",
|
||||
"javascript:alert(1)",
|
||||
"<img onerror=alert(1)>",
|
||||
]
|
||||
|
||||
for description in xss_descriptions:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
CreateAppPayload(
|
||||
name="Valid Name",
|
||||
mode="chat",
|
||||
description=description,
|
||||
)
|
||||
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||
|
||||
def test_create_app_valid_descriptions(self):
|
||||
"""Test CreateAppPayload with valid descriptions."""
|
||||
valid_descriptions = [
|
||||
"A simple description",
|
||||
"Description with < and > symbols",
|
||||
"Description with & ampersand",
|
||||
"Description with 'quotes' and \"double quotes\"",
|
||||
"Description with / slashes",
|
||||
"Description with \\ backslashes",
|
||||
"Description with ; semicolons",
|
||||
"Unicode: 这是一个描述",
|
||||
"Emoji: 🎉🚀",
|
||||
]
|
||||
|
||||
for description in valid_descriptions:
|
||||
payload = CreateAppPayload(
|
||||
name="Valid App Name",
|
||||
mode="chat",
|
||||
description=description,
|
||||
)
|
||||
assert payload.description == description
|
||||
|
||||
def test_create_app_none_description(self):
|
||||
"""Test CreateAppPayload with None description."""
|
||||
payload = CreateAppPayload(
|
||||
name="Valid App Name",
|
||||
mode="chat",
|
||||
description=None,
|
||||
)
|
||||
assert payload.description is None
|
||||
|
||||
def test_update_app_xss_prevention(self):
|
||||
"""Test UpdateAppPayload also prevents XSS."""
|
||||
xss_names = [
|
||||
"<script>alert(1)</script>",
|
||||
"javascript:alert(1)",
|
||||
"<img onerror=alert(1)>",
|
||||
]
|
||||
|
||||
for name in xss_names:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
UpdateAppPayload(name=name)
|
||||
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||
|
||||
def test_update_app_valid_names(self):
|
||||
"""Test UpdateAppPayload with valid names."""
|
||||
payload = UpdateAppPayload(name="Valid Updated Name")
|
||||
assert payload.name == "Valid Updated Name"
|
||||
|
||||
def test_copy_app_xss_prevention(self):
|
||||
"""Test CopyAppPayload also prevents XSS."""
|
||||
xss_names = [
|
||||
"<script>alert(1)</script>",
|
||||
"javascript:alert(1)",
|
||||
"<img onerror=alert(1)>",
|
||||
]
|
||||
|
||||
for name in xss_names:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
CopyAppPayload(name=name)
|
||||
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||
|
||||
def test_copy_app_valid_names(self):
|
||||
"""Test CopyAppPayload with valid names."""
|
||||
payload = CopyAppPayload(name="Valid Copy Name")
|
||||
assert payload.name == "Valid Copy Name"
|
||||
|
||||
def test_copy_app_none_name(self):
|
||||
"""Test CopyAppPayload with None name (should be allowed)."""
|
||||
payload = CopyAppPayload(name=None)
|
||||
assert payload.name is None
|
||||
|
||||
def test_edge_case_angle_brackets_content(self):
|
||||
"""Test that angle brackets with actual content are rejected."""
|
||||
# Angle brackets without valid HTML-like patterns should be checked
|
||||
# The regex pattern <.*?on\w+\s*= should catch event handlers
|
||||
# But let's verify other patterns too
|
||||
|
||||
# Valid: angle brackets used as symbols (not matched by our patterns)
|
||||
# Our patterns specifically look for dangerous constructs
|
||||
|
||||
# Invalid: actual HTML tags with event handlers
|
||||
invalid_names = [
|
||||
"<div onclick=xss>",
|
||||
"<img src=x onerror=alert(1)>",
|
||||
]
|
||||
|
||||
for name in invalid_names:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
CreateAppPayload(name=name, mode="chat")
|
||||
assert "invalid characters or patterns" in str(exc_info.value).lower()
|
||||
@ -171,7 +171,7 @@ class TestOAuthCallback:
|
||||
):
|
||||
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
mock_generate_account.return_value = oauth_setup["account"]
|
||||
mock_generate_account.return_value = (oauth_setup["account"], True)
|
||||
mock_account_service.login.return_value = oauth_setup["token_pair"]
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||
@ -179,7 +179,7 @@ class TestOAuthCallback:
|
||||
|
||||
oauth_setup["provider"].get_access_token.assert_called_once_with("test_code")
|
||||
oauth_setup["provider"].get_user_info.assert_called_once_with("access_token")
|
||||
mock_redirect.assert_called_once_with("http://localhost:3000")
|
||||
mock_redirect.assert_called_once_with("http://localhost:3000?oauth_new_user=true")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exception", "expected_error"),
|
||||
@ -223,7 +223,7 @@ class TestOAuthCallback:
|
||||
# This documents actual behavior. See test_defensive_check_for_closed_account_status for details
|
||||
(
|
||||
AccountStatus.CLOSED.value,
|
||||
"http://localhost:3000",
|
||||
"http://localhost:3000?oauth_new_user=false",
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -260,7 +260,7 @@ class TestOAuthCallback:
|
||||
account = MagicMock()
|
||||
account.status = account_status
|
||||
account.id = "123"
|
||||
mock_generate_account.return_value = account
|
||||
mock_generate_account.return_value = (account, False)
|
||||
|
||||
# Mock login for CLOSED status
|
||||
mock_token_pair = MagicMock()
|
||||
@ -296,7 +296,7 @@ class TestOAuthCallback:
|
||||
|
||||
mock_account = MagicMock()
|
||||
mock_account.status = AccountStatus.PENDING
|
||||
mock_generate_account.return_value = mock_account
|
||||
mock_generate_account.return_value = (mock_account, False)
|
||||
|
||||
mock_token_pair = MagicMock()
|
||||
mock_token_pair.access_token = "jwt_access_token"
|
||||
@ -360,7 +360,7 @@ class TestOAuthCallback:
|
||||
closed_account.status = AccountStatus.CLOSED
|
||||
closed_account.id = "123"
|
||||
closed_account.name = "Closed Account"
|
||||
mock_generate_account.return_value = closed_account
|
||||
mock_generate_account.return_value = (closed_account, False)
|
||||
|
||||
# Mock successful login (current behavior)
|
||||
mock_token_pair = MagicMock()
|
||||
@ -374,7 +374,7 @@ class TestOAuthCallback:
|
||||
resource.get("github")
|
||||
|
||||
# Verify current behavior: login succeeds (this is NOT ideal)
|
||||
mock_redirect.assert_called_once_with("http://localhost:3000")
|
||||
mock_redirect.assert_called_once_with("http://localhost:3000?oauth_new_user=false")
|
||||
mock_account_service.login.assert_called_once()
|
||||
|
||||
# Document expected behavior in comments:
|
||||
@ -458,8 +458,9 @@ class TestAccountGeneration:
|
||||
with pytest.raises(AccountRegisterError):
|
||||
_generate_account("github", user_info)
|
||||
else:
|
||||
result = _generate_account("github", user_info)
|
||||
result, oauth_new_user = _generate_account("github", user_info)
|
||||
assert result == mock_account
|
||||
assert oauth_new_user == should_create
|
||||
|
||||
if should_create:
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
@ -490,9 +491,10 @@ class TestAccountGeneration:
|
||||
mock_tenant_service.create_tenant.return_value = mock_new_tenant
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
|
||||
result = _generate_account("github", user_info)
|
||||
result, oauth_new_user = _generate_account("github", user_info)
|
||||
|
||||
assert result == mock_account
|
||||
assert oauth_new_user is False
|
||||
mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace")
|
||||
mock_tenant_service.create_tenant_member.assert_called_once_with(
|
||||
mock_new_tenant, mock_account, role="owner"
|
||||
|
||||
@ -41,13 +41,10 @@ def client():
|
||||
@patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")
|
||||
)
|
||||
@patch("controllers.console.workspace.tool_providers.ToolProviderListCache.invalidate_cache", return_value=None)
|
||||
@patch("controllers.console.workspace.tool_providers.Session")
|
||||
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url")
|
||||
@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant")
|
||||
def test_create_mcp_provider_populates_tools(
|
||||
mock_reconnect, mock_session, mock_invalidate_cache, mock_current_account_with_tenant, client
|
||||
):
|
||||
def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_current_account_with_tenant, client):
|
||||
# Arrange: reconnect returns tools immediately
|
||||
mock_reconnect.return_value = ReconnectResult(
|
||||
authed=True,
|
||||
|
||||
@ -1,126 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client():
|
||||
"""Fixture: Mock Redis client"""
|
||||
with patch("core.helper.tool_provider_cache.redis_client") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
class TestToolProviderListCache:
|
||||
"""Test class for ToolProviderListCache"""
|
||||
|
||||
def test_generate_cache_key(self):
|
||||
"""Test cache key generation logic"""
|
||||
# Scenario 1: Specify typ (valid literal value)
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "builtin"
|
||||
expected_key = f"tool_providers:tenant_id:{tenant_id}:type:{typ}"
|
||||
assert ToolProviderListCache._generate_cache_key(tenant_id, typ) == expected_key
|
||||
|
||||
# Scenario 2: typ is None (defaults to "all")
|
||||
expected_key_all = f"tool_providers:tenant_id:{tenant_id}:type:all"
|
||||
assert ToolProviderListCache._generate_cache_key(tenant_id) == expected_key_all
|
||||
|
||||
def test_get_cached_providers_hit(self, mock_redis_client):
|
||||
"""Test get cached providers - cache hit and successful decoding"""
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "api"
|
||||
mock_providers = [{"id": "tool", "name": "test_provider"}]
|
||||
mock_redis_client.get.return_value = json.dumps(mock_providers).encode("utf-8")
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers(tenant_id, typ)
|
||||
|
||||
mock_redis_client.get.assert_called_once_with(ToolProviderListCache._generate_cache_key(tenant_id, typ))
|
||||
assert result == mock_providers
|
||||
|
||||
def test_get_cached_providers_decode_error(self, mock_redis_client):
|
||||
"""Test get cached providers - cache hit but decoding failed"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_redis_client.get.return_value = b"invalid_json_data"
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers(tenant_id)
|
||||
|
||||
assert result is None
|
||||
mock_redis_client.get.assert_called_once()
|
||||
|
||||
def test_get_cached_providers_miss(self, mock_redis_client):
|
||||
"""Test get cached providers - cache miss"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers(tenant_id)
|
||||
|
||||
assert result is None
|
||||
mock_redis_client.get.assert_called_once()
|
||||
|
||||
def test_set_cached_providers(self, mock_redis_client):
|
||||
"""Test set cached providers"""
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "builtin"
|
||||
mock_providers = [{"id": "tool", "name": "test_provider"}]
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
|
||||
ToolProviderListCache.set_cached_providers(tenant_id, typ, mock_providers)
|
||||
|
||||
mock_redis_client.setex.assert_called_once_with(
|
||||
cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(mock_providers)
|
||||
)
|
||||
|
||||
def test_invalidate_cache_specific_type(self, mock_redis_client):
|
||||
"""Test invalidate cache - specific type"""
|
||||
tenant_id = "tenant_123"
|
||||
typ: ToolProviderTypeApiLiteral = "workflow"
|
||||
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
|
||||
|
||||
ToolProviderListCache.invalidate_cache(tenant_id, typ)
|
||||
|
||||
mock_redis_client.delete.assert_called_once_with(cache_key)
|
||||
|
||||
def test_invalidate_cache_all_types(self, mock_redis_client):
|
||||
"""Test invalidate cache - clear all tenant cache"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_keys = [
|
||||
b"tool_providers:tenant_id:tenant_123:type:all",
|
||||
b"tool_providers:tenant_id:tenant_123:type:builtin",
|
||||
]
|
||||
mock_redis_client.scan_iter.return_value = mock_keys
|
||||
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
def test_invalidate_cache_no_keys(self, mock_redis_client):
|
||||
"""Test invalidate cache - no cache keys for tenant"""
|
||||
tenant_id = "tenant_123"
|
||||
mock_redis_client.scan_iter.return_value = []
|
||||
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
mock_redis_client.delete.assert_not_called()
|
||||
|
||||
def test_redis_fallback_default_return(self, mock_redis_client):
|
||||
"""Test redis_fallback decorator - default return value (Redis error)"""
|
||||
mock_redis_client.get.side_effect = RedisError("Redis connection error")
|
||||
|
||||
result = ToolProviderListCache.get_cached_providers("tenant_123")
|
||||
|
||||
assert result is None
|
||||
mock_redis_client.get.assert_called_once()
|
||||
|
||||
def test_redis_fallback_no_default(self, mock_redis_client):
|
||||
"""Test redis_fallback decorator - no default return value (Redis error)"""
|
||||
mock_redis_client.setex.side_effect = RedisError("Redis connection error")
|
||||
|
||||
try:
|
||||
ToolProviderListCache.set_cached_providers("tenant_123", "mcp", [])
|
||||
except RedisError:
|
||||
pytest.fail("set_cached_providers should not raise RedisError (handled by fallback)")
|
||||
|
||||
mock_redis_client.setex.assert_called_once()
|
||||
0
api/tests/unit_tests/core/rag/cleaner/__init__.py
Normal file
0
api/tests/unit_tests/core/rag/cleaner/__init__.py
Normal file
213
api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py
Normal file
213
api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py
Normal file
@ -0,0 +1,213 @@
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
|
||||
|
||||
class TestCleanProcessor:
|
||||
"""Test cases for CleanProcessor.clean method."""
|
||||
|
||||
def test_clean_default_removal_of_invalid_symbols(self):
|
||||
"""Test default cleaning removes invalid symbols."""
|
||||
# Test <| replacement
|
||||
assert CleanProcessor.clean("text<|with<|invalid", None) == "text<with<invalid"
|
||||
|
||||
# Test |> replacement
|
||||
assert CleanProcessor.clean("text|>with|>invalid", None) == "text>with>invalid"
|
||||
|
||||
# Test removal of control characters
|
||||
text_with_control = "normal\x00text\x1fwith\x07control\x7fchars"
|
||||
expected = "normaltextwithcontrolchars"
|
||||
assert CleanProcessor.clean(text_with_control, None) == expected
|
||||
|
||||
# Test U+FFFE removal
|
||||
text_with_ufffe = "normal\ufffepadding"
|
||||
expected = "normalpadding"
|
||||
assert CleanProcessor.clean(text_with_ufffe, None) == expected
|
||||
|
||||
def test_clean_with_none_process_rule(self):
|
||||
"""Test cleaning with None process_rule - only default cleaning applied."""
|
||||
text = "Hello<|World\x00"
|
||||
expected = "Hello<World"
|
||||
assert CleanProcessor.clean(text, None) == expected
|
||||
|
||||
def test_clean_with_empty_process_rule(self):
|
||||
"""Test cleaning with empty process_rule dict - only default cleaning applied."""
|
||||
text = "Hello<|World\x00"
|
||||
expected = "Hello<World"
|
||||
assert CleanProcessor.clean(text, {}) == expected
|
||||
|
||||
def test_clean_with_empty_rules(self):
|
||||
"""Test cleaning with empty rules - only default cleaning applied."""
|
||||
text = "Hello<|World\x00"
|
||||
expected = "Hello<World"
|
||||
assert CleanProcessor.clean(text, {"rules": {}}) == expected
|
||||
|
||||
def test_clean_remove_extra_spaces_enabled(self):
|
||||
"""Test remove_extra_spaces rule when enabled."""
|
||||
process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_extra_spaces", "enabled": True}]}}
|
||||
|
||||
# Test multiple newlines reduced to two
|
||||
text = "Line1\n\n\n\n\nLine2"
|
||||
expected = "Line1\n\nLine2"
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
|
||||
# Test various whitespace characters reduced to single space
|
||||
text = "word1\u2000\u2001\t\t \u3000word2"
|
||||
expected = "word1 word2"
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
|
||||
# Test combination of newlines and spaces
|
||||
text = "Line1\n\n\n\n \t Line2"
|
||||
expected = "Line1\n\n Line2"
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
|
||||
def test_clean_remove_extra_spaces_disabled(self):
|
||||
"""Test remove_extra_spaces rule when disabled."""
|
||||
process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_extra_spaces", "enabled": False}]}}
|
||||
|
||||
text = "Line1\n\n\n\n\nLine2 with spaces"
|
||||
# Should only apply default cleaning (no invalid symbols here)
|
||||
assert CleanProcessor.clean(text, process_rule) == text
|
||||
|
||||
def test_clean_remove_urls_emails_enabled(self):
|
||||
"""Test remove_urls_emails rule when enabled."""
|
||||
process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}}
|
||||
|
||||
# Test email removal
|
||||
text = "Contact us at test@example.com for more info"
|
||||
expected = "Contact us at for more info"
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
|
||||
# Test URL removal
|
||||
text = "Visit https://example.com or http://test.org"
|
||||
expected = "Visit or "
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
|
||||
# Test both email and URL
|
||||
text = "Email me@test.com and visit https://site.com"
|
||||
expected = "Email and visit "
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
|
||||
def test_clean_preserve_markdown_links_and_images(self):
|
||||
"""Test that markdown links and images are preserved when removing URLs."""
|
||||
process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}}
|
||||
|
||||
# Test markdown link preservation
|
||||
text = "Check [Google](https://google.com) for info"
|
||||
expected = "Check [Google](https://google.com) for info"
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
|
||||
# Test markdown image preservation
|
||||
text = "Image: "
|
||||
expected = "Image: "
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
|
||||
# Test both link and image preservation
|
||||
text = "[Link](https://link.com) and "
|
||||
expected = "[Link](https://link.com) and "
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
|
||||
# Test that non-markdown URLs are still removed
|
||||
text = "Check [Link](https://keep.com) but remove https://remove.com"
|
||||
expected = "Check [Link](https://keep.com) but remove "
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
|
||||
# Test email removal alongside markdown preservation
|
||||
text = "Email: test@test.com, link: [Click](https://site.com)"
|
||||
expected = "Email: , link: [Click](https://site.com)"
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
|
||||
def test_clean_remove_urls_emails_disabled(self):
|
||||
"""Test remove_urls_emails rule when disabled."""
|
||||
process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": False}]}}
|
||||
|
||||
text = "Email test@example.com visit https://example.com"
|
||||
# Should only apply default cleaning
|
||||
assert CleanProcessor.clean(text, process_rule) == text
|
||||
|
||||
def test_clean_both_rules_enabled(self):
|
||||
"""Test both pre-processing rules enabled together."""
|
||||
process_rule = {
|
||||
"rules": {
|
||||
"pre_processing_rules": [
|
||||
{"id": "remove_extra_spaces", "enabled": True},
|
||||
{"id": "remove_urls_emails", "enabled": True},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
text = "Hello\n\n\n\n World test@example.com \n\n\nhttps://example.com"
|
||||
expected = "Hello\n\n World \n\n"
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
|
||||
def test_clean_with_markdown_link_and_extra_spaces(self):
|
||||
"""Test markdown link preservation with extra spaces removal."""
|
||||
process_rule = {
|
||||
"rules": {
|
||||
"pre_processing_rules": [
|
||||
{"id": "remove_extra_spaces", "enabled": True},
|
||||
{"id": "remove_urls_emails", "enabled": True},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
text = "[Link](https://example.com)\n\n\n\n Text https://remove.com"
|
||||
expected = "[Link](https://example.com)\n\n Text "
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
|
||||
def test_clean_unknown_rule_id_ignored(self):
|
||||
"""Test that unknown rule IDs are silently ignored."""
|
||||
process_rule = {"rules": {"pre_processing_rules": [{"id": "unknown_rule", "enabled": True}]}}
|
||||
|
||||
text = "Hello<|World\x00"
|
||||
expected = "Hello<World"
|
||||
# Only default cleaning should be applied
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
|
||||
def test_clean_empty_text(self):
|
||||
"""Test cleaning empty text."""
|
||||
assert CleanProcessor.clean("", None) == ""
|
||||
assert CleanProcessor.clean("", {}) == ""
|
||||
assert CleanProcessor.clean("", {"rules": {}}) == ""
|
||||
|
||||
def test_clean_text_with_only_invalid_symbols(self):
|
||||
"""Test text containing only invalid symbols."""
|
||||
text = "<|<|\x00\x01\x02\ufffe|>|>"
|
||||
# <| becomes <, |> becomes >, control chars and U+FFFE are removed
|
||||
assert CleanProcessor.clean(text, None) == "<<>>"
|
||||
|
||||
def test_clean_multiple_markdown_links_preserved(self):
|
||||
"""Test multiple markdown links are all preserved."""
|
||||
process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}}
|
||||
|
||||
text = "[One](https://one.com) [Two](http://two.org) [Three](https://three.net)"
|
||||
expected = "[One](https://one.com) [Two](http://two.org) [Three](https://three.net)"
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
|
||||
def test_clean_markdown_link_text_as_url(self):
|
||||
"""Test markdown link where the link text itself is a URL."""
|
||||
process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}}
|
||||
|
||||
# Link text that looks like URL should be preserved
|
||||
text = "[https://text-url.com](https://actual-url.com)"
|
||||
expected = "[https://text-url.com](https://actual-url.com)"
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
|
||||
# Text URL without markdown should be removed
|
||||
text = "https://text-url.com https://actual-url.com"
|
||||
expected = " "
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
|
||||
def test_clean_complex_markdown_link_content(self):
|
||||
"""Test markdown links with complex content - known limitation with brackets in link text."""
|
||||
process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}}
|
||||
|
||||
# Note: The regex pattern [^\]]* cannot handle ] within link text
|
||||
# This is a known limitation - the pattern stops at the first ]
|
||||
text = "[Text with [brackets] and (parens)](https://example.com)"
|
||||
# Actual behavior: only matches up to first ], URL gets removed
|
||||
expected = "[Text with [brackets] and (parens)]("
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
|
||||
# Test that properly formatted markdown links work
|
||||
text = "[Text with (parens) and symbols](https://example.com)"
|
||||
expected = "[Text with (parens) and symbols](https://example.com)"
|
||||
assert CleanProcessor.clean(text, process_rule) == expected
|
||||
186
api/tests/unit_tests/core/rag/extractor/test_pdf_extractor.py
Normal file
186
api/tests/unit_tests/core/rag/extractor/test_pdf_extractor.py
Normal file
@ -0,0 +1,186 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import core.rag.extractor.pdf_extractor as pe
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(monkeypatch):
|
||||
# Mock storage
|
||||
saves = []
|
||||
|
||||
def save(key, data):
|
||||
saves.append((key, data))
|
||||
|
||||
monkeypatch.setattr(pe, "storage", SimpleNamespace(save=save))
|
||||
|
||||
# Mock db
|
||||
class DummySession:
|
||||
def __init__(self):
|
||||
self.added = []
|
||||
self.committed = False
|
||||
|
||||
def add(self, obj):
|
||||
self.added.append(obj)
|
||||
|
||||
def add_all(self, objs):
|
||||
self.added.extend(objs)
|
||||
|
||||
def commit(self):
|
||||
self.committed = True
|
||||
|
||||
db_stub = SimpleNamespace(session=DummySession())
|
||||
monkeypatch.setattr(pe, "db", db_stub)
|
||||
|
||||
# Mock UploadFile
|
||||
class FakeUploadFile:
|
||||
DEFAULT_ID = "test_file_id"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Assign id from DEFAULT_ID, allow override via kwargs if needed
|
||||
self.id = self.DEFAULT_ID
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
monkeypatch.setattr(pe, "UploadFile", FakeUploadFile)
|
||||
|
||||
# Mock config
|
||||
monkeypatch.setattr(pe.dify_config, "FILES_URL", "http://files.local")
|
||||
monkeypatch.setattr(pe.dify_config, "INTERNAL_FILES_URL", None)
|
||||
monkeypatch.setattr(pe.dify_config, "STORAGE_TYPE", "local")
|
||||
|
||||
return SimpleNamespace(saves=saves, db=db_stub, UploadFile=FakeUploadFile)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("image_bytes", "expected_mime", "expected_ext", "file_id"),
|
||||
[
|
||||
(b"\xff\xd8\xff some jpeg", "image/jpeg", "jpg", "test_file_id_jpeg"),
|
||||
(b"\x89PNG\r\n\x1a\n some png", "image/png", "png", "test_file_id_png"),
|
||||
],
|
||||
)
|
||||
def test_extract_images_formats(mock_dependencies, monkeypatch, image_bytes, expected_mime, expected_ext, file_id):
|
||||
saves = mock_dependencies.saves
|
||||
db_stub = mock_dependencies.db
|
||||
|
||||
# Customize FakeUploadFile id for this test case.
|
||||
# Using monkeypatch ensures the class attribute is reset between parameter sets.
|
||||
monkeypatch.setattr(mock_dependencies.UploadFile, "DEFAULT_ID", file_id)
|
||||
|
||||
# Mock page and image objects
|
||||
mock_page = MagicMock()
|
||||
mock_image_obj = MagicMock()
|
||||
|
||||
def mock_extract(buf, fb_format=None):
|
||||
buf.write(image_bytes)
|
||||
|
||||
mock_image_obj.extract.side_effect = mock_extract
|
||||
|
||||
mock_page.get_objects.return_value = [mock_image_obj]
|
||||
|
||||
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1")
|
||||
|
||||
# We need to handle the import inside _extract_images
|
||||
with patch("pypdfium2.raw") as mock_raw:
|
||||
mock_raw.FPDF_PAGEOBJ_IMAGE = 1
|
||||
result = extractor._extract_images(mock_page)
|
||||
|
||||
assert f"" in result
|
||||
assert len(saves) == 1
|
||||
assert saves[0][1] == image_bytes
|
||||
assert len(db_stub.session.added) == 1
|
||||
assert db_stub.session.added[0].tenant_id == "t1"
|
||||
assert db_stub.session.added[0].size == len(image_bytes)
|
||||
assert db_stub.session.added[0].mime_type == expected_mime
|
||||
assert db_stub.session.added[0].extension == expected_ext
|
||||
assert db_stub.session.committed is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("get_objects_side_effect", "get_objects_return_value"),
|
||||
[
|
||||
(None, []), # Empty list
|
||||
(None, None), # None returned
|
||||
(Exception("Failed to get objects"), None), # Exception raised
|
||||
],
|
||||
)
|
||||
def test_extract_images_get_objects_scenarios(mock_dependencies, get_objects_side_effect, get_objects_return_value):
|
||||
mock_page = MagicMock()
|
||||
if get_objects_side_effect:
|
||||
mock_page.get_objects.side_effect = get_objects_side_effect
|
||||
else:
|
||||
mock_page.get_objects.return_value = get_objects_return_value
|
||||
|
||||
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1")
|
||||
|
||||
with patch("pypdfium2.raw") as mock_raw:
|
||||
mock_raw.FPDF_PAGEOBJ_IMAGE = 1
|
||||
result = extractor._extract_images(mock_page)
|
||||
|
||||
assert result == ""
|
||||
|
||||
|
||||
def test_extract_calls_extract_images(mock_dependencies, monkeypatch):
|
||||
# Mock pypdfium2
|
||||
mock_pdf_doc = MagicMock()
|
||||
mock_page = MagicMock()
|
||||
mock_pdf_doc.__iter__.return_value = [mock_page]
|
||||
|
||||
# Mock text extraction
|
||||
mock_text_page = MagicMock()
|
||||
mock_text_page.get_text_range.return_value = "Page text content"
|
||||
mock_page.get_textpage.return_value = mock_text_page
|
||||
|
||||
with patch("pypdfium2.PdfDocument", return_value=mock_pdf_doc):
|
||||
# Mock Blob
|
||||
mock_blob = MagicMock()
|
||||
mock_blob.source = "test.pdf"
|
||||
with patch("core.rag.extractor.pdf_extractor.Blob.from_path", return_value=mock_blob):
|
||||
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1")
|
||||
|
||||
# Mock _extract_images to return a known string
|
||||
monkeypatch.setattr(extractor, "_extract_images", lambda p: "")
|
||||
|
||||
documents = list(extractor.extract())
|
||||
|
||||
assert len(documents) == 1
|
||||
assert "Page text content" in documents[0].page_content
|
||||
assert "" in documents[0].page_content
|
||||
assert documents[0].metadata["page"] == 0
|
||||
|
||||
|
||||
def test_extract_images_failures(mock_dependencies):
|
||||
saves = mock_dependencies.saves
|
||||
db_stub = mock_dependencies.db
|
||||
|
||||
# Mock page and image objects
|
||||
mock_page = MagicMock()
|
||||
mock_image_obj_fail = MagicMock()
|
||||
mock_image_obj_ok = MagicMock()
|
||||
|
||||
# First image raises exception
|
||||
mock_image_obj_fail.extract.side_effect = Exception("Extraction failure")
|
||||
|
||||
# Second image is OK (JPEG)
|
||||
jpeg_bytes = b"\xff\xd8\xff some image data"
|
||||
|
||||
def mock_extract(buf, fb_format=None):
|
||||
buf.write(jpeg_bytes)
|
||||
|
||||
mock_image_obj_ok.extract.side_effect = mock_extract
|
||||
|
||||
mock_page.get_objects.return_value = [mock_image_obj_fail, mock_image_obj_ok]
|
||||
|
||||
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1")
|
||||
|
||||
with patch("pypdfium2.raw") as mock_raw:
|
||||
mock_raw.FPDF_PAGEOBJ_IMAGE = 1
|
||||
result = extractor._extract_images(mock_page)
|
||||
|
||||
# Should have one success
|
||||
assert "" in result
|
||||
assert len(saves) == 1
|
||||
assert saves[0][1] == jpeg_bytes
|
||||
assert db_stub.session.committed is True
|
||||
@ -421,7 +421,18 @@ class TestRetrievalService:
|
||||
# In real code, this waits for all futures to complete
|
||||
# In tests, futures complete immediately, so wait is a no-op
|
||||
with patch("core.rag.datasource.retrieval_service.concurrent.futures.wait"):
|
||||
yield mock_executor
|
||||
# Mock concurrent.futures.as_completed for early error propagation
|
||||
# In real code, this yields futures as they complete
|
||||
# In tests, we yield all futures immediately since they're already done
|
||||
def mock_as_completed(futures_list, timeout=None):
|
||||
"""Mock as_completed that yields futures immediately."""
|
||||
yield from futures_list
|
||||
|
||||
with patch(
|
||||
"core.rag.datasource.retrieval_service.concurrent.futures.as_completed",
|
||||
side_effect=mock_as_completed,
|
||||
):
|
||||
yield mock_executor
|
||||
|
||||
# ==================== Vector Search Tests ====================
|
||||
|
||||
|
||||
@ -0,0 +1 @@
|
||||
"""Tests for graph traversal components."""
|
||||
@ -0,0 +1,307 @@
|
||||
"""Unit tests for skip propagator."""
|
||||
|
||||
from unittest.mock import MagicMock, create_autospec
|
||||
|
||||
from core.workflow.graph import Edge, Graph
|
||||
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
|
||||
from core.workflow.graph_engine.graph_traversal.skip_propagator import SkipPropagator
|
||||
|
||||
|
||||
class TestSkipPropagator:
|
||||
"""Test suite for SkipPropagator."""
|
||||
|
||||
def test_propagate_skip_from_edge_with_unknown_edges_stops_processing(self) -> None:
|
||||
"""When there are unknown incoming edges, propagation should stop."""
|
||||
# Arrange
|
||||
mock_graph = create_autospec(Graph)
|
||||
mock_state_manager = create_autospec(GraphStateManager)
|
||||
|
||||
# Create a mock edge
|
||||
mock_edge = MagicMock(spec=Edge)
|
||||
mock_edge.id = "edge_1"
|
||||
mock_edge.head = "node_2"
|
||||
|
||||
# Setup graph edges dict
|
||||
mock_graph.edges = {"edge_1": mock_edge}
|
||||
|
||||
# Setup incoming edges
|
||||
incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge)]
|
||||
mock_graph.get_incoming_edges.return_value = incoming_edges
|
||||
|
||||
# Setup state manager to return has_unknown=True
|
||||
mock_state_manager.analyze_edge_states.return_value = {
|
||||
"has_unknown": True,
|
||||
"has_taken": False,
|
||||
"all_skipped": False,
|
||||
}
|
||||
|
||||
propagator = SkipPropagator(mock_graph, mock_state_manager)
|
||||
|
||||
# Act
|
||||
propagator.propagate_skip_from_edge("edge_1")
|
||||
|
||||
# Assert
|
||||
mock_graph.get_incoming_edges.assert_called_once_with("node_2")
|
||||
mock_state_manager.analyze_edge_states.assert_called_once_with(incoming_edges)
|
||||
# Should not call any other state manager methods
|
||||
mock_state_manager.enqueue_node.assert_not_called()
|
||||
mock_state_manager.start_execution.assert_not_called()
|
||||
mock_state_manager.mark_node_skipped.assert_not_called()
|
||||
|
||||
def test_propagate_skip_from_edge_with_taken_edge_enqueues_node(self) -> None:
|
||||
"""When there is at least one taken edge, node should be enqueued."""
|
||||
# Arrange
|
||||
mock_graph = create_autospec(Graph)
|
||||
mock_state_manager = create_autospec(GraphStateManager)
|
||||
|
||||
# Create a mock edge
|
||||
mock_edge = MagicMock(spec=Edge)
|
||||
mock_edge.id = "edge_1"
|
||||
mock_edge.head = "node_2"
|
||||
|
||||
mock_graph.edges = {"edge_1": mock_edge}
|
||||
incoming_edges = [MagicMock(spec=Edge)]
|
||||
mock_graph.get_incoming_edges.return_value = incoming_edges
|
||||
|
||||
# Setup state manager to return has_taken=True
|
||||
mock_state_manager.analyze_edge_states.return_value = {
|
||||
"has_unknown": False,
|
||||
"has_taken": True,
|
||||
"all_skipped": False,
|
||||
}
|
||||
|
||||
propagator = SkipPropagator(mock_graph, mock_state_manager)
|
||||
|
||||
# Act
|
||||
propagator.propagate_skip_from_edge("edge_1")
|
||||
|
||||
# Assert
|
||||
mock_state_manager.enqueue_node.assert_called_once_with("node_2")
|
||||
mock_state_manager.start_execution.assert_called_once_with("node_2")
|
||||
mock_state_manager.mark_node_skipped.assert_not_called()
|
||||
|
||||
def test_propagate_skip_from_edge_with_all_skipped_propagates_to_node(self) -> None:
|
||||
"""When all incoming edges are skipped, should propagate skip to node."""
|
||||
# Arrange
|
||||
mock_graph = create_autospec(Graph)
|
||||
mock_state_manager = create_autospec(GraphStateManager)
|
||||
|
||||
# Create a mock edge
|
||||
mock_edge = MagicMock(spec=Edge)
|
||||
mock_edge.id = "edge_1"
|
||||
mock_edge.head = "node_2"
|
||||
|
||||
mock_graph.edges = {"edge_1": mock_edge}
|
||||
incoming_edges = [MagicMock(spec=Edge)]
|
||||
mock_graph.get_incoming_edges.return_value = incoming_edges
|
||||
|
||||
# Setup state manager to return all_skipped=True
|
||||
mock_state_manager.analyze_edge_states.return_value = {
|
||||
"has_unknown": False,
|
||||
"has_taken": False,
|
||||
"all_skipped": True,
|
||||
}
|
||||
|
||||
propagator = SkipPropagator(mock_graph, mock_state_manager)
|
||||
|
||||
# Act
|
||||
propagator.propagate_skip_from_edge("edge_1")
|
||||
|
||||
# Assert
|
||||
mock_state_manager.mark_node_skipped.assert_called_once_with("node_2")
|
||||
mock_state_manager.enqueue_node.assert_not_called()
|
||||
mock_state_manager.start_execution.assert_not_called()
|
||||
|
||||
def test_propagate_skip_to_node_marks_node_and_outgoing_edges_skipped(self) -> None:
|
||||
"""_propagate_skip_to_node should mark node and all outgoing edges as skipped."""
|
||||
# Arrange
|
||||
mock_graph = create_autospec(Graph)
|
||||
mock_state_manager = create_autospec(GraphStateManager)
|
||||
|
||||
# Create outgoing edges
|
||||
edge1 = MagicMock(spec=Edge)
|
||||
edge1.id = "edge_2"
|
||||
edge1.head = "node_downstream_1" # Set head for propagate_skip_from_edge
|
||||
|
||||
edge2 = MagicMock(spec=Edge)
|
||||
edge2.id = "edge_3"
|
||||
edge2.head = "node_downstream_2"
|
||||
|
||||
# Setup graph edges dict for propagate_skip_from_edge
|
||||
mock_graph.edges = {"edge_2": edge1, "edge_3": edge2}
|
||||
mock_graph.get_outgoing_edges.return_value = [edge1, edge2]
|
||||
|
||||
# Setup get_incoming_edges to return empty list to stop recursion
|
||||
mock_graph.get_incoming_edges.return_value = []
|
||||
|
||||
propagator = SkipPropagator(mock_graph, mock_state_manager)
|
||||
|
||||
# Use mock to call private method
|
||||
# Act
|
||||
propagator._propagate_skip_to_node("node_1")
|
||||
|
||||
# Assert
|
||||
mock_state_manager.mark_node_skipped.assert_called_once_with("node_1")
|
||||
mock_state_manager.mark_edge_skipped.assert_any_call("edge_2")
|
||||
mock_state_manager.mark_edge_skipped.assert_any_call("edge_3")
|
||||
assert mock_state_manager.mark_edge_skipped.call_count == 2
|
||||
# Should recursively propagate from each edge
|
||||
# Since propagate_skip_from_edge is called, we need to verify it was called
|
||||
# But we can't directly verify due to recursion. We'll trust the logic.
|
||||
|
||||
def test_skip_branch_paths_marks_unselected_edges_and_propagates(self) -> None:
|
||||
"""skip_branch_paths should mark all unselected edges as skipped and propagate."""
|
||||
# Arrange
|
||||
mock_graph = create_autospec(Graph)
|
||||
mock_state_manager = create_autospec(GraphStateManager)
|
||||
|
||||
# Create unselected edges
|
||||
edge1 = MagicMock(spec=Edge)
|
||||
edge1.id = "edge_1"
|
||||
edge1.head = "node_downstream_1"
|
||||
|
||||
edge2 = MagicMock(spec=Edge)
|
||||
edge2.id = "edge_2"
|
||||
edge2.head = "node_downstream_2"
|
||||
|
||||
unselected_edges = [edge1, edge2]
|
||||
|
||||
# Setup graph edges dict
|
||||
mock_graph.edges = {"edge_1": edge1, "edge_2": edge2}
|
||||
# Setup get_incoming_edges to return empty list to stop recursion
|
||||
mock_graph.get_incoming_edges.return_value = []
|
||||
|
||||
propagator = SkipPropagator(mock_graph, mock_state_manager)
|
||||
|
||||
# Act
|
||||
propagator.skip_branch_paths(unselected_edges)
|
||||
|
||||
# Assert
|
||||
mock_state_manager.mark_edge_skipped.assert_any_call("edge_1")
|
||||
mock_state_manager.mark_edge_skipped.assert_any_call("edge_2")
|
||||
assert mock_state_manager.mark_edge_skipped.call_count == 2
|
||||
# propagate_skip_from_edge should be called for each edge
|
||||
# We can't directly verify due to the mock, but the logic is covered
|
||||
|
||||
def test_propagate_skip_from_edge_recursively_propagates_through_graph(self) -> None:
|
||||
"""Skip propagation should recursively propagate through the graph."""
|
||||
# Arrange
|
||||
mock_graph = create_autospec(Graph)
|
||||
mock_state_manager = create_autospec(GraphStateManager)
|
||||
|
||||
# Create edge chain: edge_1 -> node_2 -> edge_3 -> node_4
|
||||
edge1 = MagicMock(spec=Edge)
|
||||
edge1.id = "edge_1"
|
||||
edge1.head = "node_2"
|
||||
|
||||
edge3 = MagicMock(spec=Edge)
|
||||
edge3.id = "edge_3"
|
||||
edge3.head = "node_4"
|
||||
|
||||
mock_graph.edges = {"edge_1": edge1, "edge_3": edge3}
|
||||
|
||||
# Setup get_incoming_edges to return different values based on node
|
||||
def get_incoming_edges_side_effect(node_id):
|
||||
if node_id == "node_2":
|
||||
return [edge1]
|
||||
elif node_id == "node_4":
|
||||
return [edge3]
|
||||
return []
|
||||
|
||||
mock_graph.get_incoming_edges.side_effect = get_incoming_edges_side_effect
|
||||
|
||||
# Setup get_outgoing_edges to return different values based on node
|
||||
def get_outgoing_edges_side_effect(node_id):
|
||||
if node_id == "node_2":
|
||||
return [edge3]
|
||||
elif node_id == "node_4":
|
||||
return [] # No outgoing edges, stops recursion
|
||||
return []
|
||||
|
||||
mock_graph.get_outgoing_edges.side_effect = get_outgoing_edges_side_effect
|
||||
|
||||
# Setup state manager to return all_skipped for both nodes
|
||||
mock_state_manager.analyze_edge_states.return_value = {
|
||||
"has_unknown": False,
|
||||
"has_taken": False,
|
||||
"all_skipped": True,
|
||||
}
|
||||
|
||||
propagator = SkipPropagator(mock_graph, mock_state_manager)
|
||||
|
||||
# Act
|
||||
propagator.propagate_skip_from_edge("edge_1")
|
||||
|
||||
# Assert
|
||||
# Should mark node_2 as skipped
|
||||
mock_state_manager.mark_node_skipped.assert_any_call("node_2")
|
||||
# Should mark edge_3 as skipped
|
||||
mock_state_manager.mark_edge_skipped.assert_any_call("edge_3")
|
||||
# Should propagate to node_4
|
||||
mock_state_manager.mark_node_skipped.assert_any_call("node_4")
|
||||
assert mock_state_manager.mark_node_skipped.call_count == 2
|
||||
|
||||
def test_propagate_skip_from_edge_with_mixed_edge_states_handles_correctly(self) -> None:
|
||||
"""Test with mixed edge states (some unknown, some taken, some skipped)."""
|
||||
# Arrange
|
||||
mock_graph = create_autospec(Graph)
|
||||
mock_state_manager = create_autospec(GraphStateManager)
|
||||
|
||||
mock_edge = MagicMock(spec=Edge)
|
||||
mock_edge.id = "edge_1"
|
||||
mock_edge.head = "node_2"
|
||||
|
||||
mock_graph.edges = {"edge_1": mock_edge}
|
||||
incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge), MagicMock(spec=Edge)]
|
||||
mock_graph.get_incoming_edges.return_value = incoming_edges
|
||||
|
||||
# Test 1: has_unknown=True, has_taken=False, all_skipped=False
|
||||
mock_state_manager.analyze_edge_states.return_value = {
|
||||
"has_unknown": True,
|
||||
"has_taken": False,
|
||||
"all_skipped": False,
|
||||
}
|
||||
|
||||
propagator = SkipPropagator(mock_graph, mock_state_manager)
|
||||
|
||||
# Act
|
||||
propagator.propagate_skip_from_edge("edge_1")
|
||||
|
||||
# Assert - should stop processing
|
||||
mock_state_manager.enqueue_node.assert_not_called()
|
||||
mock_state_manager.mark_node_skipped.assert_not_called()
|
||||
|
||||
# Reset mocks for next test
|
||||
mock_state_manager.reset_mock()
|
||||
mock_graph.reset_mock()
|
||||
|
||||
# Test 2: has_unknown=False, has_taken=True, all_skipped=False
|
||||
mock_state_manager.analyze_edge_states.return_value = {
|
||||
"has_unknown": False,
|
||||
"has_taken": True,
|
||||
"all_skipped": False,
|
||||
}
|
||||
|
||||
# Act
|
||||
propagator.propagate_skip_from_edge("edge_1")
|
||||
|
||||
# Assert - should enqueue node
|
||||
mock_state_manager.enqueue_node.assert_called_once_with("node_2")
|
||||
mock_state_manager.start_execution.assert_called_once_with("node_2")
|
||||
|
||||
# Reset mocks for next test
|
||||
mock_state_manager.reset_mock()
|
||||
mock_graph.reset_mock()
|
||||
|
||||
# Test 3: has_unknown=False, has_taken=False, all_skipped=True
|
||||
mock_state_manager.analyze_edge_states.return_value = {
|
||||
"has_unknown": False,
|
||||
"has_taken": False,
|
||||
"all_skipped": True,
|
||||
}
|
||||
|
||||
# Act
|
||||
propagator.propagate_skip_from_edge("edge_1")
|
||||
|
||||
# Assert - should propagate skip
|
||||
mock_state_manager.mark_node_skipped.assert_called_once_with("node_2")
|
||||
@ -8,11 +8,12 @@ class TestCelerySSLConfiguration:
|
||||
"""Test suite for Celery SSL configuration."""
|
||||
|
||||
def test_get_celery_ssl_options_when_ssl_disabled(self):
|
||||
"""Test SSL options when REDIS_USE_SSL is False."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.REDIS_USE_SSL = False
|
||||
"""Test SSL options when BROKER_USE_SSL is False."""
|
||||
from configs import DifyConfig
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||
dify_config = DifyConfig(CELERY_BROKER_URL="redis://localhost:6379/0")
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", dify_config):
|
||||
from extensions.ext_celery import _get_celery_ssl_options
|
||||
|
||||
result = _get_celery_ssl_options()
|
||||
@ -21,7 +22,6 @@ class TestCelerySSLConfiguration:
|
||||
def test_get_celery_ssl_options_when_broker_not_redis(self):
|
||||
"""Test SSL options when broker is not Redis."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.REDIS_USE_SSL = True
|
||||
mock_config.CELERY_BROKER_URL = "amqp://localhost:5672"
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||
@ -33,7 +33,6 @@ class TestCelerySSLConfiguration:
|
||||
def test_get_celery_ssl_options_with_cert_none(self):
|
||||
"""Test SSL options with CERT_NONE requirement."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.REDIS_USE_SSL = True
|
||||
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
|
||||
mock_config.REDIS_SSL_CERT_REQS = "CERT_NONE"
|
||||
mock_config.REDIS_SSL_CA_CERTS = None
|
||||
@ -53,7 +52,6 @@ class TestCelerySSLConfiguration:
|
||||
def test_get_celery_ssl_options_with_cert_required(self):
|
||||
"""Test SSL options with CERT_REQUIRED and certificates."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.REDIS_USE_SSL = True
|
||||
mock_config.CELERY_BROKER_URL = "rediss://localhost:6380/0"
|
||||
mock_config.REDIS_SSL_CERT_REQS = "CERT_REQUIRED"
|
||||
mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt"
|
||||
@ -73,7 +71,6 @@ class TestCelerySSLConfiguration:
|
||||
def test_get_celery_ssl_options_with_cert_optional(self):
|
||||
"""Test SSL options with CERT_OPTIONAL requirement."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.REDIS_USE_SSL = True
|
||||
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
|
||||
mock_config.REDIS_SSL_CERT_REQS = "CERT_OPTIONAL"
|
||||
mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt"
|
||||
@ -91,7 +88,6 @@ class TestCelerySSLConfiguration:
|
||||
def test_get_celery_ssl_options_with_invalid_cert_reqs(self):
|
||||
"""Test SSL options with invalid cert requirement defaults to CERT_NONE."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.REDIS_USE_SSL = True
|
||||
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
|
||||
mock_config.REDIS_SSL_CERT_REQS = "INVALID_VALUE"
|
||||
mock_config.REDIS_SSL_CA_CERTS = None
|
||||
@ -108,7 +104,6 @@ class TestCelerySSLConfiguration:
|
||||
def test_celery_init_applies_ssl_to_broker_and_backend(self):
|
||||
"""Test that SSL options are applied to both broker and backend when using Redis."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.REDIS_USE_SSL = True
|
||||
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
|
||||
mock_config.CELERY_BACKEND = "redis"
|
||||
mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0"
|
||||
|
||||
272
api/tests/unit_tests/libs/test_archive_storage.py
Normal file
272
api/tests/unit_tests/libs/test_archive_storage.py
Normal file
@ -0,0 +1,272 @@
|
||||
import base64
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from unittest.mock import ANY, MagicMock
|
||||
|
||||
import pytest
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from libs import archive_storage as storage_module
|
||||
from libs.archive_storage import (
|
||||
ArchiveStorage,
|
||||
ArchiveStorageError,
|
||||
ArchiveStorageNotConfiguredError,
|
||||
)
|
||||
|
||||
BUCKET_NAME = "archive-bucket"
|
||||
|
||||
|
||||
def _configure_storage(monkeypatch, **overrides):
|
||||
defaults = {
|
||||
"ARCHIVE_STORAGE_ENABLED": True,
|
||||
"ARCHIVE_STORAGE_ENDPOINT": "https://storage.example.com",
|
||||
"ARCHIVE_STORAGE_ARCHIVE_BUCKET": BUCKET_NAME,
|
||||
"ARCHIVE_STORAGE_ACCESS_KEY": "access",
|
||||
"ARCHIVE_STORAGE_SECRET_KEY": "secret",
|
||||
"ARCHIVE_STORAGE_REGION": "auto",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
for key, value in defaults.items():
|
||||
monkeypatch.setattr(storage_module.dify_config, key, value, raising=False)
|
||||
|
||||
|
||||
def _client_error(code: str) -> ClientError:
|
||||
return ClientError({"Error": {"Code": code}}, "Operation")
|
||||
|
||||
|
||||
def _mock_client(monkeypatch):
|
||||
client = MagicMock()
|
||||
client.head_bucket.return_value = None
|
||||
boto_client = MagicMock(return_value=client)
|
||||
monkeypatch.setattr(storage_module.boto3, "client", boto_client)
|
||||
return client, boto_client
|
||||
|
||||
|
||||
def test_init_disabled(monkeypatch):
|
||||
_configure_storage(monkeypatch, ARCHIVE_STORAGE_ENABLED=False)
|
||||
with pytest.raises(ArchiveStorageNotConfiguredError, match="not enabled"):
|
||||
ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
|
||||
def test_init_missing_config(monkeypatch):
|
||||
_configure_storage(monkeypatch, ARCHIVE_STORAGE_ENDPOINT=None)
|
||||
with pytest.raises(ArchiveStorageNotConfiguredError, match="incomplete"):
|
||||
ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
|
||||
def test_init_bucket_not_found(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
client.head_bucket.side_effect = _client_error("404")
|
||||
|
||||
with pytest.raises(ArchiveStorageNotConfiguredError, match="does not exist"):
|
||||
ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
|
||||
def test_init_bucket_access_denied(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
client.head_bucket.side_effect = _client_error("403")
|
||||
|
||||
with pytest.raises(ArchiveStorageNotConfiguredError, match="Access denied"):
|
||||
ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
|
||||
def test_init_bucket_other_error(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
client.head_bucket.side_effect = _client_error("500")
|
||||
|
||||
with pytest.raises(ArchiveStorageError, match="Failed to access archive bucket"):
|
||||
ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
|
||||
def test_init_sets_client(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, boto_client = _mock_client(monkeypatch)
|
||||
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
boto_client.assert_called_once_with(
|
||||
"s3",
|
||||
endpoint_url="https://storage.example.com",
|
||||
aws_access_key_id="access",
|
||||
aws_secret_access_key="secret",
|
||||
region_name="auto",
|
||||
config=ANY,
|
||||
)
|
||||
assert storage.client is client
|
||||
assert storage.bucket == BUCKET_NAME
|
||||
|
||||
|
||||
def test_put_object_returns_checksum(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
data = b"hello"
|
||||
checksum = storage.put_object("key", data)
|
||||
|
||||
expected_md5 = hashlib.md5(data).hexdigest()
|
||||
expected_content_md5 = base64.b64encode(hashlib.md5(data).digest()).decode()
|
||||
client.put_object.assert_called_once_with(
|
||||
Bucket="archive-bucket",
|
||||
Key="key",
|
||||
Body=data,
|
||||
ContentMD5=expected_content_md5,
|
||||
)
|
||||
assert checksum == expected_md5
|
||||
|
||||
|
||||
def test_put_object_raises_on_error(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
client.put_object.side_effect = _client_error("500")
|
||||
|
||||
with pytest.raises(ArchiveStorageError, match="Failed to upload object"):
|
||||
storage.put_object("key", b"data")
|
||||
|
||||
|
||||
def test_get_object_returns_bytes(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
body = MagicMock()
|
||||
body.read.return_value = b"payload"
|
||||
client.get_object.return_value = {"Body": body}
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
assert storage.get_object("key") == b"payload"
|
||||
|
||||
|
||||
def test_get_object_missing(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
client.get_object.side_effect = _client_error("NoSuchKey")
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="Archive object not found"):
|
||||
storage.get_object("missing")
|
||||
|
||||
|
||||
def test_get_object_stream(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
body = MagicMock()
|
||||
body.iter_chunks.return_value = [b"a", b"b"]
|
||||
client.get_object.return_value = {"Body": body}
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
assert list(storage.get_object_stream("key")) == [b"a", b"b"]
|
||||
|
||||
|
||||
def test_get_object_stream_missing(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
client.get_object.side_effect = _client_error("NoSuchKey")
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="Archive object not found"):
|
||||
list(storage.get_object_stream("missing"))
|
||||
|
||||
|
||||
def test_object_exists(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
assert storage.object_exists("key") is True
|
||||
client.head_object.side_effect = _client_error("404")
|
||||
assert storage.object_exists("missing") is False
|
||||
|
||||
|
||||
def test_delete_object_error(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
client.delete_object.side_effect = _client_error("500")
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
with pytest.raises(ArchiveStorageError, match="Failed to delete object"):
|
||||
storage.delete_object("key")
|
||||
|
||||
|
||||
def test_list_objects(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
paginator = MagicMock()
|
||||
paginator.paginate.return_value = [
|
||||
{"Contents": [{"Key": "a"}, {"Key": "b"}]},
|
||||
{"Contents": [{"Key": "c"}]},
|
||||
]
|
||||
client.get_paginator.return_value = paginator
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
assert storage.list_objects("prefix") == ["a", "b", "c"]
|
||||
paginator.paginate.assert_called_once_with(Bucket="archive-bucket", Prefix="prefix")
|
||||
|
||||
|
||||
def test_list_objects_error(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
paginator = MagicMock()
|
||||
paginator.paginate.side_effect = _client_error("500")
|
||||
client.get_paginator.return_value = paginator
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
with pytest.raises(ArchiveStorageError, match="Failed to list objects"):
|
||||
storage.list_objects("prefix")
|
||||
|
||||
|
||||
def test_generate_presigned_url(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
client.generate_presigned_url.return_value = "http://signed-url"
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
url = storage.generate_presigned_url("key", expires_in=123)
|
||||
|
||||
client.generate_presigned_url.assert_called_once_with(
|
||||
ClientMethod="get_object",
|
||||
Params={"Bucket": "archive-bucket", "Key": "key"},
|
||||
ExpiresIn=123,
|
||||
)
|
||||
assert url == "http://signed-url"
|
||||
|
||||
|
||||
def test_generate_presigned_url_error(monkeypatch):
|
||||
_configure_storage(monkeypatch)
|
||||
client, _ = _mock_client(monkeypatch)
|
||||
client.generate_presigned_url.side_effect = _client_error("500")
|
||||
storage = ArchiveStorage(bucket=BUCKET_NAME)
|
||||
|
||||
with pytest.raises(ArchiveStorageError, match="Failed to generate pre-signed URL"):
|
||||
storage.generate_presigned_url("key")
|
||||
|
||||
|
||||
def test_serialization_roundtrip():
|
||||
records = [
|
||||
{
|
||||
"id": "1",
|
||||
"created_at": datetime(2024, 1, 1, 12, 0, 0),
|
||||
"payload": {"nested": "value"},
|
||||
"items": [{"name": "a"}],
|
||||
},
|
||||
{"id": "2", "value": 123},
|
||||
]
|
||||
|
||||
data = ArchiveStorage.serialize_to_jsonl_gz(records)
|
||||
decoded = ArchiveStorage.deserialize_from_jsonl_gz(data)
|
||||
|
||||
assert decoded[0]["id"] == "1"
|
||||
assert decoded[0]["payload"]["nested"] == "value"
|
||||
assert decoded[0]["items"][0]["name"] == "a"
|
||||
assert "2024-01-01T12:00:00" in decoded[0]["created_at"]
|
||||
assert decoded[1]["value"] == 123
|
||||
|
||||
|
||||
def test_content_md5_matches_checksum():
|
||||
data = b"checksum"
|
||||
expected = base64.b64encode(hashlib.md5(data).digest()).decode()
|
||||
|
||||
assert ArchiveStorage._content_md5(data) == expected
|
||||
assert ArchiveStorage.compute_checksum(data) == hashlib.md5(data).hexdigest()
|
||||
@ -1,4 +1,4 @@
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from qcloud_cos import CosConfig
|
||||
@ -18,3 +18,72 @@ class TestTencentCos(BaseStorageTest):
|
||||
with patch.object(CosConfig, "__init__", return_value=None):
|
||||
self.storage = TencentCosStorage()
|
||||
self.storage.bucket_name = get_example_bucket()
|
||||
|
||||
|
||||
class TestTencentCosConfiguration:
|
||||
"""Tests for TencentCosStorage initialization with different configurations."""
|
||||
|
||||
def test_init_with_custom_domain(self):
|
||||
"""Test initialization with custom domain configured."""
|
||||
# Mock dify_config to return custom domain configuration
|
||||
mock_dify_config = MagicMock()
|
||||
mock_dify_config.TENCENT_COS_CUSTOM_DOMAIN = "cos.example.com"
|
||||
mock_dify_config.TENCENT_COS_SECRET_ID = "test-secret-id"
|
||||
mock_dify_config.TENCENT_COS_SECRET_KEY = "test-secret-key"
|
||||
mock_dify_config.TENCENT_COS_SCHEME = "https"
|
||||
|
||||
# Mock CosConfig and CosS3Client
|
||||
mock_config_instance = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
|
||||
with (
|
||||
patch("extensions.storage.tencent_cos_storage.dify_config", mock_dify_config),
|
||||
patch(
|
||||
"extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance
|
||||
) as mock_cos_config,
|
||||
patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client),
|
||||
):
|
||||
TencentCosStorage()
|
||||
|
||||
# Verify CosConfig was called with Domain parameter (not Region)
|
||||
mock_cos_config.assert_called_once()
|
||||
call_kwargs = mock_cos_config.call_args[1]
|
||||
assert "Domain" in call_kwargs
|
||||
assert call_kwargs["Domain"] == "cos.example.com"
|
||||
assert "Region" not in call_kwargs
|
||||
assert call_kwargs["SecretId"] == "test-secret-id"
|
||||
assert call_kwargs["SecretKey"] == "test-secret-key"
|
||||
assert call_kwargs["Scheme"] == "https"
|
||||
|
||||
def test_init_with_region(self):
|
||||
"""Test initialization with region configured (no custom domain)."""
|
||||
# Mock dify_config to return region configuration
|
||||
mock_dify_config = MagicMock()
|
||||
mock_dify_config.TENCENT_COS_CUSTOM_DOMAIN = None
|
||||
mock_dify_config.TENCENT_COS_REGION = "ap-guangzhou"
|
||||
mock_dify_config.TENCENT_COS_SECRET_ID = "test-secret-id"
|
||||
mock_dify_config.TENCENT_COS_SECRET_KEY = "test-secret-key"
|
||||
mock_dify_config.TENCENT_COS_SCHEME = "https"
|
||||
|
||||
# Mock CosConfig and CosS3Client
|
||||
mock_config_instance = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
|
||||
with (
|
||||
patch("extensions.storage.tencent_cos_storage.dify_config", mock_dify_config),
|
||||
patch(
|
||||
"extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance
|
||||
) as mock_cos_config,
|
||||
patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client),
|
||||
):
|
||||
TencentCosStorage()
|
||||
|
||||
# Verify CosConfig was called with Region parameter (not Domain)
|
||||
mock_cos_config.assert_called_once()
|
||||
call_kwargs = mock_cos_config.call_args[1]
|
||||
assert "Region" in call_kwargs
|
||||
assert call_kwargs["Region"] == "ap-guangzhou"
|
||||
assert "Domain" not in call_kwargs
|
||||
assert call_kwargs["SecretId"] == "test-secret-id"
|
||||
assert call_kwargs["SecretKey"] == "test-secret-key"
|
||||
assert call_kwargs["Scheme"] == "https"
|
||||
|
||||
@ -1294,6 +1294,42 @@ class TestBillingServiceSubscriptionOperations:
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
def test_get_plan_bulk_with_invalid_tenant_plan_skipped(self, mock_send_request):
|
||||
"""Test bulk plan retrieval when one tenant has invalid plan data (should skip that tenant)."""
|
||||
# Arrange
|
||||
tenant_ids = ["tenant-valid-1", "tenant-invalid", "tenant-valid-2"]
|
||||
|
||||
# Response with one invalid tenant plan (missing expiration_date) and two valid ones
|
||||
mock_send_request.return_value = {
|
||||
"data": {
|
||||
"tenant-valid-1": {"plan": "sandbox", "expiration_date": 1735689600},
|
||||
"tenant-invalid": {"plan": "professional"}, # Missing expiration_date field
|
||||
"tenant-valid-2": {"plan": "team", "expiration_date": 1767225600},
|
||||
}
|
||||
}
|
||||
|
||||
# Act
|
||||
with patch("services.billing_service.logger") as mock_logger:
|
||||
result = BillingService.get_plan_bulk(tenant_ids)
|
||||
|
||||
# Assert - should only contain valid tenants
|
||||
assert len(result) == 2
|
||||
assert "tenant-valid-1" in result
|
||||
assert "tenant-valid-2" in result
|
||||
assert "tenant-invalid" not in result
|
||||
|
||||
# Verify valid tenants have correct data
|
||||
assert result["tenant-valid-1"]["plan"] == "sandbox"
|
||||
assert result["tenant-valid-1"]["expiration_date"] == 1735689600
|
||||
assert result["tenant-valid-2"]["plan"] == "team"
|
||||
assert result["tenant-valid-2"]["expiration_date"] == 1767225600
|
||||
|
||||
# Verify exception was logged for the invalid tenant
|
||||
mock_logger.exception.assert_called_once()
|
||||
log_call_args = mock_logger.exception.call_args[0]
|
||||
assert "get_plan_bulk: failed to validate subscription plan for tenant" in log_call_args[0]
|
||||
assert "tenant-invalid" in log_call_args[1]
|
||||
|
||||
def test_get_expired_subscription_cleanup_whitelist_success(self, mock_send_request):
|
||||
"""Test successful retrieval of expired subscription cleanup whitelist."""
|
||||
# Arrange
|
||||
|
||||
@ -15,6 +15,11 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
("", ""),
|
||||
(" ", " "),
|
||||
("【测试】", "【测试】"),
|
||||
# Markdown link preservation - should be preserved if text starts with a markdown link
|
||||
("[Google](https://google.com) is a search engine", "[Google](https://google.com) is a search engine"),
|
||||
("[Example](http://example.com) some text", "[Example](http://example.com) some text"),
|
||||
# Leading symbols before markdown link are removed, including the opening bracket [
|
||||
("@[Test](https://example.com)", "Test](https://example.com)"),
|
||||
],
|
||||
)
|
||||
def test_remove_leading_symbols(input_text, expected_output):
|
||||
|
||||
@ -447,6 +447,15 @@ S3_SECRET_KEY=
|
||||
# If set to false, the access key and secret key must be provided.
|
||||
S3_USE_AWS_MANAGED_IAM=false
|
||||
|
||||
# Workflow run and Conversation archive storage (S3-compatible)
|
||||
ARCHIVE_STORAGE_ENABLED=false
|
||||
ARCHIVE_STORAGE_ENDPOINT=
|
||||
ARCHIVE_STORAGE_ARCHIVE_BUCKET=
|
||||
ARCHIVE_STORAGE_EXPORT_BUCKET=
|
||||
ARCHIVE_STORAGE_ACCESS_KEY=
|
||||
ARCHIVE_STORAGE_SECRET_KEY=
|
||||
ARCHIVE_STORAGE_REGION=auto
|
||||
|
||||
# Azure Blob Configuration
|
||||
#
|
||||
AZURE_BLOB_ACCOUNT_NAME=difyai
|
||||
@ -478,6 +487,7 @@ TENCENT_COS_SECRET_KEY=your-secret-key
|
||||
TENCENT_COS_SECRET_ID=your-secret-id
|
||||
TENCENT_COS_REGION=your-region
|
||||
TENCENT_COS_SCHEME=your-scheme
|
||||
TENCENT_COS_CUSTOM_DOMAIN=your-custom-domain
|
||||
|
||||
# Oracle Storage Configuration
|
||||
#
|
||||
|
||||
@ -122,6 +122,13 @@ x-shared-env: &shared-api-worker-env
|
||||
S3_ACCESS_KEY: ${S3_ACCESS_KEY:-}
|
||||
S3_SECRET_KEY: ${S3_SECRET_KEY:-}
|
||||
S3_USE_AWS_MANAGED_IAM: ${S3_USE_AWS_MANAGED_IAM:-false}
|
||||
ARCHIVE_STORAGE_ENABLED: ${ARCHIVE_STORAGE_ENABLED:-false}
|
||||
ARCHIVE_STORAGE_ENDPOINT: ${ARCHIVE_STORAGE_ENDPOINT:-}
|
||||
ARCHIVE_STORAGE_ARCHIVE_BUCKET: ${ARCHIVE_STORAGE_ARCHIVE_BUCKET:-}
|
||||
ARCHIVE_STORAGE_EXPORT_BUCKET: ${ARCHIVE_STORAGE_EXPORT_BUCKET:-}
|
||||
ARCHIVE_STORAGE_ACCESS_KEY: ${ARCHIVE_STORAGE_ACCESS_KEY:-}
|
||||
ARCHIVE_STORAGE_SECRET_KEY: ${ARCHIVE_STORAGE_SECRET_KEY:-}
|
||||
ARCHIVE_STORAGE_REGION: ${ARCHIVE_STORAGE_REGION:-auto}
|
||||
AZURE_BLOB_ACCOUNT_NAME: ${AZURE_BLOB_ACCOUNT_NAME:-difyai}
|
||||
AZURE_BLOB_ACCOUNT_KEY: ${AZURE_BLOB_ACCOUNT_KEY:-difyai}
|
||||
AZURE_BLOB_CONTAINER_NAME: ${AZURE_BLOB_CONTAINER_NAME:-difyai-container}
|
||||
@ -141,6 +148,7 @@ x-shared-env: &shared-api-worker-env
|
||||
TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id}
|
||||
TENCENT_COS_REGION: ${TENCENT_COS_REGION:-your-region}
|
||||
TENCENT_COS_SCHEME: ${TENCENT_COS_SCHEME:-your-scheme}
|
||||
TENCENT_COS_CUSTOM_DOMAIN: ${TENCENT_COS_CUSTOM_DOMAIN:-your-custom-domain}
|
||||
OCI_ENDPOINT: ${OCI_ENDPOINT:-https://your-object-storage-namespace.compat.objectstorage.us-ashburn-1.oraclecloud.com}
|
||||
OCI_BUCKET_NAME: ${OCI_BUCKET_NAME:-your-bucket-name}
|
||||
OCI_ACCESS_KEY: ${OCI_ACCESS_KEY:-your-access-key}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import type { Plan, UsagePlanInfo } from '@/app/components/billing/type'
|
||||
import type { ProviderContextState } from '@/context/provider-context'
|
||||
import { merge, noop } from 'es-toolkit/compat'
|
||||
import { merge } from 'es-toolkit/compat'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
import { defaultPlan } from '@/app/components/billing/config'
|
||||
|
||||
// Avoid being mocked in tests
|
||||
|
||||
@ -3,7 +3,7 @@ import path from 'node:path'
|
||||
import vm from 'node:vm'
|
||||
import { transpile } from 'typescript'
|
||||
|
||||
describe('check-i18n script functionality', () => {
|
||||
describe('i18n:check script functionality', () => {
|
||||
const testDir = path.join(__dirname, '../i18n-test')
|
||||
const testEnDir = path.join(testDir, 'en-US')
|
||||
const testZhDir = path.join(testDir, 'zh-Hans')
|
||||
|
||||
@ -16,7 +16,7 @@ const getSupportedLocales = (): string[] => {
|
||||
|
||||
// Helper function to load translation file content
|
||||
const loadTranslationContent = (locale: string): string => {
|
||||
const filePath = path.join(I18N_DIR, locale, 'app-debug.ts')
|
||||
const filePath = path.join(I18N_DIR, locale, 'app-debug.json')
|
||||
|
||||
if (!fs.existsSync(filePath))
|
||||
throw new Error(`Translation file not found: ${filePath}`)
|
||||
@ -24,14 +24,14 @@ const loadTranslationContent = (locale: string): string => {
|
||||
return fs.readFileSync(filePath, 'utf-8')
|
||||
}
|
||||
|
||||
// Helper function to check if upload features exist
|
||||
// Helper function to check if upload features exist (supports flattened JSON)
|
||||
const hasUploadFeatures = (content: string): { [key: string]: boolean } => {
|
||||
return {
|
||||
fileUpload: /fileUpload\s*:\s*\{/.test(content),
|
||||
imageUpload: /imageUpload\s*:\s*\{/.test(content),
|
||||
documentUpload: /documentUpload\s*:\s*\{/.test(content),
|
||||
audioUpload: /audioUpload\s*:\s*\{/.test(content),
|
||||
featureBar: /bar\s*:\s*\{/.test(content),
|
||||
fileUpload: /"feature\.fileUpload\.title"/.test(content),
|
||||
imageUpload: /"feature\.imageUpload\.title"/.test(content),
|
||||
documentUpload: /"feature\.documentUpload\.title"/.test(content),
|
||||
audioUpload: /"feature\.audioUpload\.title"/.test(content),
|
||||
featureBar: /"feature\.bar\.empty"/.test(content),
|
||||
}
|
||||
}
|
||||
|
||||
@ -45,7 +45,7 @@ describe('Upload Features i18n Translations - Issue #23062', () => {
|
||||
|
||||
it('all locales should have translation files', () => {
|
||||
supportedLocales.forEach((locale) => {
|
||||
const filePath = path.join(I18N_DIR, locale, 'app-debug.ts')
|
||||
const filePath = path.join(I18N_DIR, locale, 'app-debug.json')
|
||||
expect(fs.existsSync(filePath)).toBe(true)
|
||||
})
|
||||
})
|
||||
@ -76,12 +76,9 @@ describe('Upload Features i18n Translations - Issue #23062', () => {
|
||||
previouslyMissingLocales.forEach((locale) => {
|
||||
const content = loadTranslationContent(locale)
|
||||
|
||||
// Verify audioUpload exists
|
||||
expect(/audioUpload\s*:\s*\{/.test(content)).toBe(true)
|
||||
|
||||
// Verify it has title and description
|
||||
expect(/audioUpload[^}]*title\s*:/.test(content)).toBe(true)
|
||||
expect(/audioUpload[^}]*description\s*:/.test(content)).toBe(true)
|
||||
// Verify audioUpload exists with title and description (flattened JSON format)
|
||||
expect(/"feature\.audioUpload\.title"/.test(content)).toBe(true)
|
||||
expect(/"feature\.audioUpload\.description"/.test(content)).toBe(true)
|
||||
|
||||
console.log(`✅ ${locale} - Issue #23062 resolved: audioUpload feature present`)
|
||||
})
|
||||
@ -91,28 +88,28 @@ describe('Upload Features i18n Translations - Issue #23062', () => {
|
||||
supportedLocales.forEach((locale) => {
|
||||
const content = loadTranslationContent(locale)
|
||||
|
||||
// Check fileUpload has required properties
|
||||
if (/fileUpload\s*:\s*\{/.test(content)) {
|
||||
expect(/fileUpload[^}]*title\s*:/.test(content)).toBe(true)
|
||||
expect(/fileUpload[^}]*description\s*:/.test(content)).toBe(true)
|
||||
// Check fileUpload has required properties (flattened JSON format)
|
||||
if (/"feature\.fileUpload\.title"/.test(content)) {
|
||||
expect(/"feature\.fileUpload\.title"/.test(content)).toBe(true)
|
||||
expect(/"feature\.fileUpload\.description"/.test(content)).toBe(true)
|
||||
}
|
||||
|
||||
// Check imageUpload has required properties
|
||||
if (/imageUpload\s*:\s*\{/.test(content)) {
|
||||
expect(/imageUpload[^}]*title\s*:/.test(content)).toBe(true)
|
||||
expect(/imageUpload[^}]*description\s*:/.test(content)).toBe(true)
|
||||
if (/"feature\.imageUpload\.title"/.test(content)) {
|
||||
expect(/"feature\.imageUpload\.title"/.test(content)).toBe(true)
|
||||
expect(/"feature\.imageUpload\.description"/.test(content)).toBe(true)
|
||||
}
|
||||
|
||||
// Check documentUpload has required properties
|
||||
if (/documentUpload\s*:\s*\{/.test(content)) {
|
||||
expect(/documentUpload[^}]*title\s*:/.test(content)).toBe(true)
|
||||
expect(/documentUpload[^}]*description\s*:/.test(content)).toBe(true)
|
||||
if (/"feature\.documentUpload\.title"/.test(content)) {
|
||||
expect(/"feature\.documentUpload\.title"/.test(content)).toBe(true)
|
||||
expect(/"feature\.documentUpload\.description"/.test(content)).toBe(true)
|
||||
}
|
||||
|
||||
// Check audioUpload has required properties
|
||||
if (/audioUpload\s*:\s*\{/.test(content)) {
|
||||
expect(/audioUpload[^}]*title\s*:/.test(content)).toBe(true)
|
||||
expect(/audioUpload[^}]*description\s*:/.test(content)).toBe(true)
|
||||
if (/"feature\.audioUpload\.title"/.test(content)) {
|
||||
expect(/"feature\.audioUpload\.title"/.test(content)).toBe(true)
|
||||
expect(/"feature\.audioUpload\.description"/.test(content)).toBe(true)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
@ -64,7 +64,6 @@ vi.mock('i18next', () => ({
|
||||
|
||||
// Mock the useConfig hook
|
||||
vi.mock('@/app/components/workflow/nodes/iteration/use-config', () => ({
|
||||
__esModule: true,
|
||||
default: () => ({
|
||||
inputs: {
|
||||
is_parallel: true,
|
||||
|
||||
@ -70,7 +70,7 @@ const AppDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
|
||||
const navConfig = [
|
||||
...(isCurrentWorkspaceEditor
|
||||
? [{
|
||||
name: t('common.appMenus.promptEng'),
|
||||
name: t('appMenus.promptEng', { ns: 'common' }),
|
||||
href: `/app/${appId}/${(mode === AppModeEnum.WORKFLOW || mode === AppModeEnum.ADVANCED_CHAT) ? 'workflow' : 'configuration'}`,
|
||||
icon: RiTerminalWindowLine,
|
||||
selectedIcon: RiTerminalWindowFill,
|
||||
@ -78,7 +78,7 @@ const AppDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
|
||||
: []
|
||||
),
|
||||
{
|
||||
name: t('common.appMenus.apiAccess'),
|
||||
name: t('appMenus.apiAccess', { ns: 'common' }),
|
||||
href: `/app/${appId}/develop`,
|
||||
icon: RiTerminalBoxLine,
|
||||
selectedIcon: RiTerminalBoxFill,
|
||||
@ -86,8 +86,8 @@ const AppDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
|
||||
...(isCurrentWorkspaceEditor
|
||||
? [{
|
||||
name: mode !== AppModeEnum.WORKFLOW
|
||||
? t('common.appMenus.logAndAnn')
|
||||
: t('common.appMenus.logs'),
|
||||
? t('appMenus.logAndAnn', { ns: 'common' })
|
||||
: t('appMenus.logs', { ns: 'common' }),
|
||||
href: `/app/${appId}/logs`,
|
||||
icon: RiFileList3Line,
|
||||
selectedIcon: RiFileList3Fill,
|
||||
@ -95,7 +95,7 @@ const AppDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
|
||||
: []
|
||||
),
|
||||
{
|
||||
name: t('common.appMenus.overview'),
|
||||
name: t('appMenus.overview', { ns: 'common' }),
|
||||
href: `/app/${appId}/overview`,
|
||||
icon: RiDashboard2Line,
|
||||
selectedIcon: RiDashboard2Fill,
|
||||
@ -104,7 +104,7 @@ const AppDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
|
||||
return navConfig
|
||||
}, [t])
|
||||
|
||||
useDocumentTitle(appDetail?.name || t('common.menus.appDetail'))
|
||||
useDocumentTitle(appDetail?.name || t('menus.appDetail', { ns: 'common' }))
|
||||
|
||||
useEffect(() => {
|
||||
if (appDetail) {
|
||||
|
||||
@ -4,6 +4,7 @@ import type { IAppCardProps } from '@/app/components/app/overview/app-card'
|
||||
import type { BlockEnum } from '@/app/components/workflow/types'
|
||||
import type { UpdateAppSiteCodeResponse } from '@/models/app'
|
||||
import type { App } from '@/types/app'
|
||||
import type { I18nKeysByPrefix } from '@/types/i18n'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
@ -62,7 +63,7 @@ const CardView: FC<ICardViewProps> = ({ appId, isInPanel, className }) => {
|
||||
const buildTriggerModeMessage = useCallback((featureName: string) => (
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="text-xs text-text-secondary">
|
||||
{t('appOverview.overview.disableTooltip.triggerMode', { feature: featureName })}
|
||||
{t('overview.disableTooltip.triggerMode', { ns: 'appOverview', feature: featureName })}
|
||||
</div>
|
||||
<div
|
||||
className="cursor-pointer text-xs font-medium text-text-accent hover:underline"
|
||||
@ -71,19 +72,19 @@ const CardView: FC<ICardViewProps> = ({ appId, isInPanel, className }) => {
|
||||
window.open(triggerDocUrl, '_blank')
|
||||
}}
|
||||
>
|
||||
{t('appOverview.overview.appInfo.enableTooltip.learnMore')}
|
||||
{t('overview.appInfo.enableTooltip.learnMore', { ns: 'appOverview' })}
|
||||
</div>
|
||||
</div>
|
||||
), [t, triggerDocUrl])
|
||||
|
||||
const disableWebAppTooltip = disableAppCards
|
||||
? buildTriggerModeMessage(t('appOverview.overview.appInfo.title'))
|
||||
? buildTriggerModeMessage(t('overview.appInfo.title', { ns: 'appOverview' }))
|
||||
: null
|
||||
const disableApiTooltip = disableAppCards
|
||||
? buildTriggerModeMessage(t('appOverview.overview.apiInfo.title'))
|
||||
? buildTriggerModeMessage(t('overview.apiInfo.title', { ns: 'appOverview' }))
|
||||
: null
|
||||
const disableMcpTooltip = disableAppCards
|
||||
? buildTriggerModeMessage(t('tools.mcp.server.title'))
|
||||
? buildTriggerModeMessage(t('mcp.server.title', { ns: 'tools' }))
|
||||
: null
|
||||
|
||||
const updateAppDetail = async () => {
|
||||
@ -94,7 +95,7 @@ const CardView: FC<ICardViewProps> = ({ appId, isInPanel, className }) => {
|
||||
catch (error) { console.error(error) }
|
||||
}
|
||||
|
||||
const handleCallbackResult = (err: Error | null, message?: string) => {
|
||||
const handleCallbackResult = (err: Error | null, message?: I18nKeysByPrefix<'common', 'actionMsg.'>) => {
|
||||
const type = err ? 'error' : 'success'
|
||||
|
||||
message ||= (type === 'success' ? 'modifiedSuccessfully' : 'modifiedUnsuccessfully')
|
||||
@ -104,7 +105,7 @@ const CardView: FC<ICardViewProps> = ({ appId, isInPanel, className }) => {
|
||||
|
||||
notify({
|
||||
type,
|
||||
message: t(`common.actionMsg.${message}` as any) as string,
|
||||
message: t(`actionMsg.${message}`, { ns: 'common' }) as string,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
'use client'
|
||||
import type { PeriodParams } from '@/app/components/app/overview/app-chart'
|
||||
import type { I18nKeysByPrefix } from '@/types/i18n'
|
||||
import dayjs from 'dayjs'
|
||||
import quarterOfYear from 'dayjs/plugin/quarterOfYear'
|
||||
import * as React from 'react'
|
||||
@ -16,7 +17,9 @@ dayjs.extend(quarterOfYear)
|
||||
|
||||
const today = dayjs()
|
||||
|
||||
const TIME_PERIOD_MAPPING = [
|
||||
type TimePeriodName = I18nKeysByPrefix<'appLog', 'filter.period.'>
|
||||
|
||||
const TIME_PERIOD_MAPPING: { value: number, name: TimePeriodName }[] = [
|
||||
{ value: 0, name: 'today' },
|
||||
{ value: 7, name: 'last7days' },
|
||||
{ value: 30, name: 'last30days' },
|
||||
@ -35,8 +38,8 @@ export default function ChartView({ appId, headerRight }: IChartViewProps) {
|
||||
const isChatApp = appDetail?.mode !== 'completion' && appDetail?.mode !== 'workflow'
|
||||
const isWorkflow = appDetail?.mode === 'workflow'
|
||||
const [period, setPeriod] = useState<PeriodParams>(IS_CLOUD_EDITION
|
||||
? { name: t('appLog.filter.period.today'), query: { start: today.startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } }
|
||||
: { name: t('appLog.filter.period.last7days'), query: { start: today.subtract(7, 'day').startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } },
|
||||
? { name: t('filter.period.today', { ns: 'appLog' }), query: { start: today.startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } }
|
||||
: { name: t('filter.period.last7days', { ns: 'appLog' }), query: { start: today.subtract(7, 'day').startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } },
|
||||
)
|
||||
|
||||
if (!appDetail)
|
||||
@ -45,7 +48,7 @@ export default function ChartView({ appId, headerRight }: IChartViewProps) {
|
||||
return (
|
||||
<div>
|
||||
<div className="mb-4">
|
||||
<div className="system-xl-semibold mb-2 text-text-primary">{t('common.appMenus.overview')}</div>
|
||||
<div className="system-xl-semibold mb-2 text-text-primary">{t('appMenus.overview', { ns: 'common' })}</div>
|
||||
<div className="flex items-center justify-between">
|
||||
{IS_CLOUD_EDITION
|
||||
? (
|
||||
|
||||
@ -2,13 +2,16 @@
|
||||
import type { FC } from 'react'
|
||||
import type { PeriodParams } from '@/app/components/app/overview/app-chart'
|
||||
import type { Item } from '@/app/components/base/select'
|
||||
import type { I18nKeysByPrefix } from '@/types/i18n'
|
||||
import dayjs from 'dayjs'
|
||||
import * as React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { SimpleSelect } from '@/app/components/base/select'
|
||||
|
||||
type TimePeriodName = I18nKeysByPrefix<'appLog', 'filter.period.'>
|
||||
|
||||
type Props = {
|
||||
periodMapping: { [key: string]: { value: number, name: string } }
|
||||
periodMapping: { [key: string]: { value: number, name: TimePeriodName } }
|
||||
onSelect: (payload: PeriodParams) => void
|
||||
queryDateFormat: string
|
||||
}
|
||||
@ -25,9 +28,9 @@ const LongTimeRangePicker: FC<Props> = ({
|
||||
const handleSelect = React.useCallback((item: Item) => {
|
||||
const id = item.value
|
||||
const value = periodMapping[id]?.value ?? '-1'
|
||||
const name = item.name || t('appLog.filter.period.allTime')
|
||||
const name = item.name || t('filter.period.allTime', { ns: 'appLog' })
|
||||
if (value === -1) {
|
||||
onSelect({ name: t('appLog.filter.period.allTime'), query: undefined })
|
||||
onSelect({ name: t('filter.period.allTime', { ns: 'appLog' }), query: undefined })
|
||||
}
|
||||
else if (value === 0) {
|
||||
const startOfToday = today.startOf('day').format(queryDateFormat)
|
||||
@ -53,7 +56,7 @@ const LongTimeRangePicker: FC<Props> = ({
|
||||
|
||||
return (
|
||||
<SimpleSelect
|
||||
items={Object.entries(periodMapping).map(([k, v]) => ({ value: k, name: t(`appLog.filter.period.${v.name}` as any) as string }))}
|
||||
items={Object.entries(periodMapping).map(([k, v]) => ({ value: k, name: t(`filter.period.${v.name}`, { ns: 'appLog' }) }))}
|
||||
className="mt-0 !w-40"
|
||||
notClearable={true}
|
||||
onSelect={handleSelect}
|
||||
|
||||
@ -4,11 +4,11 @@ import type { FC } from 'react'
|
||||
import type { TriggerProps } from '@/app/components/base/date-and-time-picker/types'
|
||||
import { RiCalendarLine } from '@remixicon/react'
|
||||
import dayjs from 'dayjs'
|
||||
import { noop } from 'es-toolkit/compat'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
import * as React from 'react'
|
||||
import { useCallback } from 'react'
|
||||
import Picker from '@/app/components/base/date-and-time-picker/date-picker'
|
||||
import { useI18N } from '@/context/i18n'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { formatToLocalTime } from '@/utils/format'
|
||||
|
||||
@ -26,7 +26,7 @@ const DatePicker: FC<Props> = ({
|
||||
onStartChange,
|
||||
onEndChange,
|
||||
}) => {
|
||||
const { locale } = useI18N()
|
||||
const locale = useLocale()
|
||||
|
||||
const renderDate = useCallback(({ value, handleClickTrigger, isOpen }: TriggerProps) => {
|
||||
return (
|
||||
|
||||
@ -2,19 +2,22 @@
|
||||
import type { Dayjs } from 'dayjs'
|
||||
import type { FC } from 'react'
|
||||
import type { PeriodParams, PeriodParamsWithTimeRange } from '@/app/components/app/overview/app-chart'
|
||||
import type { I18nKeysByPrefix } from '@/types/i18n'
|
||||
import dayjs from 'dayjs'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { HourglassShape } from '@/app/components/base/icons/src/vender/other'
|
||||
import { useI18N } from '@/context/i18n'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import { formatToLocalTime } from '@/utils/format'
|
||||
import DatePicker from './date-picker'
|
||||
import RangeSelector from './range-selector'
|
||||
|
||||
const today = dayjs()
|
||||
|
||||
type TimePeriodName = I18nKeysByPrefix<'appLog', 'filter.period.'>
|
||||
|
||||
type Props = {
|
||||
ranges: { value: number, name: string }[]
|
||||
ranges: { value: number, name: TimePeriodName }[]
|
||||
onSelect: (payload: PeriodParams) => void
|
||||
queryDateFormat: string
|
||||
}
|
||||
@ -24,7 +27,7 @@ const TimeRangePicker: FC<Props> = ({
|
||||
onSelect,
|
||||
queryDateFormat,
|
||||
}) => {
|
||||
const { locale } = useI18N()
|
||||
const locale = useLocale()
|
||||
|
||||
const [isCustomRange, setIsCustomRange] = useState(false)
|
||||
const [start, setStart] = useState<Dayjs>(today)
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
import type { FC } from 'react'
|
||||
import type { PeriodParamsWithTimeRange, TimeRange } from '@/app/components/app/overview/app-chart'
|
||||
import type { Item } from '@/app/components/base/select'
|
||||
import type { I18nKeysByPrefix } from '@/types/i18n'
|
||||
import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react'
|
||||
import dayjs from 'dayjs'
|
||||
import * as React from 'react'
|
||||
@ -12,9 +13,11 @@ import { cn } from '@/utils/classnames'
|
||||
|
||||
const today = dayjs()
|
||||
|
||||
type TimePeriodName = I18nKeysByPrefix<'appLog', 'filter.period.'>
|
||||
|
||||
type Props = {
|
||||
isCustomRange: boolean
|
||||
ranges: { value: number, name: string }[]
|
||||
ranges: { value: number, name: TimePeriodName }[]
|
||||
onSelect: (payload: PeriodParamsWithTimeRange) => void
|
||||
}
|
||||
|
||||
@ -42,7 +45,7 @@ const RangeSelector: FC<Props> = ({
|
||||
const renderTrigger = useCallback((item: Item | null, isOpen: boolean) => {
|
||||
return (
|
||||
<div className={cn('flex h-8 cursor-pointer items-center space-x-1.5 rounded-lg bg-components-input-bg-normal pl-3 pr-2', isOpen && 'bg-state-base-hover-alt')}>
|
||||
<div className="system-sm-regular text-components-input-text-filled">{isCustomRange ? t('appLog.filter.period.custom') : item?.name}</div>
|
||||
<div className="system-sm-regular text-components-input-text-filled">{isCustomRange ? t('filter.period.custom', { ns: 'appLog' }) : item?.name}</div>
|
||||
<RiArrowDownSLine className={cn('size-4 text-text-quaternary', isOpen && 'text-text-secondary')} />
|
||||
</div>
|
||||
)
|
||||
@ -66,7 +69,7 @@ const RangeSelector: FC<Props> = ({
|
||||
}, [])
|
||||
return (
|
||||
<SimpleSelect
|
||||
items={ranges.map(v => ({ ...v, name: t(`appLog.filter.period.${v.name}` as any) as string }))}
|
||||
items={ranges.map(v => ({ ...v, name: t(`filter.period.${v.name}`, { ns: 'appLog' }) }))}
|
||||
className="mt-0 !w-40"
|
||||
notClearable={true}
|
||||
onSelect={handleSelectRange}
|
||||
|
||||
@ -15,7 +15,7 @@ import ProviderPanel from './provider-panel'
|
||||
import TracingIcon from './tracing-icon'
|
||||
import { TracingProvider } from './type'
|
||||
|
||||
const I18N_PREFIX = 'app.tracing'
|
||||
const I18N_PREFIX = 'tracing'
|
||||
|
||||
export type PopupProps = {
|
||||
appId: string
|
||||
@ -327,19 +327,19 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center">
|
||||
<TracingIcon size="md" className="mr-2" />
|
||||
<div className="title-2xl-semi-bold text-text-primary">{t(`${I18N_PREFIX}.tracing`)}</div>
|
||||
<div className="title-2xl-semi-bold text-text-primary">{t(`${I18N_PREFIX}.tracing`, { ns: 'app' })}</div>
|
||||
</div>
|
||||
<div className="flex items-center">
|
||||
<Indicator color={enabled ? 'green' : 'gray'} />
|
||||
<div className={cn('system-xs-semibold-uppercase ml-1 text-text-tertiary', enabled && 'text-util-colors-green-green-600')}>
|
||||
{t(`${I18N_PREFIX}.${enabled ? 'enabled' : 'disabled'}`)}
|
||||
{t(`${I18N_PREFIX}.${enabled ? 'enabled' : 'disabled'}`, { ns: 'app' })}
|
||||
</div>
|
||||
{!readOnly && (
|
||||
<>
|
||||
{providerAllNotConfigured
|
||||
? (
|
||||
<Tooltip
|
||||
popupContent={t(`${I18N_PREFIX}.disabledTip`)}
|
||||
popupContent={t(`${I18N_PREFIX}.disabledTip`, { ns: 'app' })}
|
||||
>
|
||||
{switchContent}
|
||||
</Tooltip>
|
||||
@ -351,14 +351,14 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
</div>
|
||||
|
||||
<div className="system-xs-regular mt-2 text-text-tertiary">
|
||||
{t(`${I18N_PREFIX}.tracingDescription`)}
|
||||
{t(`${I18N_PREFIX}.tracingDescription`, { ns: 'app' })}
|
||||
</div>
|
||||
<Divider className="my-3" />
|
||||
<div className="relative">
|
||||
{(providerAllConfigured || providerAllNotConfigured)
|
||||
? (
|
||||
<>
|
||||
<div className="system-xs-medium-uppercase text-text-tertiary">{t(`${I18N_PREFIX}.configProviderTitle.${providerAllConfigured ? 'configured' : 'notConfigured'}`)}</div>
|
||||
<div className="system-xs-medium-uppercase text-text-tertiary">{t(`${I18N_PREFIX}.configProviderTitle.${providerAllConfigured ? 'configured' : 'notConfigured'}`, { ns: 'app' })}</div>
|
||||
<div className="mt-2 max-h-96 space-y-2 overflow-y-auto">
|
||||
{langfusePanel}
|
||||
{langSmithPanel}
|
||||
@ -375,11 +375,11 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
)
|
||||
: (
|
||||
<>
|
||||
<div className="system-xs-medium-uppercase text-text-tertiary">{t(`${I18N_PREFIX}.configProviderTitle.configured`)}</div>
|
||||
<div className="system-xs-medium-uppercase text-text-tertiary">{t(`${I18N_PREFIX}.configProviderTitle.configured`, { ns: 'app' })}</div>
|
||||
<div className="mt-2 max-h-40 space-y-2 overflow-y-auto">
|
||||
{configuredProviderPanel()}
|
||||
</div>
|
||||
<div className="system-xs-medium-uppercase mt-3 text-text-tertiary">{t(`${I18N_PREFIX}.configProviderTitle.moreProvider`)}</div>
|
||||
<div className="system-xs-medium-uppercase mt-3 text-text-tertiary">{t(`${I18N_PREFIX}.configProviderTitle.moreProvider`, { ns: 'app' })}</div>
|
||||
<div className="mt-2 max-h-40 space-y-2 overflow-y-auto">
|
||||
{moreProviderPanel()}
|
||||
</div>
|
||||
|
||||
@ -23,7 +23,7 @@ import ConfigButton from './config-button'
|
||||
import TracingIcon from './tracing-icon'
|
||||
import { TracingProvider } from './type'
|
||||
|
||||
const I18N_PREFIX = 'app.tracing'
|
||||
const I18N_PREFIX = 'tracing'
|
||||
|
||||
const Panel: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
@ -45,7 +45,7 @@ const Panel: FC = () => {
|
||||
if (!noToast) {
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('common.api.success'),
|
||||
message: t('api.success', { ns: 'common' }),
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -254,7 +254,7 @@ const Panel: FC = () => {
|
||||
)}
|
||||
>
|
||||
<TracingIcon size="md" />
|
||||
<div className="system-sm-semibold mx-2 text-text-secondary">{t(`${I18N_PREFIX}.title`)}</div>
|
||||
<div className="system-sm-semibold mx-2 text-text-secondary">{t(`${I18N_PREFIX}.title`, { ns: 'app' })}</div>
|
||||
<div className="rounded-md p-1">
|
||||
<RiEqualizer2Line className="h-4 w-4 text-text-tertiary" />
|
||||
</div>
|
||||
@ -295,7 +295,7 @@ const Panel: FC = () => {
|
||||
<div className="ml-4 mr-1 flex items-center">
|
||||
<Indicator color={enabled ? 'green' : 'gray'} />
|
||||
<div className="system-xs-semibold-uppercase ml-1.5 text-text-tertiary">
|
||||
{t(`${I18N_PREFIX}.${enabled ? 'enabled' : 'disabled'}`)}
|
||||
{t(`${I18N_PREFIX}.${enabled ? 'enabled' : 'disabled'}`, { ns: 'app' })}
|
||||
</div>
|
||||
</div>
|
||||
{InUseProviderIcon && <InUseProviderIcon className="ml-1 h-4" />}
|
||||
|
||||
@ -30,7 +30,7 @@ type Props = {
|
||||
onChosen: (provider: TracingProvider) => void
|
||||
}
|
||||
|
||||
const I18N_PREFIX = 'app.tracing.configProvider'
|
||||
const I18N_PREFIX = 'tracing.configProvider'
|
||||
|
||||
const arizeConfigTemplate = {
|
||||
api_key: '',
|
||||
@ -157,7 +157,7 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
})
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('common.api.remove'),
|
||||
message: t('api.remove', { ns: 'common' }),
|
||||
})
|
||||
onRemoved()
|
||||
hideRemoveConfirm()
|
||||
@ -177,37 +177,37 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
if (type === TracingProvider.arize) {
|
||||
const postData = config as ArizeConfig
|
||||
if (!postData.api_key)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: 'API Key' })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'API Key' })
|
||||
if (!postData.space_id)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Space ID' })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Space ID' })
|
||||
if (!errorMessage && !postData.project)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: t(`${I18N_PREFIX}.project`) })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.project`, { ns: 'app' }) })
|
||||
}
|
||||
|
||||
if (type === TracingProvider.phoenix) {
|
||||
const postData = config as PhoenixConfig
|
||||
if (!postData.api_key)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: 'API Key' })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'API Key' })
|
||||
if (!errorMessage && !postData.project)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: t(`${I18N_PREFIX}.project`) })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.project`, { ns: 'app' }) })
|
||||
}
|
||||
|
||||
if (type === TracingProvider.langSmith) {
|
||||
const postData = config as LangSmithConfig
|
||||
if (!postData.api_key)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: 'API Key' })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'API Key' })
|
||||
if (!errorMessage && !postData.project)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: t(`${I18N_PREFIX}.project`) })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.project`, { ns: 'app' }) })
|
||||
}
|
||||
|
||||
if (type === TracingProvider.langfuse) {
|
||||
const postData = config as LangFuseConfig
|
||||
if (!errorMessage && !postData.secret_key)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: t(`${I18N_PREFIX}.secretKey`) })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.secretKey`, { ns: 'app' }) })
|
||||
if (!errorMessage && !postData.public_key)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: t(`${I18N_PREFIX}.publicKey`) })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.publicKey`, { ns: 'app' }) })
|
||||
if (!errorMessage && !postData.host)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Host' })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Host' })
|
||||
}
|
||||
|
||||
if (type === TracingProvider.opik) {
|
||||
@ -218,43 +218,43 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
if (type === TracingProvider.weave) {
|
||||
const postData = config as WeaveConfig
|
||||
if (!errorMessage && !postData.api_key)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: 'API Key' })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'API Key' })
|
||||
if (!errorMessage && !postData.project)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: t(`${I18N_PREFIX}.project`) })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.project`, { ns: 'app' }) })
|
||||
}
|
||||
|
||||
if (type === TracingProvider.aliyun) {
|
||||
const postData = config as AliyunConfig
|
||||
if (!errorMessage && !postData.app_name)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: 'App Name' })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'App Name' })
|
||||
if (!errorMessage && !postData.license_key)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: 'License Key' })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'License Key' })
|
||||
if (!errorMessage && !postData.endpoint)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Endpoint' })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Endpoint' })
|
||||
}
|
||||
|
||||
if (type === TracingProvider.mlflow) {
|
||||
const postData = config as MLflowConfig
|
||||
if (!errorMessage && !postData.tracking_uri)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Tracking URI' })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Tracking URI' })
|
||||
}
|
||||
|
||||
if (type === TracingProvider.databricks) {
|
||||
const postData = config as DatabricksConfig
|
||||
if (!errorMessage && !postData.experiment_id)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Experiment ID' })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Experiment ID' })
|
||||
if (!errorMessage && !postData.host)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Host' })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Host' })
|
||||
}
|
||||
|
||||
if (type === TracingProvider.tencent) {
|
||||
const postData = config as TencentConfig
|
||||
if (!errorMessage && !postData.token)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Token' })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Token' })
|
||||
if (!errorMessage && !postData.endpoint)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Endpoint' })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Endpoint' })
|
||||
if (!errorMessage && !postData.service_name)
|
||||
errorMessage = t('common.errorMsg.fieldRequired', { field: 'Service Name' })
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Service Name' })
|
||||
}
|
||||
|
||||
return errorMessage
|
||||
@ -281,7 +281,7 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
})
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('common.api.success'),
|
||||
message: t('api.success', { ns: 'common' }),
|
||||
})
|
||||
onSaved(config)
|
||||
if (isAdd)
|
||||
@ -303,8 +303,8 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
<div className="px-8 pt-8">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<div className="title-2xl-semi-bold text-text-primary">
|
||||
{t(`${I18N_PREFIX}.title`)}
|
||||
{t(`app.tracing.${type}.title`)}
|
||||
{t(`${I18N_PREFIX}.title`, { ns: 'app' })}
|
||||
{t(`tracing.${type}.title`, { ns: 'app' })}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@ -317,7 +317,7 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
isRequired
|
||||
value={(config as ArizeConfig).api_key}
|
||||
onChange={handleConfigChange('api_key')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'API Key' })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'API Key' })!}
|
||||
/>
|
||||
<Field
|
||||
label="Space ID"
|
||||
@ -325,15 +325,15 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
isRequired
|
||||
value={(config as ArizeConfig).space_id}
|
||||
onChange={handleConfigChange('space_id')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'Space ID' })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'Space ID' })!}
|
||||
/>
|
||||
<Field
|
||||
label={t(`${I18N_PREFIX}.project`)!}
|
||||
label={t(`${I18N_PREFIX}.project`, { ns: 'app' })!}
|
||||
labelClassName="!text-sm"
|
||||
isRequired
|
||||
value={(config as ArizeConfig).project}
|
||||
onChange={handleConfigChange('project')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: t(`${I18N_PREFIX}.project`) })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: t(`${I18N_PREFIX}.project`, { ns: 'app' }) })!}
|
||||
/>
|
||||
<Field
|
||||
label="Endpoint"
|
||||
@ -352,15 +352,15 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
isRequired
|
||||
value={(config as PhoenixConfig).api_key}
|
||||
onChange={handleConfigChange('api_key')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'API Key' })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'API Key' })!}
|
||||
/>
|
||||
<Field
|
||||
label={t(`${I18N_PREFIX}.project`)!}
|
||||
label={t(`${I18N_PREFIX}.project`, { ns: 'app' })!}
|
||||
labelClassName="!text-sm"
|
||||
isRequired
|
||||
value={(config as PhoenixConfig).project}
|
||||
onChange={handleConfigChange('project')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: t(`${I18N_PREFIX}.project`) })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: t(`${I18N_PREFIX}.project`, { ns: 'app' }) })!}
|
||||
/>
|
||||
<Field
|
||||
label="Endpoint"
|
||||
@ -379,7 +379,7 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
isRequired
|
||||
value={(config as AliyunConfig).license_key}
|
||||
onChange={handleConfigChange('license_key')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'License Key' })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'License Key' })!}
|
||||
/>
|
||||
<Field
|
||||
label="Endpoint"
|
||||
@ -404,7 +404,7 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
isRequired
|
||||
value={(config as TencentConfig).token}
|
||||
onChange={handleConfigChange('token')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'Token' })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'Token' })!}
|
||||
/>
|
||||
<Field
|
||||
label="Endpoint"
|
||||
@ -432,22 +432,22 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
isRequired
|
||||
value={(config as WeaveConfig).api_key}
|
||||
onChange={handleConfigChange('api_key')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'API Key' })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'API Key' })!}
|
||||
/>
|
||||
<Field
|
||||
label={t(`${I18N_PREFIX}.project`)!}
|
||||
label={t(`${I18N_PREFIX}.project`, { ns: 'app' })!}
|
||||
labelClassName="!text-sm"
|
||||
isRequired
|
||||
value={(config as WeaveConfig).project}
|
||||
onChange={handleConfigChange('project')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: t(`${I18N_PREFIX}.project`) })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: t(`${I18N_PREFIX}.project`, { ns: 'app' }) })!}
|
||||
/>
|
||||
<Field
|
||||
label="Entity"
|
||||
labelClassName="!text-sm"
|
||||
value={(config as WeaveConfig).entity}
|
||||
onChange={handleConfigChange('entity')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'Entity' })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'Entity' })!}
|
||||
/>
|
||||
<Field
|
||||
label="Endpoint"
|
||||
@ -473,15 +473,15 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
isRequired
|
||||
value={(config as LangSmithConfig).api_key}
|
||||
onChange={handleConfigChange('api_key')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'API Key' })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'API Key' })!}
|
||||
/>
|
||||
<Field
|
||||
label={t(`${I18N_PREFIX}.project`)!}
|
||||
label={t(`${I18N_PREFIX}.project`, { ns: 'app' })!}
|
||||
labelClassName="!text-sm"
|
||||
isRequired
|
||||
value={(config as LangSmithConfig).project}
|
||||
onChange={handleConfigChange('project')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: t(`${I18N_PREFIX}.project`) })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: t(`${I18N_PREFIX}.project`, { ns: 'app' }) })!}
|
||||
/>
|
||||
<Field
|
||||
label="Endpoint"
|
||||
@ -495,20 +495,20 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
{type === TracingProvider.langfuse && (
|
||||
<>
|
||||
<Field
|
||||
label={t(`${I18N_PREFIX}.secretKey`)!}
|
||||
label={t(`${I18N_PREFIX}.secretKey`, { ns: 'app' })!}
|
||||
labelClassName="!text-sm"
|
||||
value={(config as LangFuseConfig).secret_key}
|
||||
isRequired
|
||||
onChange={handleConfigChange('secret_key')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: t(`${I18N_PREFIX}.secretKey`) })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: t(`${I18N_PREFIX}.secretKey`, { ns: 'app' }) })!}
|
||||
/>
|
||||
<Field
|
||||
label={t(`${I18N_PREFIX}.publicKey`)!}
|
||||
label={t(`${I18N_PREFIX}.publicKey`, { ns: 'app' })!}
|
||||
labelClassName="!text-sm"
|
||||
isRequired
|
||||
value={(config as LangFuseConfig).public_key}
|
||||
onChange={handleConfigChange('public_key')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: t(`${I18N_PREFIX}.publicKey`) })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: t(`${I18N_PREFIX}.publicKey`, { ns: 'app' }) })!}
|
||||
/>
|
||||
<Field
|
||||
label="Host"
|
||||
@ -527,14 +527,14 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
labelClassName="!text-sm"
|
||||
value={(config as OpikConfig).api_key}
|
||||
onChange={handleConfigChange('api_key')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'API Key' })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'API Key' })!}
|
||||
/>
|
||||
<Field
|
||||
label={t(`${I18N_PREFIX}.project`)!}
|
||||
label={t(`${I18N_PREFIX}.project`, { ns: 'app' })!}
|
||||
labelClassName="!text-sm"
|
||||
value={(config as OpikConfig).project}
|
||||
onChange={handleConfigChange('project')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: t(`${I18N_PREFIX}.project`) })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: t(`${I18N_PREFIX}.project`, { ns: 'app' }) })!}
|
||||
/>
|
||||
<Field
|
||||
label="Workspace"
|
||||
@ -555,7 +555,7 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
{type === TracingProvider.mlflow && (
|
||||
<>
|
||||
<Field
|
||||
label={t(`${I18N_PREFIX}.trackingUri`)!}
|
||||
label={t(`${I18N_PREFIX}.trackingUri`, { ns: 'app' })!}
|
||||
labelClassName="!text-sm"
|
||||
value={(config as MLflowConfig).tracking_uri}
|
||||
isRequired
|
||||
@ -563,67 +563,67 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
placeholder="http://localhost:5000"
|
||||
/>
|
||||
<Field
|
||||
label={t(`${I18N_PREFIX}.experimentId`)!}
|
||||
label={t(`${I18N_PREFIX}.experimentId`, { ns: 'app' })!}
|
||||
labelClassName="!text-sm"
|
||||
isRequired
|
||||
value={(config as MLflowConfig).experiment_id}
|
||||
onChange={handleConfigChange('experiment_id')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: t(`${I18N_PREFIX}.experimentId`) })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: t(`${I18N_PREFIX}.experimentId`, { ns: 'app' }) })!}
|
||||
/>
|
||||
<Field
|
||||
label={t(`${I18N_PREFIX}.username`)!}
|
||||
label={t(`${I18N_PREFIX}.username`, { ns: 'app' })!}
|
||||
labelClassName="!text-sm"
|
||||
value={(config as MLflowConfig).username}
|
||||
onChange={handleConfigChange('username')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: t(`${I18N_PREFIX}.username`) })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: t(`${I18N_PREFIX}.username`, { ns: 'app' }) })!}
|
||||
/>
|
||||
<Field
|
||||
label={t(`${I18N_PREFIX}.password`)!}
|
||||
label={t(`${I18N_PREFIX}.password`, { ns: 'app' })!}
|
||||
labelClassName="!text-sm"
|
||||
value={(config as MLflowConfig).password}
|
||||
onChange={handleConfigChange('password')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: t(`${I18N_PREFIX}.password`) })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: t(`${I18N_PREFIX}.password`, { ns: 'app' }) })!}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
{type === TracingProvider.databricks && (
|
||||
<>
|
||||
<Field
|
||||
label={t(`${I18N_PREFIX}.experimentId`)!}
|
||||
label={t(`${I18N_PREFIX}.experimentId`, { ns: 'app' })!}
|
||||
labelClassName="!text-sm"
|
||||
value={(config as DatabricksConfig).experiment_id}
|
||||
onChange={handleConfigChange('experiment_id')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: t(`${I18N_PREFIX}.experimentId`) })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: t(`${I18N_PREFIX}.experimentId`, { ns: 'app' }) })!}
|
||||
isRequired
|
||||
/>
|
||||
<Field
|
||||
label={t(`${I18N_PREFIX}.databricksHost`)!}
|
||||
label={t(`${I18N_PREFIX}.databricksHost`, { ns: 'app' })!}
|
||||
labelClassName="!text-sm"
|
||||
value={(config as DatabricksConfig).host}
|
||||
onChange={handleConfigChange('host')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: t(`${I18N_PREFIX}.databricksHost`) })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: t(`${I18N_PREFIX}.databricksHost`, { ns: 'app' }) })!}
|
||||
isRequired
|
||||
/>
|
||||
<Field
|
||||
label={t(`${I18N_PREFIX}.clientId`)!}
|
||||
label={t(`${I18N_PREFIX}.clientId`, { ns: 'app' })!}
|
||||
labelClassName="!text-sm"
|
||||
value={(config as DatabricksConfig).client_id}
|
||||
onChange={handleConfigChange('client_id')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: t(`${I18N_PREFIX}.clientId`) })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: t(`${I18N_PREFIX}.clientId`, { ns: 'app' }) })!}
|
||||
/>
|
||||
<Field
|
||||
label={t(`${I18N_PREFIX}.clientSecret`)!}
|
||||
label={t(`${I18N_PREFIX}.clientSecret`, { ns: 'app' })!}
|
||||
labelClassName="!text-sm"
|
||||
value={(config as DatabricksConfig).client_secret}
|
||||
onChange={handleConfigChange('client_secret')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: t(`${I18N_PREFIX}.clientSecret`) })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: t(`${I18N_PREFIX}.clientSecret`, { ns: 'app' }) })!}
|
||||
/>
|
||||
<Field
|
||||
label={t(`${I18N_PREFIX}.personalAccessToken`)!}
|
||||
label={t(`${I18N_PREFIX}.personalAccessToken`, { ns: 'app' })!}
|
||||
labelClassName="!text-sm"
|
||||
value={(config as DatabricksConfig).personal_access_token}
|
||||
onChange={handleConfigChange('personal_access_token')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { key: t(`${I18N_PREFIX}.personalAccessToken`) })!}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: t(`${I18N_PREFIX}.personalAccessToken`, { ns: 'app' }) })!}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
@ -634,7 +634,7 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
target="_blank"
|
||||
href={docURL[type]}
|
||||
>
|
||||
<span>{t(`${I18N_PREFIX}.viewDocsLink`, { key: t(`app.tracing.${type}.title`) })}</span>
|
||||
<span>{t(`${I18N_PREFIX}.viewDocsLink`, { ns: 'app', key: t(`tracing.${type}.title`, { ns: 'app' }) })}</span>
|
||||
<LinkExternal02 className="h-3 w-3" />
|
||||
</a>
|
||||
<div className="flex items-center">
|
||||
@ -644,7 +644,7 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
className="h-9 text-sm font-medium text-text-secondary"
|
||||
onClick={showRemoveConfirm}
|
||||
>
|
||||
<span className="text-[#D92D20]">{t('common.operation.remove')}</span>
|
||||
<span className="text-[#D92D20]">{t('operation.remove', { ns: 'common' })}</span>
|
||||
</Button>
|
||||
<Divider type="vertical" className="mx-3 h-[18px]" />
|
||||
</>
|
||||
@ -653,7 +653,7 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
className="mr-2 h-9 text-sm font-medium text-text-secondary"
|
||||
onClick={onCancel}
|
||||
>
|
||||
{t('common.operation.cancel')}
|
||||
{t('operation.cancel', { ns: 'common' })}
|
||||
</Button>
|
||||
<Button
|
||||
className="h-9 text-sm font-medium"
|
||||
@ -661,7 +661,7 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
onClick={handleSave}
|
||||
loading={isSaving}
|
||||
>
|
||||
{t(`common.operation.${isAdd ? 'saveAndEnable' : 'save'}`)}
|
||||
{t(`operation.${isAdd ? 'saveAndEnable' : 'save'}`, { ns: 'common' })}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
@ -670,7 +670,7 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
<div className="border-t-[0.5px] border-divider-regular">
|
||||
<div className="flex items-center justify-center bg-background-section-burn py-3 text-xs text-text-tertiary">
|
||||
<Lock01 className="mr-1 h-3 w-3 text-text-tertiary" />
|
||||
{t('common.modelProvider.encrypted.front')}
|
||||
{t('modelProvider.encrypted.front', { ns: 'common' })}
|
||||
<a
|
||||
className="mx-1 text-primary-600"
|
||||
target="_blank"
|
||||
@ -679,7 +679,7 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
>
|
||||
PKCS1_OAEP
|
||||
</a>
|
||||
{t('common.modelProvider.encrypted.back')}
|
||||
{t('modelProvider.encrypted.back', { ns: 'common' })}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@ -691,8 +691,8 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
<Confirm
|
||||
isShow
|
||||
type="warning"
|
||||
title={t(`${I18N_PREFIX}.removeConfirmTitle`, { key: t(`app.tracing.${type}.title`) })!}
|
||||
content={t(`${I18N_PREFIX}.removeConfirmContent`)}
|
||||
title={t(`${I18N_PREFIX}.removeConfirmTitle`, { ns: 'app', key: t(`tracing.${type}.title`, { ns: 'app' }) })!}
|
||||
content={t(`${I18N_PREFIX}.removeConfirmContent`, { ns: 'app' })}
|
||||
onConfirm={handleRemove}
|
||||
onCancel={hideRemoveConfirm}
|
||||
/>
|
||||
|
||||
@ -11,7 +11,7 @@ import { Eye as View } from '@/app/components/base/icons/src/vender/solid/genera
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { TracingProvider } from './type'
|
||||
|
||||
const I18N_PREFIX = 'app.tracing'
|
||||
const I18N_PREFIX = 'tracing'
|
||||
|
||||
type Props = {
|
||||
type: TracingProvider
|
||||
@ -82,14 +82,14 @@ const ProviderPanel: FC<Props> = ({
|
||||
<div className="flex items-center justify-between space-x-1">
|
||||
<div className="flex items-center">
|
||||
<Icon className="h-6" />
|
||||
{isChosen && <div className="system-2xs-medium-uppercase ml-1 flex h-4 items-center rounded-[4px] border border-text-accent-secondary px-1 text-text-accent-secondary">{t(`${I18N_PREFIX}.inUse`)}</div>}
|
||||
{isChosen && <div className="system-2xs-medium-uppercase ml-1 flex h-4 items-center rounded-[4px] border border-text-accent-secondary px-1 text-text-accent-secondary">{t(`${I18N_PREFIX}.inUse`, { ns: 'app' })}</div>}
|
||||
</div>
|
||||
{!readOnly && (
|
||||
<div className="flex items-center justify-between space-x-1">
|
||||
{hasConfigured && (
|
||||
<div className="flex h-6 cursor-pointer items-center space-x-1 rounded-md border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-2 text-text-secondary shadow-xs" onClick={viewBtnClick}>
|
||||
<View className="h-3 w-3" />
|
||||
<div className="text-xs font-medium">{t(`${I18N_PREFIX}.view`)}</div>
|
||||
<div className="text-xs font-medium">{t(`${I18N_PREFIX}.view`, { ns: 'app' })}</div>
|
||||
</div>
|
||||
)}
|
||||
<div
|
||||
@ -97,13 +97,13 @@ const ProviderPanel: FC<Props> = ({
|
||||
onClick={handleConfigBtnClick}
|
||||
>
|
||||
<RiEqualizer2Line className="h-3 w-3" />
|
||||
<div className="text-xs font-medium">{t(`${I18N_PREFIX}.config`)}</div>
|
||||
<div className="text-xs font-medium">{t(`${I18N_PREFIX}.config`, { ns: 'app' })}</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="system-xs-regular mt-2 text-text-tertiary">
|
||||
{t(`${I18N_PREFIX}.${type}.description`)}
|
||||
{t(`${I18N_PREFIX}.${type}.description`, { ns: 'app' })}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user