diff --git a/.claude/skills/frontend-testing/assets/component-test.template.tsx b/.claude/skills/frontend-testing/assets/component-test.template.tsx index c39baff916..6b7803bd4b 100644 --- a/.claude/skills/frontend-testing/assets/component-test.template.tsx +++ b/.claude/skills/frontend-testing/assets/component-test.template.tsx @@ -28,17 +28,14 @@ import userEvent from '@testing-library/user-event' // i18n (automatically mocked) // WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup -// No explicit mock needed - it returns translation keys as-is +// The global mock provides: useTranslation, Trans, useMixedTranslation, useGetLanguage +// No explicit mock needed for most tests +// // Override only if custom translations are required: -// vi.mock('react-i18next', () => ({ -// useTranslation: () => ({ -// t: (key: string) => { -// const customTranslations: Record = { -// 'my.custom.key': 'Custom Translation', -// } -// return customTranslations[key] || key -// }, -// }), +// import { createReactI18nextMock } from '@/test/i18n-mock' +// vi.mock('react-i18next', () => createReactI18nextMock({ +// 'my.custom.key': 'Custom Translation', +// 'button.save': 'Save', // })) // Router (if component uses useRouter, usePathname, useSearchParams) diff --git a/.claude/skills/frontend-testing/references/mocking.md b/.claude/skills/frontend-testing/references/mocking.md index 23889c8d3d..c70bcf0ae5 100644 --- a/.claude/skills/frontend-testing/references/mocking.md +++ b/.claude/skills/frontend-testing/references/mocking.md @@ -52,23 +52,29 @@ Modules are not mocked automatically. Use `vi.mock` in test files, or add global ### 1. i18n (Auto-loaded via Global Mock) A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup. -**No explicit mock needed** for most tests - it returns translation keys as-is. -For tests requiring custom translations, override the mock: +The global mock provides: + +- `useTranslation` - returns translation keys with namespace prefix +- `Trans` component - renders i18nKey and components +- `useMixedTranslation` (from `@/app/components/plugins/marketplace/hooks`) +- `useGetLanguage` (from `@/context/i18n`) - returns `'en-US'` + +**Default behavior**: Most tests should use the global mock (no local override needed). + +**For custom translations**: Use the helper function from `@/test/i18n-mock`: ```typescript -vi.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => { - const translations: Record = { - 'my.custom.key': 'Custom translation', - } - return translations[key] || key - }, - }), +import { createReactI18nextMock } from '@/test/i18n-mock' + +vi.mock('react-i18next', () => createReactI18nextMock({ + 'my.custom.key': 'Custom translation', + 'button.save': 'Save', })) ``` +**Avoid**: Manually defining `useTranslation` mocks that just return the key - the global mock already does this. + ### 2. Next.js Router ```typescript diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index d463349686..462ece303e 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -110,6 +110,16 @@ jobs: working-directory: ./web run: pnpm run type-check:tsgo + - name: Web dead code check + if: steps.changed-files.outputs.any_changed == 'true' + working-directory: ./web + run: pnpm run knip + + - name: Web build check + if: steps.changed-files.outputs.any_changed == 'true' + working-directory: ./web + run: pnpm run build + superlinter: name: SuperLinter runs-on: ubuntu-latest diff --git a/.github/workflows/translate-i18n-base-on-english.yml b/.github/workflows/translate-i18n-base-on-english.yml index 58cd18371f..a51350f630 100644 --- a/.github/workflows/translate-i18n-base-on-english.yml +++ b/.github/workflows/translate-i18n-base-on-english.yml @@ -4,7 +4,7 @@ on: push: branches: [main] paths: - - 'web/i18n/en-US/*.ts' + - 'web/i18n/en-US/*.json' permissions: contents: write @@ -28,13 +28,13 @@ jobs: run: | git fetch origin "${{ github.event.before }}" || true git fetch origin "${{ github.sha }}" || true - changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.ts') + changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json') echo "Changed files: $changed_files" if [ -n "$changed_files" ]; then echo "FILES_CHANGED=true" >> $GITHUB_ENV file_args="" for file in $changed_files; do - filename=$(basename "$file" .ts) + filename=$(basename "$file" .json) file_args="$file_args --file $filename" done echo "FILE_ARGS=$file_args" >> $GITHUB_ENV @@ -65,7 +65,7 @@ jobs: - name: Generate i18n translations if: env.FILES_CHANGED == 'true' working-directory: ./web - run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }} + run: pnpm run i18n:gen ${{ env.FILE_ARGS }} - name: Create Pull Request if: env.FILES_CHANGED == 'true' diff --git a/api/.env.example b/api/.env.example index 9cbb111d31..5f8d369ec4 100644 --- a/api/.env.example +++ b/api/.env.example @@ -101,6 +101,15 @@ S3_ACCESS_KEY=your-access-key S3_SECRET_KEY=your-secret-key S3_REGION=your-region +# Workflow run and Conversation archive storage (S3-compatible) +ARCHIVE_STORAGE_ENABLED=false +ARCHIVE_STORAGE_ENDPOINT= +ARCHIVE_STORAGE_ARCHIVE_BUCKET= +ARCHIVE_STORAGE_EXPORT_BUCKET= +ARCHIVE_STORAGE_ACCESS_KEY= +ARCHIVE_STORAGE_SECRET_KEY= +ARCHIVE_STORAGE_REGION=auto + # Azure Blob Storage configuration AZURE_BLOB_ACCOUNT_NAME=your-account-name AZURE_BLOB_ACCOUNT_KEY=your-account-key @@ -128,6 +137,7 @@ TENCENT_COS_SECRET_KEY=your-secret-key TENCENT_COS_SECRET_ID=your-secret-id TENCENT_COS_REGION=your-region TENCENT_COS_SCHEME=your-scheme +TENCENT_COS_CUSTOM_DOMAIN=your-custom-domain # Huawei OBS Storage Configuration HUAWEI_OBS_BUCKET_NAME=your-bucket-name diff --git a/api/.ruff.toml b/api/.ruff.toml index 7206f7fa0f..8db0cbcb21 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -1,4 +1,8 @@ -exclude = ["migrations/*"] +exclude = [ + "migrations/*", + ".git", + ".git/**", +] line-length = 120 [format] diff --git a/api/configs/extra/__init__.py b/api/configs/extra/__init__.py index 4543b5389d..de97adfc0e 100644 --- a/api/configs/extra/__init__.py +++ b/api/configs/extra/__init__.py @@ -1,9 +1,11 @@ +from configs.extra.archive_config import ArchiveStorageConfig from configs.extra.notion_config import NotionConfig from configs.extra.sentry_config import SentryConfig class ExtraServiceConfig( # place the configs in alphabet order + ArchiveStorageConfig, NotionConfig, SentryConfig, ): diff --git a/api/configs/extra/archive_config.py b/api/configs/extra/archive_config.py new file mode 100644 index 0000000000..a85628fa61 --- /dev/null +++ b/api/configs/extra/archive_config.py @@ -0,0 +1,43 @@ +from pydantic import Field +from pydantic_settings import BaseSettings + + +class ArchiveStorageConfig(BaseSettings): + """ + Configuration settings for workflow run logs archiving storage. + """ + + ARCHIVE_STORAGE_ENABLED: bool = Field( + description="Enable workflow run logs archiving to S3-compatible storage", + default=False, + ) + + ARCHIVE_STORAGE_ENDPOINT: str | None = Field( + description="URL of the S3-compatible storage endpoint (e.g., 'https://storage.example.com')", + default=None, + ) + + ARCHIVE_STORAGE_ARCHIVE_BUCKET: str | None = Field( + description="Name of the bucket to store archived workflow logs", + default=None, + ) + + ARCHIVE_STORAGE_EXPORT_BUCKET: str | None = Field( + description="Name of the bucket to store exported workflow runs", + default=None, + ) + + ARCHIVE_STORAGE_ACCESS_KEY: str | None = Field( + description="Access key ID for authenticating with storage", + default=None, + ) + + ARCHIVE_STORAGE_SECRET_KEY: str | None = Field( + description="Secret access key for authenticating with storage", + default=None, + ) + + ARCHIVE_STORAGE_REGION: str = Field( + description="Region for storage (use 'auto' if the provider supports it)", + default="auto", + ) diff --git a/api/configs/middleware/storage/tencent_cos_storage_config.py b/api/configs/middleware/storage/tencent_cos_storage_config.py index e297e748e9..cdd10740f8 100644 --- a/api/configs/middleware/storage/tencent_cos_storage_config.py +++ b/api/configs/middleware/storage/tencent_cos_storage_config.py @@ -31,3 +31,8 @@ class TencentCloudCOSStorageConfig(BaseSettings): description="Protocol scheme for COS requests: 'https' (recommended) or 'http'", default=None, ) + + TENCENT_COS_CUSTOM_DOMAIN: str | None = Field( + description="Tencent Cloud COS custom domain setting", + default=None, + ) diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index df9de825de..c16a23fac8 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -1,62 +1,59 @@ -from flask_restx import Api, Namespace, fields +from __future__ import annotations -from libs.helper import AppIconUrlField +from typing import Any, TypeAlias -parameters__system_parameters = { - "image_file_size_limit": fields.Integer, - "video_file_size_limit": fields.Integer, - "audio_file_size_limit": fields.Integer, - "file_size_limit": fields.Integer, - "workflow_file_upload_limit": fields.Integer, -} +from pydantic import BaseModel, ConfigDict, computed_field + +from core.file import helpers as file_helpers +from models.model import IconType + +JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any] +JSONObject: TypeAlias = dict[str, Any] -def build_system_parameters_model(api_or_ns: Api | Namespace): - """Build the system parameters model for the API or Namespace.""" - return api_or_ns.model("SystemParameters", parameters__system_parameters) +class SystemParameters(BaseModel): + image_file_size_limit: int + video_file_size_limit: int + audio_file_size_limit: int + file_size_limit: int + workflow_file_upload_limit: int -parameters_fields = { - "opening_statement": fields.String, - "suggested_questions": fields.Raw, - "suggested_questions_after_answer": fields.Raw, - "speech_to_text": fields.Raw, - "text_to_speech": fields.Raw, - "retriever_resource": fields.Raw, - "annotation_reply": fields.Raw, - "more_like_this": fields.Raw, - "user_input_form": fields.Raw, - "sensitive_word_avoidance": fields.Raw, - "file_upload": fields.Raw, - "system_parameters": fields.Nested(parameters__system_parameters), -} +class Parameters(BaseModel): + opening_statement: str | None = None + suggested_questions: list[str] + suggested_questions_after_answer: JSONObject + speech_to_text: JSONObject + text_to_speech: JSONObject + retriever_resource: JSONObject + annotation_reply: JSONObject + more_like_this: JSONObject + user_input_form: list[JSONObject] + sensitive_word_avoidance: JSONObject + file_upload: JSONObject + system_parameters: SystemParameters -def build_parameters_model(api_or_ns: Api | Namespace): - """Build the parameters model for the API or Namespace.""" - copied_fields = parameters_fields.copy() - copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns)) - return api_or_ns.model("Parameters", copied_fields) +class Site(BaseModel): + model_config = ConfigDict(from_attributes=True) + title: str + chat_color_theme: str | None = None + chat_color_theme_inverted: bool + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + description: str | None = None + copyright: str | None = None + privacy_policy: str | None = None + custom_disclaimer: str | None = None + default_language: str + show_workflow_steps: bool + use_icon_as_answer_icon: bool -site_fields = { - "title": fields.String, - "chat_color_theme": fields.String, - "chat_color_theme_inverted": fields.Boolean, - "icon_type": fields.String, - "icon": fields.String, - "icon_background": fields.String, - "icon_url": AppIconUrlField, - "description": fields.String, - "copyright": fields.String, - "privacy_policy": fields.String, - "custom_disclaimer": fields.String, - "default_language": fields.String, - "show_workflow_steps": fields.Boolean, - "use_icon_as_answer_icon": fields.Boolean, -} - - -def build_site_model(api_or_ns: Api | Namespace): - """Build the site model for the API or Namespace.""" - return api_or_ns.model("Site", site_fields) + @computed_field(return_type=str | None) # type: ignore + @property + def icon_url(self) -> str | None: + if self.icon and self.icon_type == IconType.IMAGE: + return file_helpers.get_signed_file_url(self.icon) + return None diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 62e997dae2..44cf89d6a9 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,3 +1,4 @@ +import re import uuid from typing import Literal @@ -73,6 +74,48 @@ class AppListQuery(BaseModel): raise ValueError("Invalid UUID format in tag_ids.") from exc +# XSS prevention: patterns that could lead to XSS attacks +# Includes: script tags, iframe tags, javascript: protocol, SVG with onload, etc. +_XSS_PATTERNS = [ + r"]*>.*?", # Script tags + r"]*?(?:/>|>.*?)", # Iframe tags (including self-closing) + r"javascript:", # JavaScript protocol + r"]*?\s+onload\s*=[^>]*>", # SVG with onload handler (attribute-aware, flexible whitespace) + r"<.*?on\s*\w+\s*=", # Event handlers like onclick, onerror, etc. + r"]*(?:\s*/>|>.*?)", # Object tags (opening tag) + r"]*>", # Embed tags (self-closing) + r"]*>", # Link tags with javascript +] + + +def _validate_xss_safe(value: str | None, field_name: str = "Field") -> str | None: + """ + Validate that a string value doesn't contain potential XSS payloads. + + Args: + value: The string value to validate + field_name: Name of the field for error messages + + Returns: + The original value if safe + + Raises: + ValueError: If the value contains XSS patterns + """ + if value is None: + return None + + value_lower = value.lower() + for pattern in _XSS_PATTERNS: + if re.search(pattern, value_lower, re.DOTALL | re.IGNORECASE): + raise ValueError( + f"{field_name} contains invalid characters or patterns. " + "HTML tags, JavaScript, and other potentially dangerous content are not allowed." + ) + + return value + + class CreateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) @@ -81,6 +124,11 @@ class CreateAppPayload(BaseModel): icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") + @field_validator("name", "description", mode="before") + @classmethod + def validate_xss_safe(cls, value: str | None, info) -> str | None: + return _validate_xss_safe(value, info.field_name) + class UpdateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") @@ -91,6 +139,11 @@ class UpdateAppPayload(BaseModel): use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon") max_active_requests: int | None = Field(default=None, description="Maximum active requests") + @field_validator("name", "description", mode="before") + @classmethod + def validate_xss_safe(cls, value: str | None, info) -> str | None: + return _validate_xss_safe(value, info.field_name) + class CopyAppPayload(BaseModel): name: str | None = Field(default=None, description="Name for the copied app") @@ -99,6 +152,11 @@ class CopyAppPayload(BaseModel): icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") + @field_validator("name", "description", mode="before") + @classmethod + def validate_xss_safe(cls, value: str | None, info) -> str | None: + return _validate_xss_safe(value, info.field_name) + class AppExportQuery(BaseModel): include_secret: bool = Field(default=False, description="Include secrets in export") diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 7ad1e56373..c20e83d36f 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -124,7 +124,7 @@ class OAuthCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}") try: - account = _generate_account(provider, user_info) + account, oauth_new_user = _generate_account(provider, user_info) except AccountNotFoundError: return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.") except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError): @@ -159,7 +159,10 @@ class OAuthCallback(Resource): ip_address=extract_remote_ip(request), ) - response = redirect(f"{dify_config.CONSOLE_WEB_URL}") + base_url = dify_config.CONSOLE_WEB_URL + query_char = "&" if "?" in base_url else "?" + target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}" + response = redirect(target_url) set_access_token_to_cookie(request, response, token_pair.access_token) set_refresh_token_to_cookie(request, response, token_pair.refresh_token) @@ -177,9 +180,10 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> return account -def _generate_account(provider: str, user_info: OAuthUserInfo): +def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]: # Get account by openid or email. account = _get_account_by_openid_or_email(provider, user_info) + oauth_new_user = False if account: tenants = TenantService.get_join_tenants(account) @@ -193,6 +197,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): tenant_was_created.send(new_tenant) if not account: + oauth_new_user = True if not FeatureService.get_system_features().is_allow_register: if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email): raise AccountRegisterError( @@ -220,4 +225,4 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): # Link account AccountService.link_account_integrate(provider, user_info.id, account) - return account + return account, oauth_new_user diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index e73abc2555..5a536af6d2 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -3,10 +3,12 @@ import uuid from flask import request from flask_restx import Resource, marshal from pydantic import BaseModel, Field -from sqlalchemy import select +from sqlalchemy import String, cast, func, or_, select +from sqlalchemy.dialects.postgresql import JSONB from werkzeug.exceptions import Forbidden, NotFound import services +from configs import dify_config from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ProviderNotInitializeError @@ -143,7 +145,29 @@ class DatasetDocumentSegmentListApi(Resource): query = query.where(DocumentSegment.hit_count >= hit_count_gte) if keyword: - query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) + # Search in both content and keywords fields + # Use database-specific methods for JSON array search + if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql": + # PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text + keywords_condition = func.array_to_string( + func.array( + select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB))) + .correlate(DocumentSegment) + .scalar_subquery() + ), + ",", + ).ilike(f"%{keyword}%") + else: + # MySQL: Cast JSON to string for pattern matching + # MySQL stores Chinese text directly in JSON without Unicode escaping + keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%") + + query = query.where( + or_( + DocumentSegment.content.ilike(f"%{keyword}%"), + keywords_condition, + ) + ) if args.enabled.lower() != "all": if args.enabled.lower() == "true": diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 9c6b2aedfb..660a4d5aea 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,5 +1,3 @@ -from flask_restx import marshal_with - from controllers.common import fields from controllers.console import console_ns from controllers.console.app.error import AppUnavailableError @@ -13,7 +11,6 @@ from services.app_service import AppService class AppParameterApi(InstalledAppResource): """Resource for app variables.""" - @marshal_with(fields.parameters_fields) def get(self, installed_app: InstalledApp): """Retrieve app parameters.""" app_model = installed_app.app @@ -37,7 +34,8 @@ class AppParameterApi(InstalledAppResource): user_input_form = features_dict.get("user_input_form", []) - return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + return fields.Parameters.model_validate(parameters).model_dump(mode="json") @console_ns.route("/installed-apps//meta", endpoint="installed_app_meta") diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index d51b37a9cd..e9e7b72718 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -20,7 +20,6 @@ from controllers.console.wraps import ( ) from core.db.session_factory import session_factory from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration -from core.helper.tool_provider_cache import ToolProviderListCache from core.mcp.auth.auth_flow import auth, handle_callback from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError from core.mcp.mcp_client import MCPClient @@ -987,9 +986,6 @@ class ToolProviderMCPApi(Resource): # Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is logger.warning("Failed to fetch MCP tools after creation", exc_info=True) - # Final cache invalidation to ensure list views are up to date - ToolProviderListCache.invalidate_cache(tenant_id) - return jsonable_encoder(result) @console_ns.expect(parser_mcp_put) @@ -1036,9 +1032,6 @@ class ToolProviderMCPApi(Resource): validation_result=validation_result, ) - # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations - ToolProviderListCache.invalidate_cache(current_tenant_id) - return {"result": "success"} @console_ns.expect(parser_mcp_delete) @@ -1053,9 +1046,6 @@ class ToolProviderMCPApi(Resource): service = MCPToolManageService(session=session) service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"]) - # Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations - ToolProviderListCache.invalidate_cache(current_tenant_id) - return {"result": "success"} @@ -1106,8 +1096,6 @@ class ToolMCPAuthApi(Resource): credentials=provider_entity.credentials, authed=True, ) - # Invalidate cache after updating credentials - ToolProviderListCache.invalidate_cache(tenant_id) return {"result": "success"} except MCPAuthError as e: try: @@ -1121,22 +1109,16 @@ class ToolMCPAuthApi(Resource): with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) response = service.execute_auth_actions(auth_result) - # Invalidate cache after auth actions may have updated provider state - ToolProviderListCache.invalidate_cache(tenant_id) return response except MCPRefreshTokenError as e: with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) - # Invalidate cache after clearing credentials - ToolProviderListCache.invalidate_cache(tenant_id) raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e except (MCPError, ValueError) as e: with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) - # Invalidate cache after clearing credentials - ToolProviderListCache.invalidate_cache(tenant_id) raise ValueError(f"Failed to connect to MCP server: {e}") from e diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 63c373b50f..85ac9336d6 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -1,7 +1,7 @@ from typing import Literal from flask import request -from flask_restx import Api, Namespace, Resource, fields +from flask_restx import Namespace, Resource, fields from flask_restx.api import HTTPStatus from pydantic import BaseModel, Field @@ -92,7 +92,7 @@ annotation_list_fields = { } -def build_annotation_list_model(api_or_ns: Api | Namespace): +def build_annotation_list_model(api_or_ns: Namespace): """Build the annotation list model for the API or Namespace.""" copied_annotation_list_fields = annotation_list_fields.copy() copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns))) diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 25d7ccccec..562f5e33cc 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,6 +1,6 @@ from flask_restx import Resource -from controllers.common.fields import build_parameters_model +from controllers.common.fields import Parameters from controllers.service_api import service_api_ns from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.wraps import validate_app_token @@ -23,7 +23,6 @@ class AppParameterApi(Resource): } ) @validate_app_token - @service_api_ns.marshal_with(build_parameters_model(service_api_ns)) def get(self, app_model: App): """Retrieve app parameters. @@ -45,7 +44,8 @@ class AppParameterApi(Resource): user_input_form = features_dict.get("user_input_form", []) - return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + return Parameters.model_validate(parameters).model_dump(mode="json") @service_api_ns.route("/meta") diff --git a/api/controllers/service_api/app/site.py b/api/controllers/service_api/app/site.py index 9f8324a84e..8b47a887bb 100644 --- a/api/controllers/service_api/app/site.py +++ b/api/controllers/service_api/app/site.py @@ -1,7 +1,7 @@ from flask_restx import Resource from werkzeug.exceptions import Forbidden -from controllers.common.fields import build_site_model +from controllers.common.fields import Site as SiteResponse from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token from extensions.ext_database import db @@ -23,7 +23,6 @@ class AppSiteApi(Resource): } ) @validate_app_token - @service_api_ns.marshal_with(build_site_model(service_api_ns)) def get(self, app_model: App): """Retrieve app site info. @@ -38,4 +37,4 @@ class AppSiteApi(Resource): if app_model.tenant.status == TenantStatus.ARCHIVE: raise Forbidden() - return site + return SiteResponse.model_validate(site).model_dump(mode="json") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 4964888fd6..6a549fc926 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -3,7 +3,7 @@ from typing import Any, Literal from dateutil.parser import isoparse from flask import request -from flask_restx import Api, Namespace, Resource, fields +from flask_restx import Namespace, Resource, fields from pydantic import BaseModel, Field from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -78,7 +78,7 @@ workflow_run_fields = { } -def build_workflow_run_model(api_or_ns: Api | Namespace): +def build_workflow_run_model(api_or_ns: Namespace): """Build the workflow run model for the API or Namespace.""" return api_or_ns.model("WorkflowRun", workflow_run_fields) diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index db3b93a4dc..62ea532eac 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,7 +1,7 @@ import logging from flask import request -from flask_restx import Resource, marshal_with +from flask_restx import Resource from pydantic import BaseModel, ConfigDict, Field from werkzeug.exceptions import Unauthorized @@ -50,7 +50,6 @@ class AppParameterApi(WebApiResource): 500: "Internal Server Error", } ) - @marshal_with(fields.parameters_fields) def get(self, app_model: App, end_user): """Retrieve app parameters.""" if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: @@ -69,7 +68,8 @@ class AppParameterApi(WebApiResource): user_input_form = features_dict.get("user_input_form", []) - return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + return fields.Parameters.model_validate(parameters).model_dump(mode="json") @web_ns.route("/meta") diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index b32e35d0ca..a55f2d0f5f 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -22,6 +22,7 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine +from core.workflow.nodes.agent.exc import AgentMaxIterationError from models.model import Message logger = logging.getLogger(__name__) @@ -165,6 +166,11 @@ class CotAgentRunner(BaseAgentRunner, ABC): scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" self._agent_scratchpad.append(scratchpad) + # Check if max iteration is reached and model still wants to call tools + if iteration_step == max_iteration_steps and scratchpad.action: + if scratchpad.action.action_name.lower() != "final answer": + raise AgentMaxIterationError(app_config.agent.max_iteration) + # get llm usage if "usage" in usage_dict: if usage_dict["usage"] is not None: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index dcc1326b33..68d14ad027 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -25,6 +25,7 @@ from core.model_runtime.entities.message_entities import ImagePromptMessageConte from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine +from core.workflow.nodes.agent.exc import AgentMaxIterationError from models.model import Message logger = logging.getLogger(__name__) @@ -222,6 +223,10 @@ class FunctionCallAgentRunner(BaseAgentRunner): final_answer += response + "\n" + # Check if max iteration is reached and model still wants to call tools + if iteration_step == max_iteration_steps and tool_calls: + raise AgentMaxIterationError(app_config.agent.max_iteration) + # call tools tool_responses = [] for tool_call_id, tool_call_name, tool_call_args in tool_calls: diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 12431976f0..a123fb0321 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -30,7 +30,6 @@ class SimpleModelProviderEntity(BaseModel): label: I18nObject icon_small: I18nObject | None = None icon_small_dark: I18nObject | None = None - icon_large: I18nObject | None = None supported_model_types: list[ModelType] def __init__(self, provider_entity: ProviderEntity): @@ -44,7 +43,6 @@ class SimpleModelProviderEntity(BaseModel): label=provider_entity.label, icon_small=provider_entity.icon_small, icon_small_dark=provider_entity.icon_small_dark, - icon_large=provider_entity.icon_large, supported_model_types=provider_entity.supported_model_types, ) @@ -94,7 +92,6 @@ class DefaultModelProviderEntity(BaseModel): provider: str label: I18nObject icon_small: I18nObject | None = None - icon_large: I18nObject | None = None supported_model_types: Sequence[ModelType] = [] diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py deleted file mode 100644 index c5447c2b3f..0000000000 --- a/api/core/helper/tool_provider_cache.py +++ /dev/null @@ -1,58 +0,0 @@ -import json -import logging -from typing import Any, cast - -from core.tools.entities.api_entities import ToolProviderTypeApiLiteral -from extensions.ext_redis import redis_client, redis_fallback - -logger = logging.getLogger(__name__) - - -class ToolProviderListCache: - """Cache for tool provider lists""" - - CACHE_TTL = 300 # 5 minutes - - @staticmethod - def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str: - """Generate cache key for tool providers list""" - type_filter = typ or "all" - return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}" - - @staticmethod - @redis_fallback(default_return=None) - def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None: - """Get cached tool providers""" - cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) - cached_data = redis_client.get(cache_key) - if cached_data: - try: - return json.loads(cached_data.decode("utf-8")) - except (json.JSONDecodeError, UnicodeDecodeError): - logger.warning("Failed to decode cached tool providers data") - return None - return None - - @staticmethod - @redis_fallback() - def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]): - """Cache tool providers""" - cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) - redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers)) - - @staticmethod - @redis_fallback() - def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None): - """Invalidate cache for tool providers""" - if typ: - # Invalidate specific type cache - cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) - redis_client.delete(cache_key) - else: - # Invalidate all caches for this tenant - keys = ["builtin", "model", "api", "workflow", "mcp"] - pipeline = redis_client.pipeline() - for key in keys: - cache_key = ToolProviderListCache._generate_cache_key(tenant_id, cast(ToolProviderTypeApiLiteral, key)) - pipeline.delete(cache_key) - pipeline.execute() diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index 648b209ef1..2d88751668 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -100,7 +100,6 @@ class SimpleProviderEntity(BaseModel): label: I18nObject icon_small: I18nObject | None = None icon_small_dark: I18nObject | None = None - icon_large: I18nObject | None = None supported_model_types: Sequence[ModelType] models: list[AIModelEntity] = [] @@ -123,7 +122,6 @@ class ProviderEntity(BaseModel): label: I18nObject description: I18nObject | None = None icon_small: I18nObject | None = None - icon_large: I18nObject | None = None icon_small_dark: I18nObject | None = None background: str | None = None help: ProviderHelpEntity | None = None @@ -157,7 +155,6 @@ class ProviderEntity(BaseModel): provider=self.provider, label=self.label, icon_small=self.icon_small, - icon_large=self.icon_large, supported_model_types=self.supported_model_types, models=self.models, ) diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index b8704ef4ed..12a202ce64 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -285,7 +285,7 @@ class ModelProviderFactory: """ Get provider icon :param provider: provider name - :param icon_type: icon type (icon_small or icon_large) + :param icon_type: icon type (icon_small or icon_small_dark) :param lang: language (zh_Hans or en_US) :return: provider icon """ @@ -309,13 +309,7 @@ class ModelProviderFactory: else: file_name = provider_schema.icon_small_dark.en_US else: - if not provider_schema.icon_large: - raise ValueError(f"Provider {provider} does not have large icon.") - - if lang.lower() == "zh_hans": - file_name = provider_schema.icon_large.zh_Hans - else: - file_name = provider_schema.icon_large.en_US + raise ValueError(f"Unsupported icon type: {icon_type}.") if not file_name: raise ValueError(f"Provider {provider} does not have icon.") diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 6c818bdc8b..10d86d1762 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -331,7 +331,6 @@ class ProviderManager: provider=provider_schema.provider, label=provider_schema.label, icon_small=provider_schema.icon_small, - icon_large=provider_schema.icon_large, supported_model_types=provider_schema.supported_model_types, ), ) diff --git a/api/core/rag/cleaner/clean_processor.py b/api/core/rag/cleaner/clean_processor.py index 9cb009035b..e182c35b99 100644 --- a/api/core/rag/cleaner/clean_processor.py +++ b/api/core/rag/cleaner/clean_processor.py @@ -27,26 +27,44 @@ class CleanProcessor: pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)" text = re.sub(pattern, "", text) - # Remove URL but keep Markdown image URLs - # First, temporarily replace Markdown image URLs with a placeholder - markdown_image_pattern = r"!\[.*?\]\((https?://[^\s)]+)\)" - placeholders: list[str] = [] + # Remove URL but keep Markdown image URLs and link URLs + # Replace the ENTIRE markdown link/image with a single placeholder to protect + # the link text (which might also be a URL) from being removed + markdown_link_pattern = r"\[([^\]]*)\]\((https?://[^)]+)\)" + markdown_image_pattern = r"!\[.*?\]\((https?://[^)]+)\)" + placeholders: list[tuple[str, str, str]] = [] # (type, text, url) - def replace_with_placeholder(match, placeholders=placeholders): + def replace_markdown_with_placeholder(match, placeholders=placeholders): + link_type = "link" + link_text = match.group(1) + url = match.group(2) + placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__" + placeholders.append((link_type, link_text, url)) + return placeholder + + def replace_image_with_placeholder(match, placeholders=placeholders): + link_type = "image" url = match.group(1) - placeholder = f"__MARKDOWN_IMAGE_URL_{len(placeholders)}__" - placeholders.append(url) - return f"![image]({placeholder})" + placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__" + placeholders.append((link_type, "image", url)) + return placeholder - text = re.sub(markdown_image_pattern, replace_with_placeholder, text) + # Protect markdown links first + text = re.sub(markdown_link_pattern, replace_markdown_with_placeholder, text) + # Then protect markdown images + text = re.sub(markdown_image_pattern, replace_image_with_placeholder, text) # Now remove all remaining URLs - url_pattern = r"https?://[^\s)]+" + url_pattern = r"https?://\S+" text = re.sub(url_pattern, "", text) - # Finally, restore the Markdown image URLs - for i, url in enumerate(placeholders): - text = text.replace(f"__MARKDOWN_IMAGE_URL_{i}__", url) + # Restore the Markdown links and images + for i, (link_type, text_or_alt, url) in enumerate(placeholders): + placeholder = f"__MARKDOWN_PLACEHOLDER_{i}__" + if link_type == "link": + text = text.replace(placeholder, f"[{text_or_alt}]({url})") + else: # image + text = text.replace(placeholder, f"![{text_or_alt}]({url})") return text def filter_string(self, text): diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 43912cd75d..8ec1ce6242 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,4 +1,5 @@ import concurrent.futures +import logging from concurrent.futures import ThreadPoolExecutor from typing import Any @@ -36,6 +37,8 @@ default_retrieval_model = { "score_threshold_enabled": False, } +logger = logging.getLogger(__name__) + class RetrievalService: # Cache precompiled regular expressions to avoid repeated compilation @@ -106,7 +109,12 @@ class RetrievalService: ) ) - concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED) + if futures: + for future in concurrent.futures.as_completed(futures, timeout=3600): + if exceptions: + for f in futures: + f.cancel() + break if exceptions: raise ValueError(";\n".join(exceptions)) @@ -210,6 +218,7 @@ class RetrievalService: ) all_documents.extend(documents) except Exception as e: + logger.error(e, exc_info=True) exceptions.append(str(e)) @classmethod @@ -303,6 +312,7 @@ class RetrievalService: else: all_documents.extend(documents) except Exception as e: + logger.error(e, exc_info=True) exceptions.append(str(e)) @classmethod @@ -351,6 +361,7 @@ class RetrievalService: else: all_documents.extend(documents) except Exception as e: + logger.error(e, exc_info=True) exceptions.append(str(e)) @staticmethod @@ -663,7 +674,14 @@ class RetrievalService: document_ids_filter=document_ids_filter, ) ) - concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED) + # Use as_completed for early error propagation - cancel remaining futures on first error + if futures: + for future in concurrent.futures.as_completed(futures, timeout=300): + if future.exception(): + # Cancel remaining futures to avoid unnecessary waiting + for f in futures: + f.cancel() + break if exceptions: raise ValueError(";\n".join(exceptions)) diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 013c287248..6d28ce25bc 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -112,7 +112,7 @@ class ExtractProcessor: if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": - extractor = PdfExtractor(file_path) + extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension in {".md", ".markdown", ".mdx"}: extractor = ( UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key) @@ -148,7 +148,7 @@ class ExtractProcessor: if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": - extractor = PdfExtractor(file_path) + extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension in {".md", ".markdown", ".mdx"}: extractor = MarkdownExtractor(file_path, autodetect_encoding=True) elif file_extension in {".htm", ".html"}: diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 80530d99a6..6aabcac704 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -1,25 +1,57 @@ """Abstract interface for document loader implementations.""" import contextlib +import io +import logging +import uuid from collections.abc import Iterator +import pypdfium2 +import pypdfium2.raw as pdfium_c + +from configs import dify_config from core.rag.extractor.blob.blob import Blob from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document +from extensions.ext_database import db from extensions.ext_storage import storage +from libs.datetime_utils import naive_utc_now +from models.enums import CreatorUserRole +from models.model import UploadFile + +logger = logging.getLogger(__name__) class PdfExtractor(BaseExtractor): - """Load pdf files. - + """ + PdfExtractor is used to extract text and images from PDF files. Args: - file_path: Path to the file to load. + file_path: Path to the PDF file. + tenant_id: Workspace ID. + user_id: ID of the user performing the extraction. + file_cache_key: Optional cache key for the extracted text. """ - def __init__(self, file_path: str, file_cache_key: str | None = None): - """Initialize with file path.""" + # Magic bytes for image format detection: (magic_bytes, extension, mime_type) + IMAGE_FORMATS = [ + (b"\xff\xd8\xff", "jpg", "image/jpeg"), + (b"\x89PNG\r\n\x1a\n", "png", "image/png"), + (b"\x00\x00\x00\x0c\x6a\x50\x20\x20\x0d\x0a\x87\x0a", "jp2", "image/jp2"), + (b"GIF8", "gif", "image/gif"), + (b"BM", "bmp", "image/bmp"), + (b"II*\x00", "tiff", "image/tiff"), + (b"MM\x00*", "tiff", "image/tiff"), + (b"II+\x00", "tiff", "image/tiff"), + (b"MM\x00+", "tiff", "image/tiff"), + ] + MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS) + + def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None): + """Initialize PdfExtractor.""" self._file_path = file_path + self._tenant_id = tenant_id + self._user_id = user_id self._file_cache_key = file_cache_key def extract(self) -> list[Document]: @@ -50,7 +82,6 @@ class PdfExtractor(BaseExtractor): def parse(self, blob: Blob) -> Iterator[Document]: """Lazily parse the blob.""" - import pypdfium2 # type: ignore with blob.as_bytes_io() as file_path: pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True) @@ -59,8 +90,87 @@ class PdfExtractor(BaseExtractor): text_page = page.get_textpage() content = text_page.get_text_range() text_page.close() + + image_content = self._extract_images(page) + if image_content: + content += "\n" + image_content + page.close() metadata = {"source": blob.source, "page": page_number} yield Document(page_content=content, metadata=metadata) finally: pdf_reader.close() + + def _extract_images(self, page) -> str: + """ + Extract images from a PDF page, save them to storage and database, + and return markdown image links. + + Args: + page: pypdfium2 page object. + + Returns: + Markdown string containing links to the extracted images. + """ + image_content = [] + upload_files = [] + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + + try: + image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,)) + for obj in image_objects: + try: + # Extract image bytes + img_byte_arr = io.BytesIO() + # Extract DCTDecode (JPEG) and JPXDecode (JPEG 2000) images directly + # Fallback to png for other formats + obj.extract(img_byte_arr, fb_format="png") + img_bytes = img_byte_arr.getvalue() + + if not img_bytes: + continue + + header = img_bytes[: self.MAX_MAGIC_LEN] + image_ext = None + mime_type = None + for magic, ext, mime in self.IMAGE_FORMATS: + if header.startswith(magic): + image_ext = ext + mime_type = mime + break + + if not image_ext or not mime_type: + continue + + file_uuid = str(uuid.uuid4()) + file_key = "image_files/" + self._tenant_id + "/" + file_uuid + "." + image_ext + + storage.save(file_key, img_bytes) + + # save file to db + upload_file = UploadFile( + tenant_id=self._tenant_id, + storage_type=dify_config.STORAGE_TYPE, + key=file_key, + name=file_key, + size=len(img_bytes), + extension=image_ext, + mime_type=mime_type, + created_by=self._user_id, + created_by_role=CreatorUserRole.ACCOUNT, + created_at=naive_utc_now(), + used=True, + used_by=self._user_id, + used_at=naive_utc_now(), + ) + upload_files.append(upload_file) + image_content.append(f"![image]({base_url}/files/{upload_file.id}/file-preview)") + except Exception as e: + logger.warning("Failed to extract image from PDF: %s", e) + continue + except Exception as e: + logger.warning("Failed to get objects from PDF page: %s", e) + if upload_files: + db.session.add_all(upload_files) + db.session.commit() + return "\n".join(image_content) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 2c3fc5ab75..4ec59940e3 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -516,6 +516,9 @@ class DatasetRetrieval: ].embedding_model_provider weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model with measure_time() as timer: + cancel_event = threading.Event() + thread_exceptions: list[Exception] = [] + if query: query_thread = threading.Thread( target=self._multiple_retrieve_thread, @@ -534,6 +537,8 @@ class DatasetRetrieval: "score_threshold": score_threshold, "query": query, "attachment_id": None, + "cancel_event": cancel_event, + "thread_exceptions": thread_exceptions, }, ) all_threads.append(query_thread) @@ -557,12 +562,25 @@ class DatasetRetrieval: "score_threshold": score_threshold, "query": None, "attachment_id": attachment_id, + "cancel_event": cancel_event, + "thread_exceptions": thread_exceptions, }, ) all_threads.append(attachment_thread) attachment_thread.start() - for thread in all_threads: - thread.join() + + # Poll threads with short timeout to detect errors quickly (fail-fast) + while any(t.is_alive() for t in all_threads): + for thread in all_threads: + thread.join(timeout=0.1) + if thread_exceptions: + cancel_event.set() + break + if thread_exceptions: + break + + if thread_exceptions: + raise thread_exceptions[0] self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id) if all_documents: @@ -1404,40 +1422,53 @@ class DatasetRetrieval: score_threshold: float, query: str | None, attachment_id: str | None, + cancel_event: threading.Event | None = None, + thread_exceptions: list[Exception] | None = None, ): - with flask_app.app_context(): - threads = [] - all_documents_item: list[Document] = [] - index_type = None - for dataset in available_datasets: - index_type = dataset.indexing_technique - document_ids_filter = None - if dataset.provider != "external": - if metadata_condition and not metadata_filter_document_ids: - continue - if metadata_filter_document_ids: - document_ids = metadata_filter_document_ids.get(dataset.id, []) - if document_ids: - document_ids_filter = document_ids - else: + try: + with flask_app.app_context(): + threads = [] + all_documents_item: list[Document] = [] + index_type = None + for dataset in available_datasets: + # Check for cancellation signal + if cancel_event and cancel_event.is_set(): + break + index_type = dataset.indexing_technique + document_ids_filter = None + if dataset.provider != "external": + if metadata_condition and not metadata_filter_document_ids: continue - retrieval_thread = threading.Thread( - target=self._retriever, - kwargs={ - "flask_app": flask_app, - "dataset_id": dataset.id, - "query": query, - "top_k": top_k, - "all_documents": all_documents_item, - "document_ids_filter": document_ids_filter, - "metadata_condition": metadata_condition, - "attachment_ids": [attachment_id] if attachment_id else None, - }, - ) - threads.append(retrieval_thread) - retrieval_thread.start() - for thread in threads: - thread.join() + if metadata_filter_document_ids: + document_ids = metadata_filter_document_ids.get(dataset.id, []) + if document_ids: + document_ids_filter = document_ids + else: + continue + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": flask_app, + "dataset_id": dataset.id, + "query": query, + "top_k": top_k, + "all_documents": all_documents_item, + "document_ids_filter": document_ids_filter, + "metadata_condition": metadata_condition, + "attachment_ids": [attachment_id] if attachment_id else None, + }, + ) + threads.append(retrieval_thread) + retrieval_thread.start() + + # Poll threads with short timeout to respond quickly to cancellation + while any(t.is_alive() for t in threads): + for thread in threads: + thread.join(timeout=0.1) + if cancel_event and cancel_event.is_set(): + break + if cancel_event and cancel_event.is_set(): + break if reranking_enable: # do rerank for searched documents @@ -1470,3 +1501,8 @@ class DatasetRetrieval: all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item if all_documents_item: all_documents.extend(all_documents_item) + except Exception as e: + if cancel_event: + cancel_event.set() + if thread_exceptions is not None: + thread_exceptions.append(e) diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 3486182192..584975de05 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -378,7 +378,7 @@ class ApiBasedToolSchemaParser: @staticmethod def auto_parse_to_tool_bundle( content: str, extra_info: dict | None = None, warning: dict | None = None - ) -> tuple[list[ApiToolBundle], str]: + ) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]: """ auto parse to tool bundle diff --git a/api/core/tools/utils/text_processing_utils.py b/api/core/tools/utils/text_processing_utils.py index 0f9a91a111..4bfaa5e49b 100644 --- a/api/core/tools/utils/text_processing_utils.py +++ b/api/core/tools/utils/text_processing_utils.py @@ -4,6 +4,7 @@ import re def remove_leading_symbols(text: str) -> str: """ Remove leading punctuation or symbols from the given text. + Preserves markdown links like [text](url) at the start. Args: text (str): The input text to process. @@ -11,6 +12,11 @@ def remove_leading_symbols(text: str) -> str: Returns: str: The text with leading punctuation or symbols removed. """ + # Check if text starts with a markdown link - preserve it + markdown_link_pattern = r"^\[([^\]]+)\]\((https?://[^)]+)\)" + if re.match(markdown_link_pattern, text): + return text + # Match Unicode ranges for punctuation and symbols # FIXME this pattern is confused quick fix for #11868 maybe refactor it later pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+' diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 2bd973f831..5422f5250b 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -54,7 +54,6 @@ class WorkflowToolProviderController(ToolProviderController): raise ValueError("app not found") user = session.get(Account, db_provider.user_id) if db_provider.user_id else None - controller = WorkflowToolProviderController( entity=ToolProviderEntity( identity=ToolProviderIdentity( @@ -67,7 +66,7 @@ class WorkflowToolProviderController(ToolProviderController): credentials_schema=[], plugin_id=None, ), - provider_id="", + provider_id=db_provider.id, ) controller.tools = [ diff --git a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py b/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py index 78f8ecdcdf..b9c9243963 100644 --- a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py +++ b/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py @@ -60,6 +60,7 @@ class SkipPropagator: if edge_states["has_taken"]: # Enqueue node self._state_manager.enqueue_node(downstream_node_id) + self._state_manager.start_execution(downstream_node_id) return # All edges are skipped, propagate skip to this node diff --git a/api/core/workflow/nodes/agent/exc.py b/api/core/workflow/nodes/agent/exc.py index 944f5f0b20..ba2c83d8a6 100644 --- a/api/core/workflow/nodes/agent/exc.py +++ b/api/core/workflow/nodes/agent/exc.py @@ -119,3 +119,14 @@ class AgentVariableTypeError(AgentNodeError): self.expected_type = expected_type self.actual_type = actual_type super().__init__(message) + + +class AgentMaxIterationError(AgentNodeError): + """Exception raised when the agent exceeds the maximum iteration limit.""" + + def __init__(self, max_iteration: int): + self.max_iteration = max_iteration + super().__init__( + f"Agent exceeded the maximum iteration limit of {max_iteration}. " + f"The agent was unable to complete the task within the allowed number of iterations." + ) diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 5cf4984709..764df5d178 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -12,9 +12,8 @@ from dify_app import DifyApp def _get_celery_ssl_options() -> dict[str, Any] | None: """Get SSL configuration for Celery broker/backend connections.""" - # Use REDIS_USE_SSL for consistency with the main Redis client # Only apply SSL if we're using Redis as broker/backend - if not dify_config.REDIS_USE_SSL: + if not dify_config.BROKER_USE_SSL: return None # Check if Celery is actually using Redis diff --git a/api/extensions/storage/tencent_cos_storage.py b/api/extensions/storage/tencent_cos_storage.py index ea5d982efc..cf092c6973 100644 --- a/api/extensions/storage/tencent_cos_storage.py +++ b/api/extensions/storage/tencent_cos_storage.py @@ -13,12 +13,20 @@ class TencentCosStorage(BaseStorage): super().__init__() self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME - config = CosConfig( - Region=dify_config.TENCENT_COS_REGION, - SecretId=dify_config.TENCENT_COS_SECRET_ID, - SecretKey=dify_config.TENCENT_COS_SECRET_KEY, - Scheme=dify_config.TENCENT_COS_SCHEME, - ) + if dify_config.TENCENT_COS_CUSTOM_DOMAIN: + config = CosConfig( + Domain=dify_config.TENCENT_COS_CUSTOM_DOMAIN, + SecretId=dify_config.TENCENT_COS_SECRET_ID, + SecretKey=dify_config.TENCENT_COS_SECRET_KEY, + Scheme=dify_config.TENCENT_COS_SCHEME, + ) + else: + config = CosConfig( + Region=dify_config.TENCENT_COS_REGION, + SecretId=dify_config.TENCENT_COS_SECRET_ID, + SecretKey=dify_config.TENCENT_COS_SECRET_KEY, + Scheme=dify_config.TENCENT_COS_SCHEME, + ) self.client = CosS3Client(config) def save(self, filename, data): diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index 38835d5ac7..e69306dcb2 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields from libs.helper import TimestampField @@ -12,7 +12,7 @@ annotation_fields = { } -def build_annotation_model(api_or_ns: Api | Namespace): +def build_annotation_model(api_or_ns: Namespace): """Build the annotation model for the API or Namespace.""" return api_or_ns.model("Annotation", annotation_fields) diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index ecc267cf38..e4ca2e7a42 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields from fields.member_fields import simple_account_fields from libs.helper import TimestampField @@ -46,7 +46,7 @@ message_file_fields = { } -def build_message_file_model(api_or_ns: Api | Namespace): +def build_message_file_model(api_or_ns: Namespace): """Build the message file fields for the API or Namespace.""" return api_or_ns.model("MessageFile", message_file_fields) @@ -217,7 +217,7 @@ conversation_infinite_scroll_pagination_fields = { } -def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): +def build_conversation_infinite_scroll_pagination_model(api_or_ns: Namespace): """Build the conversation infinite scroll pagination model for the API or Namespace.""" simple_conversation_model = build_simple_conversation_model(api_or_ns) @@ -226,11 +226,11 @@ def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespa return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields) -def build_conversation_delete_model(api_or_ns: Api | Namespace): +def build_conversation_delete_model(api_or_ns: Namespace): """Build the conversation delete model for the API or Namespace.""" return api_or_ns.model("ConversationDelete", conversation_delete_fields) -def build_simple_conversation_model(api_or_ns: Api | Namespace): +def build_simple_conversation_model(api_or_ns: Namespace): """Build the simple conversation model for the API or Namespace.""" return api_or_ns.model("SimpleConversation", simple_conversation_fields) diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index 7d5e311591..c55014a368 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields from libs.helper import TimestampField @@ -29,12 +29,12 @@ conversation_variable_infinite_scroll_pagination_fields = { } -def build_conversation_variable_model(api_or_ns: Api | Namespace): +def build_conversation_variable_model(api_or_ns: Namespace): """Build the conversation variable model for the API or Namespace.""" return api_or_ns.model("ConversationVariable", conversation_variable_fields) -def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): +def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Namespace): """Build the conversation variable infinite scroll pagination model for the API or Namespace.""" # Build the nested variable model first conversation_variable_model = build_conversation_variable_model(api_or_ns) diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index ea43e3b5fd..5389b0213a 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields simple_end_user_fields = { "id": fields.String, @@ -8,5 +8,5 @@ simple_end_user_fields = { } -def build_simple_end_user_model(api_or_ns: Api | Namespace): +def build_simple_end_user_model(api_or_ns: Namespace): return api_or_ns.model("SimpleEndUser", simple_end_user_fields) diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index a707500445..70138404c7 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields from libs.helper import TimestampField @@ -14,7 +14,7 @@ upload_config_fields = { } -def build_upload_config_model(api_or_ns: Api | Namespace): +def build_upload_config_model(api_or_ns: Namespace): """Build the upload config model for the API or Namespace. Args: @@ -39,7 +39,7 @@ file_fields = { } -def build_file_model(api_or_ns: Api | Namespace): +def build_file_model(api_or_ns: Namespace): """Build the file model for the API or Namespace. Args: @@ -57,7 +57,7 @@ remote_file_info_fields = { } -def build_remote_file_info_model(api_or_ns: Api | Namespace): +def build_remote_file_info_model(api_or_ns: Namespace): """Build the remote file info model for the API or Namespace. Args: @@ -81,7 +81,7 @@ file_fields_with_signed_url = { } -def build_file_with_signed_url_model(api_or_ns: Api | Namespace): +def build_file_with_signed_url_model(api_or_ns: Namespace): """Build the file with signed URL model for the API or Namespace. Args: diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 08e38a6931..25160927e6 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields from libs.helper import AvatarUrlField, TimestampField @@ -9,7 +9,7 @@ simple_account_fields = { } -def build_simple_account_model(api_or_ns: Api | Namespace): +def build_simple_account_model(api_or_ns: Namespace): return api_or_ns.model("SimpleAccount", simple_account_fields) diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index a419da2e18..151ff6f826 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields from fields.conversation_fields import message_file_fields from libs.helper import TimestampField @@ -10,7 +10,7 @@ feedback_fields = { } -def build_feedback_model(api_or_ns: Api | Namespace): +def build_feedback_model(api_or_ns: Namespace): """Build the feedback model for the API or Namespace.""" return api_or_ns.model("Feedback", feedback_fields) @@ -30,7 +30,7 @@ agent_thought_fields = { } -def build_agent_thought_model(api_or_ns: Api | Namespace): +def build_agent_thought_model(api_or_ns: Namespace): """Build the agent thought model for the API or Namespace.""" return api_or_ns.model("AgentThought", agent_thought_fields) diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index d5b7c86a04..e359a4408c 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields dataset_tag_fields = { "id": fields.String, @@ -8,5 +8,5 @@ dataset_tag_fields = { } -def build_dataset_tag_fields(api_or_ns: Api | Namespace): +def build_dataset_tag_fields(api_or_ns: Namespace): return api_or_ns.model("DataSetTag", dataset_tag_fields) diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index 4cbdf6f0ca..0ebc03a98c 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields from fields.member_fields import build_simple_account_model, simple_account_fields @@ -17,7 +17,7 @@ workflow_app_log_partial_fields = { } -def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace): +def build_workflow_app_log_partial_model(api_or_ns: Namespace): """Build the workflow app log partial model for the API or Namespace.""" workflow_run_model = build_workflow_run_for_log_model(api_or_ns) simple_account_model = build_simple_account_model(api_or_ns) @@ -43,7 +43,7 @@ workflow_app_log_pagination_fields = { } -def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace): +def build_workflow_app_log_pagination_model(api_or_ns: Namespace): """Build the workflow app log pagination model for the API or Namespace.""" # Build the nested partial model first workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns) diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 821ce62ecc..476025064f 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -1,4 +1,4 @@ -from flask_restx import Api, Namespace, fields +from flask_restx import Namespace, fields from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields @@ -19,7 +19,7 @@ workflow_run_for_log_fields = { } -def build_workflow_run_for_log_model(api_or_ns: Api | Namespace): +def build_workflow_run_for_log_model(api_or_ns: Namespace): return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields) diff --git a/api/libs/archive_storage.py b/api/libs/archive_storage.py new file mode 100644 index 0000000000..f84d226447 --- /dev/null +++ b/api/libs/archive_storage.py @@ -0,0 +1,347 @@ +""" +Archive Storage Client for S3-compatible storage. + +This module provides a dedicated storage client for archiving or exporting logs +to S3-compatible object storage. +""" + +import base64 +import datetime +import gzip +import hashlib +import logging +from collections.abc import Generator +from typing import Any, cast + +import boto3 +import orjson +from botocore.client import Config +from botocore.exceptions import ClientError + +from configs import dify_config + +logger = logging.getLogger(__name__) + + +class ArchiveStorageError(Exception): + """Base exception for archive storage operations.""" + + pass + + +class ArchiveStorageNotConfiguredError(ArchiveStorageError): + """Raised when archive storage is not properly configured.""" + + pass + + +class ArchiveStorage: + """ + S3-compatible storage client for archiving or exporting. + + This client provides methods for storing and retrieving archived data in JSONL+gzip format. + """ + + def __init__(self, bucket: str): + if not dify_config.ARCHIVE_STORAGE_ENABLED: + raise ArchiveStorageNotConfiguredError("Archive storage is not enabled") + + if not bucket: + raise ArchiveStorageNotConfiguredError("Archive storage bucket is not configured") + if not all( + [ + dify_config.ARCHIVE_STORAGE_ENDPOINT, + bucket, + dify_config.ARCHIVE_STORAGE_ACCESS_KEY, + dify_config.ARCHIVE_STORAGE_SECRET_KEY, + ] + ): + raise ArchiveStorageNotConfiguredError( + "Archive storage configuration is incomplete. " + "Required: ARCHIVE_STORAGE_ENDPOINT, ARCHIVE_STORAGE_ACCESS_KEY, " + "ARCHIVE_STORAGE_SECRET_KEY, and a bucket name" + ) + + self.bucket = bucket + self.client = boto3.client( + "s3", + endpoint_url=dify_config.ARCHIVE_STORAGE_ENDPOINT, + aws_access_key_id=dify_config.ARCHIVE_STORAGE_ACCESS_KEY, + aws_secret_access_key=dify_config.ARCHIVE_STORAGE_SECRET_KEY, + region_name=dify_config.ARCHIVE_STORAGE_REGION, + config=Config(s3={"addressing_style": "path"}), + ) + + # Verify bucket accessibility + try: + self.client.head_bucket(Bucket=self.bucket) + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code") + if error_code == "404": + raise ArchiveStorageNotConfiguredError(f"Archive bucket '{self.bucket}' does not exist") + elif error_code == "403": + raise ArchiveStorageNotConfiguredError(f"Access denied to archive bucket '{self.bucket}'") + else: + raise ArchiveStorageError(f"Failed to access archive bucket: {e}") + + def put_object(self, key: str, data: bytes) -> str: + """ + Upload an object to the archive storage. + + Args: + key: Object key (path) within the bucket + data: Binary data to upload + + Returns: + MD5 checksum of the uploaded data + + Raises: + ArchiveStorageError: If upload fails + """ + checksum = hashlib.md5(data).hexdigest() + try: + self.client.put_object( + Bucket=self.bucket, + Key=key, + Body=data, + ContentMD5=self._content_md5(data), + ) + logger.debug("Uploaded object: %s (size=%d, checksum=%s)", key, len(data), checksum) + return checksum + except ClientError as e: + raise ArchiveStorageError(f"Failed to upload object '{key}': {e}") + + def get_object(self, key: str) -> bytes: + """ + Download an object from the archive storage. + + Args: + key: Object key (path) within the bucket + + Returns: + Binary data of the object + + Raises: + ArchiveStorageError: If download fails + FileNotFoundError: If object does not exist + """ + try: + response = self.client.get_object(Bucket=self.bucket, Key=key) + return response["Body"].read() + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code") + if error_code == "NoSuchKey": + raise FileNotFoundError(f"Archive object not found: {key}") + raise ArchiveStorageError(f"Failed to download object '{key}': {e}") + + def get_object_stream(self, key: str) -> Generator[bytes, None, None]: + """ + Stream an object from the archive storage. + + Args: + key: Object key (path) within the bucket + + Yields: + Chunks of binary data + + Raises: + ArchiveStorageError: If download fails + FileNotFoundError: If object does not exist + """ + try: + response = self.client.get_object(Bucket=self.bucket, Key=key) + yield from response["Body"].iter_chunks() + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code") + if error_code == "NoSuchKey": + raise FileNotFoundError(f"Archive object not found: {key}") + raise ArchiveStorageError(f"Failed to stream object '{key}': {e}") + + def object_exists(self, key: str) -> bool: + """ + Check if an object exists in the archive storage. + + Args: + key: Object key (path) within the bucket + + Returns: + True if object exists, False otherwise + """ + try: + self.client.head_object(Bucket=self.bucket, Key=key) + return True + except ClientError: + return False + + def delete_object(self, key: str) -> None: + """ + Delete an object from the archive storage. + + Args: + key: Object key (path) within the bucket + + Raises: + ArchiveStorageError: If deletion fails + """ + try: + self.client.delete_object(Bucket=self.bucket, Key=key) + logger.debug("Deleted object: %s", key) + except ClientError as e: + raise ArchiveStorageError(f"Failed to delete object '{key}': {e}") + + def generate_presigned_url(self, key: str, expires_in: int = 3600) -> str: + """ + Generate a pre-signed URL for downloading an object. + + Args: + key: Object key (path) within the bucket + expires_in: URL validity duration in seconds (default: 1 hour) + + Returns: + Pre-signed URL string. + + Raises: + ArchiveStorageError: If generation fails + """ + try: + return self.client.generate_presigned_url( + ClientMethod="get_object", + Params={"Bucket": self.bucket, "Key": key}, + ExpiresIn=expires_in, + ) + except ClientError as e: + raise ArchiveStorageError(f"Failed to generate pre-signed URL for '{key}': {e}") + + def list_objects(self, prefix: str) -> list[str]: + """ + List objects under a given prefix. + + Args: + prefix: Object key prefix to filter by + + Returns: + List of object keys matching the prefix + """ + keys = [] + paginator = self.client.get_paginator("list_objects_v2") + + try: + for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix): + for obj in page.get("Contents", []): + keys.append(obj["Key"]) + except ClientError as e: + raise ArchiveStorageError(f"Failed to list objects with prefix '{prefix}': {e}") + + return keys + + @staticmethod + def _content_md5(data: bytes) -> str: + """Calculate base64-encoded MD5 for Content-MD5 header.""" + return base64.b64encode(hashlib.md5(data).digest()).decode() + + @staticmethod + def serialize_to_jsonl_gz(records: list[dict[str, Any]]) -> bytes: + """ + Serialize records to gzipped JSONL format. + + Args: + records: List of dictionaries to serialize + + Returns: + Gzipped JSONL bytes + """ + lines = [] + for record in records: + # Convert datetime objects to ISO format strings + serialized = ArchiveStorage._serialize_record(record) + lines.append(orjson.dumps(serialized)) + + jsonl_content = b"\n".join(lines) + if jsonl_content: + jsonl_content += b"\n" + + return gzip.compress(jsonl_content) + + @staticmethod + def deserialize_from_jsonl_gz(data: bytes) -> list[dict[str, Any]]: + """ + Deserialize gzipped JSONL data to records. + + Args: + data: Gzipped JSONL bytes + + Returns: + List of dictionaries + """ + jsonl_content = gzip.decompress(data) + records = [] + + for line in jsonl_content.splitlines(): + if line: + records.append(orjson.loads(line)) + + return records + + @staticmethod + def _serialize_record(record: dict[str, Any]) -> dict[str, Any]: + """Serialize a single record, converting special types.""" + + def _serialize(item: Any) -> Any: + if isinstance(item, datetime.datetime): + return item.isoformat() + if isinstance(item, dict): + return {key: _serialize(value) for key, value in item.items()} + if isinstance(item, list): + return [_serialize(value) for value in item] + return item + + return cast(dict[str, Any], _serialize(record)) + + @staticmethod + def compute_checksum(data: bytes) -> str: + """Compute MD5 checksum of data.""" + return hashlib.md5(data).hexdigest() + + +# Singleton instance (lazy initialization) +_archive_storage: ArchiveStorage | None = None +_export_storage: ArchiveStorage | None = None + + +def get_archive_storage() -> ArchiveStorage: + """ + Get the archive storage singleton instance. + + Returns: + ArchiveStorage instance + + Raises: + ArchiveStorageNotConfiguredError: If archive storage is not configured + """ + global _archive_storage + if _archive_storage is None: + archive_bucket = dify_config.ARCHIVE_STORAGE_ARCHIVE_BUCKET + if not archive_bucket: + raise ArchiveStorageNotConfiguredError( + "Archive storage bucket is not configured. Required: ARCHIVE_STORAGE_ARCHIVE_BUCKET" + ) + _archive_storage = ArchiveStorage(bucket=archive_bucket) + return _archive_storage + + +def get_export_storage() -> ArchiveStorage: + """ + Get the export storage singleton instance. + + Returns: + ArchiveStorage instance + """ + global _export_storage + if _export_storage is None: + export_bucket = dify_config.ARCHIVE_STORAGE_EXPORT_BUCKET + if not export_bucket: + raise ArchiveStorageNotConfiguredError( + "Archive export bucket is not configured. Required: ARCHIVE_STORAGE_EXPORT_BUCKET" + ) + _export_storage = ArchiveStorage(bucket=export_bucket) + return _export_storage diff --git a/api/migrations/versions/00bacef91f18_rename_api_provider_description.py b/api/migrations/versions/00bacef91f18_rename_api_provider_description.py index 17ed067d81..657d28f896 100644 --- a/api/migrations/versions/00bacef91f18_rename_api_provider_description.py +++ b/api/migrations/versions/00bacef91f18_rename_api_provider_description.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '00bacef91f18' down_revision = '8ec536f3c800' @@ -23,31 +20,17 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('description', sa.Text(), nullable=False)) - batch_op.drop_column('description_str') - else: - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False)) - batch_op.drop_column('description_str') + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False)) + batch_op.drop_column('description_str') # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('description_str', sa.TEXT(), autoincrement=False, nullable=False)) - batch_op.drop_column('description') - else: - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False)) - batch_op.drop_column('description') + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False)) + batch_op.drop_column('description') # ### end Alembic commands ### diff --git a/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py b/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py index ed70bf5d08..912d9dbfa4 100644 --- a/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py +++ b/api/migrations/versions/114eed84c228_remove_tool_id_from_model_invoke.py @@ -7,14 +7,10 @@ Create Date: 2024-01-10 04:40:57.257824 """ import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects import postgresql import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '114eed84c228' down_revision = 'c71211c8f604' @@ -32,13 +28,7 @@ def upgrade(): def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: - batch_op.add_column(sa.Column('tool_id', postgresql.UUID(), autoincrement=False, nullable=False)) - else: - with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: - batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False)) + with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: + batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False)) # ### end Alembic commands ### diff --git a/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py b/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py index 509bd5d0e8..0ca905129d 100644 --- a/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py +++ b/api/migrations/versions/161cadc1af8d_add_dataset_permission_tenant_id.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '161cadc1af8d' down_revision = '7e6a8693e07a' @@ -23,16 +20,9 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: - # Step 1: Add column without NOT NULL constraint - op.add_column('dataset_permissions', sa.Column('tenant_id', sa.UUID(), nullable=False)) - else: - with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: - # Step 1: Add column without NOT NULL constraint - op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False)) + with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: + # Step 1: Add column without NOT NULL constraint + op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False)) # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py b/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py index 0767b725f6..be1b42f883 100644 --- a/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py +++ b/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py @@ -9,11 +9,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" -import sqlalchemy as sa -from sqlalchemy.dialects import postgresql - # revision identifiers, used by Alembic. revision = '6af6a521a53e' down_revision = 'd57ba9ebb251' @@ -23,58 +18,30 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: - batch_op.alter_column('document_id', - existing_type=sa.UUID(), - nullable=True) - batch_op.alter_column('data_source_type', - existing_type=sa.TEXT(), - nullable=True) - batch_op.alter_column('segment_id', - existing_type=sa.UUID(), - nullable=True) - else: - with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: - batch_op.alter_column('document_id', - existing_type=models.types.StringUUID(), - nullable=True) - batch_op.alter_column('data_source_type', - existing_type=models.types.LongText(), - nullable=True) - batch_op.alter_column('segment_id', - existing_type=models.types.StringUUID(), - nullable=True) + with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: + batch_op.alter_column('document_id', + existing_type=models.types.StringUUID(), + nullable=True) + batch_op.alter_column('data_source_type', + existing_type=models.types.LongText(), + nullable=True) + batch_op.alter_column('segment_id', + existing_type=models.types.StringUUID(), + nullable=True) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: - batch_op.alter_column('segment_id', - existing_type=sa.UUID(), - nullable=False) - batch_op.alter_column('data_source_type', - existing_type=sa.TEXT(), - nullable=False) - batch_op.alter_column('document_id', - existing_type=sa.UUID(), - nullable=False) - else: - with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: - batch_op.alter_column('segment_id', - existing_type=models.types.StringUUID(), - nullable=False) - batch_op.alter_column('data_source_type', - existing_type=models.types.LongText(), - nullable=False) - batch_op.alter_column('document_id', - existing_type=models.types.StringUUID(), - nullable=False) + with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: + batch_op.alter_column('segment_id', + existing_type=models.types.StringUUID(), + nullable=False) + batch_op.alter_column('data_source_type', + existing_type=models.types.LongText(), + nullable=False) + batch_op.alter_column('document_id', + existing_type=models.types.StringUUID(), + nullable=False) # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py index a749c8bddf..5d12419bf7 100644 --- a/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py +++ b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py @@ -8,7 +8,6 @@ Create Date: 2024-11-01 04:34:23.816198 from alembic import op import models as models import sqlalchemy as sa -from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = 'd3f6769a94a3' diff --git a/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py index 45842295ea..a49d6a52f6 100644 --- a/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py +++ b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py @@ -28,85 +28,45 @@ def upgrade(): op.execute("UPDATE sites SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL") op.execute("UPDATE tool_api_providers SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL") - if _is_pg(conn): - with op.batch_alter_table('recommended_apps', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.VARCHAR(length=255), - type_=sa.TEXT(), - nullable=False) + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=models.types.LongText(), + nullable=False) - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.VARCHAR(length=255), - type_=sa.TEXT(), - nullable=False) + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=models.types.LongText(), + nullable=False) - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.VARCHAR(length=255), - type_=sa.TEXT(), - nullable=False) - else: - with op.batch_alter_table('recommended_apps', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.VARCHAR(length=255), - type_=models.types.LongText(), - nullable=False) - - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.VARCHAR(length=255), - type_=models.types.LongText(), - nullable=False) - - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.VARCHAR(length=255), - type_=models.types.LongText(), - nullable=False) + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=models.types.LongText(), + nullable=False) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.TEXT(), - type_=sa.VARCHAR(length=255), - nullable=True) + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=models.types.LongText(), + type_=sa.VARCHAR(length=255), + nullable=True) - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.TEXT(), - type_=sa.VARCHAR(length=255), - nullable=True) + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=models.types.LongText(), + type_=sa.VARCHAR(length=255), + nullable=True) - with op.batch_alter_table('recommended_apps', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=sa.TEXT(), - type_=sa.VARCHAR(length=255), - nullable=True) - else: - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=models.types.LongText(), - type_=sa.VARCHAR(length=255), - nullable=True) - - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=models.types.LongText(), - type_=sa.VARCHAR(length=255), - nullable=True) - - with op.batch_alter_table('recommended_apps', schema=None) as batch_op: - batch_op.alter_column('custom_disclaimer', - existing_type=models.types.LongText(), - type_=sa.VARCHAR(length=255), - nullable=True) + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=models.types.LongText(), + type_=sa.VARCHAR(length=255), + nullable=True) # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py index fdd8984029..8a36c9c4a5 100644 --- a/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py +++ b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py @@ -49,57 +49,33 @@ def upgrade(): op.execute("UPDATE workflows SET updated_at = created_at WHERE updated_at IS NULL") op.execute("UPDATE workflows SET graph = '' WHERE graph IS NULL") op.execute("UPDATE workflows SET features = '' WHERE features IS NULL") - if _is_pg(conn): - with op.batch_alter_table('workflows', schema=None) as batch_op: - batch_op.alter_column('graph', - existing_type=sa.TEXT(), - nullable=False) - batch_op.alter_column('features', - existing_type=sa.TEXT(), - nullable=False) - batch_op.alter_column('updated_at', - existing_type=postgresql.TIMESTAMP(), - nullable=False) - else: - with op.batch_alter_table('workflows', schema=None) as batch_op: - batch_op.alter_column('graph', - existing_type=models.types.LongText(), - nullable=False) - batch_op.alter_column('features', - existing_type=models.types.LongText(), - nullable=False) - batch_op.alter_column('updated_at', - existing_type=sa.TIMESTAMP(), - nullable=False) + + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.alter_column('graph', + existing_type=models.types.LongText(), + nullable=False) + batch_op.alter_column('features', + existing_type=models.types.LongText(), + nullable=False) + batch_op.alter_column('updated_at', + existing_type=sa.TIMESTAMP(), + nullable=False) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('workflows', schema=None) as batch_op: - batch_op.alter_column('updated_at', - existing_type=postgresql.TIMESTAMP(), - nullable=True) - batch_op.alter_column('features', - existing_type=sa.TEXT(), - nullable=True) - batch_op.alter_column('graph', - existing_type=sa.TEXT(), - nullable=True) - else: - with op.batch_alter_table('workflows', schema=None) as batch_op: - batch_op.alter_column('updated_at', - existing_type=sa.TIMESTAMP(), - nullable=True) - batch_op.alter_column('features', - existing_type=models.types.LongText(), - nullable=True) - batch_op.alter_column('graph', - existing_type=models.types.LongText(), - nullable=True) + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.alter_column('updated_at', + existing_type=sa.TIMESTAMP(), + nullable=True) + batch_op.alter_column('features', + existing_type=models.types.LongText(), + nullable=True) + batch_op.alter_column('graph', + existing_type=models.types.LongText(), + nullable=True) if _is_pg(conn): with op.batch_alter_table('messages', schema=None) as batch_op: diff --git a/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py b/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py index 16ca902726..1fc4a64df1 100644 --- a/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py +++ b/api/migrations/versions/2025_08_13_1605-0e154742a5fa_add_provider_model_multi_credential.py @@ -86,57 +86,30 @@ def upgrade(): def migrate_existing_provider_models_data(): """migrate provider_models table data to provider_model_credentials""" - conn = op.get_bind() - # Define table structure for data manipulation - if _is_pg(conn): - provider_models_table = table('provider_models', - column('id', models.types.StringUUID()), - column('tenant_id', models.types.StringUUID()), - column('provider_name', sa.String()), - column('model_name', sa.String()), - column('model_type', sa.String()), - column('encrypted_config', sa.Text()), - column('created_at', sa.DateTime()), - column('updated_at', sa.DateTime()), - column('credential_id', models.types.StringUUID()), - ) - else: - provider_models_table = table('provider_models', - column('id', models.types.StringUUID()), - column('tenant_id', models.types.StringUUID()), - column('provider_name', sa.String()), - column('model_name', sa.String()), - column('model_type', sa.String()), - column('encrypted_config', models.types.LongText()), - column('created_at', sa.DateTime()), - column('updated_at', sa.DateTime()), - column('credential_id', models.types.StringUUID()), - ) + # Define table structure for data manipulatio + provider_models_table = table('provider_models', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('model_name', sa.String()), + column('model_type', sa.String()), + column('encrypted_config', models.types.LongText()), + column('created_at', sa.DateTime()), + column('updated_at', sa.DateTime()), + column('credential_id', models.types.StringUUID()), + ) - if _is_pg(conn): - provider_model_credentials_table = table('provider_model_credentials', - column('id', models.types.StringUUID()), - column('tenant_id', models.types.StringUUID()), - column('provider_name', sa.String()), - column('model_name', sa.String()), - column('model_type', sa.String()), - column('credential_name', sa.String()), - column('encrypted_config', sa.Text()), - column('created_at', sa.DateTime()), - column('updated_at', sa.DateTime()) - ) - else: - provider_model_credentials_table = table('provider_model_credentials', - column('id', models.types.StringUUID()), - column('tenant_id', models.types.StringUUID()), - column('provider_name', sa.String()), - column('model_name', sa.String()), - column('model_type', sa.String()), - column('credential_name', sa.String()), - column('encrypted_config', models.types.LongText()), - column('created_at', sa.DateTime()), - column('updated_at', sa.DateTime()) - ) + provider_model_credentials_table = table('provider_model_credentials', + column('id', models.types.StringUUID()), + column('tenant_id', models.types.StringUUID()), + column('provider_name', sa.String()), + column('model_name', sa.String()), + column('model_type', sa.String()), + column('credential_name', sa.String()), + column('encrypted_config', models.types.LongText()), + column('created_at', sa.DateTime()), + column('updated_at', sa.DateTime()) + ) # Get database connection @@ -183,14 +156,8 @@ def migrate_existing_provider_models_data(): def downgrade(): # Re-add encrypted_config column to provider_models table - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('provider_models', schema=None) as batch_op: - batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('provider_models', schema=None) as batch_op: - batch_op.add_column(sa.Column('encrypted_config', models.types.LongText(), nullable=True)) + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.add_column(sa.Column('encrypted_config', models.types.LongText(), nullable=True)) if not context.is_offline_mode(): # Migrate data back from provider_model_credentials to provider_models diff --git a/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py b/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py index 75b4d61173..79fe9d9bba 100644 --- a/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py +++ b/api/migrations/versions/2025_08_20_1747-8d289573e1da_add_oauth_provider_apps.py @@ -8,7 +8,6 @@ Create Date: 2025-08-20 17:47:17.015695 from alembic import op import models as models import sqlalchemy as sa -from libs.uuid_utils import uuidv7 def _is_pg(conn): diff --git a/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py b/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py index 4f472fe4b4..cf2b973d2d 100644 --- a/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py +++ b/api/migrations/versions/2025_09_08_1007-c20211f18133_add_headers_to_mcp_provider.py @@ -9,8 +9,6 @@ from alembic import op import models as models -def _is_pg(conn): - return conn.dialect.name == "postgresql" import sqlalchemy as sa @@ -23,12 +21,7 @@ depends_on = None def upgrade(): # Add encrypted_headers column to tool_mcp_providers table - conn = op.get_bind() - - if _is_pg(conn): - op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', sa.Text(), nullable=True)) - else: - op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', models.types.LongText(), nullable=True)) + op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', models.types.LongText(), nullable=True)) def downgrade(): diff --git a/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py b/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py index 8eac0dee10..bad516dcac 100644 --- a/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py +++ b/api/migrations/versions/2025_09_17_1515-68519ad5cd18_knowledge_pipeline_migrate.py @@ -44,6 +44,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='datasource_oauth_config_pkey'), sa.UniqueConstraint('plugin_id', 'provider', name='datasource_oauth_config_datasource_id_provider_idx') ) + if _is_pg(conn): op.create_table('datasource_oauth_tenant_params', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), @@ -70,6 +71,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='datasource_oauth_tenant_config_pkey'), sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='datasource_oauth_tenant_config_unique') ) + if _is_pg(conn): op.create_table('datasource_providers', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), @@ -104,6 +106,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='datasource_provider_pkey'), sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', 'name', name='datasource_provider_unique_name') ) + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: batch_op.create_index('datasource_provider_auth_type_provider_idx', ['tenant_id', 'plugin_id', 'provider'], unique=False) @@ -133,6 +136,7 @@ def upgrade(): sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), sa.PrimaryKeyConstraint('id', name='document_pipeline_execution_log_pkey') ) + with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op: batch_op.create_index('document_pipeline_execution_logs_document_id_idx', ['document_id'], unique=False) @@ -174,6 +178,7 @@ def upgrade(): sa.Column('updated_by', models.types.StringUUID(), nullable=True), sa.PrimaryKeyConstraint('id', name='pipeline_built_in_template_pkey') ) + if _is_pg(conn): op.create_table('pipeline_customized_templates', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), @@ -193,7 +198,6 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey') ) else: - # MySQL: Use compatible syntax op.create_table('pipeline_customized_templates', sa.Column('id', models.types.StringUUID(), nullable=False), sa.Column('tenant_id', models.types.StringUUID(), nullable=False), @@ -211,6 +215,7 @@ def upgrade(): sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey') ) + with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op: batch_op.create_index('pipeline_customized_template_tenant_idx', ['tenant_id'], unique=False) @@ -236,6 +241,7 @@ def upgrade(): sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), sa.PrimaryKeyConstraint('id', name='pipeline_recommended_plugin_pkey') ) + if _is_pg(conn): op.create_table('pipelines', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), @@ -266,6 +272,7 @@ def upgrade(): sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False), sa.PrimaryKeyConstraint('id', name='pipeline_pkey') ) + if _is_pg(conn): op.create_table('workflow_draft_variable_files', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), @@ -292,6 +299,7 @@ def upgrade(): sa.Column('value_type', sa.String(20), nullable=False), sa.PrimaryKeyConstraint('id', name=op.f('workflow_draft_variable_files_pkey')) ) + if _is_pg(conn): op.create_table('workflow_node_execution_offload', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), @@ -316,6 +324,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name=op.f('workflow_node_execution_offload_pkey')), sa.UniqueConstraint('node_execution_id', 'type', name=op.f('workflow_node_execution_offload_node_execution_id_key')) ) + if _is_pg(conn): with op.batch_alter_table('datasets', schema=None) as batch_op: batch_op.add_column(sa.Column('keyword_number', sa.Integer(), server_default=sa.text('10'), nullable=True)) @@ -342,6 +351,7 @@ def upgrade(): comment='Indicates whether the current value is the default for a conversation variable. Always `FALSE` for other types of variables.',) ) batch_op.create_index('workflow_draft_variable_file_id_idx', ['file_id'], unique=False) + if _is_pg(conn): with op.batch_alter_table('workflows', schema=None) as batch_op: batch_op.add_column(sa.Column('rag_pipeline_variables', sa.Text(), server_default='{}', nullable=False)) diff --git a/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py b/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py index 0776ab0818..ec0cfbd11d 100644 --- a/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py +++ b/api/migrations/versions/2025_10_21_1430-ae662b25d9bc_remove_builtin_template_user.py @@ -9,8 +9,6 @@ from alembic import op import models as models -def _is_pg(conn): - return conn.dialect.name == "postgresql" import sqlalchemy as sa @@ -33,15 +31,9 @@ def upgrade(): def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: - batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False)) - batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True)) - else: - with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: - batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), autoincrement=False, nullable=False)) - batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), autoincrement=False, nullable=True)) + + with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op: + batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), autoincrement=False, nullable=False)) + batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), autoincrement=False, nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py b/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py index 627219cc4b..12905b3674 100644 --- a/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py +++ b/api/migrations/versions/2025_10_22_1611-03f8dcbc611e_add_workflowpause_model.py @@ -9,7 +9,6 @@ Create Date: 2025-10-22 16:11:31.805407 from alembic import op import models as models import sqlalchemy as sa -from libs.uuid_utils import uuidv7 def _is_pg(conn): return conn.dialect.name == "postgresql" diff --git a/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py b/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py index 9641a15c89..c27c1058d1 100644 --- a/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py +++ b/api/migrations/versions/2025_10_30_1518-669ffd70119c_introduce_trigger.py @@ -105,6 +105,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='trigger_oauth_tenant_client_pkey'), sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_trigger_oauth_tenant_client') ) + if _is_pg(conn): op.create_table('trigger_subscriptions', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), @@ -143,6 +144,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='trigger_provider_pkey'), sa.UniqueConstraint('tenant_id', 'provider_id', 'name', name='unique_trigger_provider') ) + with op.batch_alter_table('trigger_subscriptions', schema=None) as batch_op: batch_op.create_index('idx_trigger_providers_endpoint', ['endpoint_id'], unique=True) batch_op.create_index('idx_trigger_providers_tenant_endpoint', ['tenant_id', 'endpoint_id'], unique=False) @@ -176,6 +178,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='workflow_plugin_trigger_pkey'), sa.UniqueConstraint('app_id', 'node_id', name='uniq_app_node_subscription') ) + with op.batch_alter_table('workflow_plugin_triggers', schema=None) as batch_op: batch_op.create_index('workflow_plugin_trigger_tenant_subscription_idx', ['tenant_id', 'subscription_id', 'event_name'], unique=False) @@ -207,6 +210,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='workflow_schedule_plan_pkey'), sa.UniqueConstraint('app_id', 'node_id', name='uniq_app_node') ) + with op.batch_alter_table('workflow_schedule_plans', schema=None) as batch_op: batch_op.create_index('workflow_schedule_plan_next_idx', ['next_run_at'], unique=False) @@ -264,6 +268,7 @@ def upgrade(): sa.Column('finished_at', sa.DateTime(), nullable=True), sa.PrimaryKeyConstraint('id', name='workflow_trigger_log_pkey') ) + with op.batch_alter_table('workflow_trigger_logs', schema=None) as batch_op: batch_op.create_index('workflow_trigger_log_created_at_idx', ['created_at'], unique=False) batch_op.create_index('workflow_trigger_log_status_idx', ['status'], unique=False) @@ -299,6 +304,7 @@ def upgrade(): sa.UniqueConstraint('app_id', 'node_id', name='uniq_node'), sa.UniqueConstraint('webhook_id', name='uniq_webhook_id') ) + with op.batch_alter_table('workflow_webhook_triggers', schema=None) as batch_op: batch_op.create_index('workflow_webhook_trigger_tenant_idx', ['tenant_id'], unique=False) diff --git a/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py b/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py index fae506906b..127ffd5599 100644 --- a/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py +++ b/api/migrations/versions/23db93619b9d_add_message_files_into_agent_thought.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '23db93619b9d' down_revision = '8ae9bc661daa' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: - batch_op.add_column(sa.Column('message_files', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: - batch_op.add_column(sa.Column('message_files', models.types.LongText(), nullable=True)) + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.add_column(sa.Column('message_files', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py b/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py index 2676ef0b94..31829d8e58 100644 --- a/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py +++ b/api/migrations/versions/246ba09cbbdb_add_app_anntation_setting.py @@ -62,14 +62,8 @@ def upgrade(): def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('annotation_reply', sa.TEXT(), autoincrement=False, nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), autoincrement=False, nullable=True)) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), autoincrement=False, nullable=True)) with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op: batch_op.drop_index('app_annotation_settings_app_idx') diff --git a/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py index 3362a3a09f..07a8cd86b1 100644 --- a/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py +++ b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py @@ -11,9 +11,6 @@ from alembic import op import models as models -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '2a3aebbbf4bb' down_revision = 'c031d46af369' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('apps', schema=None) as batch_op: - batch_op.add_column(sa.Column('tracing', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('apps', schema=None) as batch_op: - batch_op.add_column(sa.Column('tracing', models.types.LongText(), nullable=True)) + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.add_column(sa.Column('tracing', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py b/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py index 40bd727f66..211b2d8882 100644 --- a/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py +++ b/api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py @@ -7,14 +7,10 @@ Create Date: 2023-09-22 15:41:01.243183 """ import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects import postgresql import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '2e9819ca5b28' down_revision = 'ab23c11305d4' @@ -24,35 +20,19 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('api_tokens', schema=None) as batch_op: - batch_op.add_column(sa.Column('tenant_id', postgresql.UUID(), nullable=True)) - batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False) - batch_op.drop_column('dataset_id') - else: - with op.batch_alter_table('api_tokens', schema=None) as batch_op: - batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=True)) - batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False) - batch_op.drop_column('dataset_id') + with op.batch_alter_table('api_tokens', schema=None) as batch_op: + batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=True)) + batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False) + batch_op.drop_column('dataset_id') # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('api_tokens', schema=None) as batch_op: - batch_op.add_column(sa.Column('dataset_id', postgresql.UUID(), autoincrement=False, nullable=True)) - batch_op.drop_index('api_token_tenant_idx') - batch_op.drop_column('tenant_id') - else: - with op.batch_alter_table('api_tokens', schema=None) as batch_op: - batch_op.add_column(sa.Column('dataset_id', models.types.StringUUID(), autoincrement=False, nullable=True)) - batch_op.drop_index('api_token_tenant_idx') - batch_op.drop_column('tenant_id') + with op.batch_alter_table('api_tokens', schema=None) as batch_op: + batch_op.add_column(sa.Column('dataset_id', models.types.StringUUID(), autoincrement=False, nullable=True)) + batch_op.drop_index('api_token_tenant_idx') + batch_op.drop_column('tenant_id') # ### end Alembic commands ### diff --git a/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py index 76056a9460..3491c85e2f 100644 --- a/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py +++ b/api/migrations/versions/42e85ed5564d_conversation_columns_set_nullable.py @@ -7,14 +7,10 @@ Create Date: 2024-03-07 08:30:29.133614 """ import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects import postgresql import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '42e85ed5564d' down_revision = 'f9107f83abab' @@ -24,59 +20,31 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('conversations', schema=None) as batch_op: - batch_op.alter_column('app_model_config_id', - existing_type=postgresql.UUID(), - nullable=True) - batch_op.alter_column('model_provider', - existing_type=sa.VARCHAR(length=255), - nullable=True) - batch_op.alter_column('model_id', - existing_type=sa.VARCHAR(length=255), - nullable=True) - else: - with op.batch_alter_table('conversations', schema=None) as batch_op: - batch_op.alter_column('app_model_config_id', - existing_type=models.types.StringUUID(), - nullable=True) - batch_op.alter_column('model_provider', - existing_type=sa.VARCHAR(length=255), - nullable=True) - batch_op.alter_column('model_id', - existing_type=sa.VARCHAR(length=255), - nullable=True) + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('app_model_config_id', + existing_type=models.types.StringUUID(), + nullable=True) + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=True) + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=True) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('conversations', schema=None) as batch_op: - batch_op.alter_column('model_id', - existing_type=sa.VARCHAR(length=255), - nullable=False) - batch_op.alter_column('model_provider', - existing_type=sa.VARCHAR(length=255), - nullable=False) - batch_op.alter_column('app_model_config_id', - existing_type=postgresql.UUID(), - nullable=False) - else: - with op.batch_alter_table('conversations', schema=None) as batch_op: - batch_op.alter_column('model_id', - existing_type=sa.VARCHAR(length=255), - nullable=False) - batch_op.alter_column('model_provider', - existing_type=sa.VARCHAR(length=255), - nullable=False) - batch_op.alter_column('app_model_config_id', - existing_type=models.types.StringUUID(), - nullable=False) + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('model_id', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('model_provider', + existing_type=sa.VARCHAR(length=255), + nullable=False) + batch_op.alter_column('app_model_config_id', + existing_type=models.types.StringUUID(), + nullable=False) # ### end Alembic commands ### diff --git a/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py b/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py index ef066587b7..8537a87233 100644 --- a/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py +++ b/api/migrations/versions/4829e54d2fee_change_message_chain_id_to_nullable.py @@ -6,14 +6,10 @@ Create Date: 2024-01-12 03:42:27.362415 """ from alembic import op -from sqlalchemy.dialects import postgresql import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '4829e54d2fee' down_revision = '114eed84c228' @@ -23,39 +19,21 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - # PostgreSQL: Keep original syntax - with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: - batch_op.alter_column('message_chain_id', - existing_type=postgresql.UUID(), - nullable=True) - else: - # MySQL: Use compatible syntax - with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: - batch_op.alter_column('message_chain_id', - existing_type=models.types.StringUUID(), - nullable=True) + + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.alter_column('message_chain_id', + existing_type=models.types.StringUUID(), + nullable=True) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - # PostgreSQL: Keep original syntax - with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: - batch_op.alter_column('message_chain_id', - existing_type=postgresql.UUID(), - nullable=False) - else: - # MySQL: Use compatible syntax - with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: - batch_op.alter_column('message_chain_id', - existing_type=models.types.StringUUID(), - nullable=False) + + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.alter_column('message_chain_id', + existing_type=models.types.StringUUID(), + nullable=False) # ### end Alembic commands ### diff --git a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py index b080e7680b..22405e3cc8 100644 --- a/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py +++ b/api/migrations/versions/563cf8bf777b_enable_tool_file_without_conversation_id.py @@ -6,14 +6,10 @@ Create Date: 2024-03-14 04:54:56.679506 """ from alembic import op -from sqlalchemy.dialects import postgresql import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '563cf8bf777b' down_revision = 'b5429b71023c' @@ -23,35 +19,19 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_files', schema=None) as batch_op: - batch_op.alter_column('conversation_id', - existing_type=postgresql.UUID(), - nullable=True) - else: - with op.batch_alter_table('tool_files', schema=None) as batch_op: - batch_op.alter_column('conversation_id', - existing_type=models.types.StringUUID(), - nullable=True) + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('conversation_id', + existing_type=models.types.StringUUID(), + nullable=True) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_files', schema=None) as batch_op: - batch_op.alter_column('conversation_id', - existing_type=postgresql.UUID(), - nullable=False) - else: - with op.batch_alter_table('tool_files', schema=None) as batch_op: - batch_op.alter_column('conversation_id', - existing_type=models.types.StringUUID(), - nullable=False) + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('conversation_id', + existing_type=models.types.StringUUID(), + nullable=False) # ### end Alembic commands ### diff --git a/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py b/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py index 1ace8ea5a0..01d7d5ba21 100644 --- a/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py +++ b/api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py @@ -48,12 +48,9 @@ def upgrade(): with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: batch_op.create_index('provider_model_name_idx', ['provider_name', 'model_name'], unique=False) - if _is_pg(conn): - with op.batch_alter_table('datasets', schema=None) as batch_op: - batch_op.add_column(sa.Column('collection_binding_id', postgresql.UUID(), nullable=True)) - else: - with op.batch_alter_table('datasets', schema=None) as batch_op: - batch_op.add_column(sa.Column('collection_binding_id', models.types.StringUUID(), nullable=True)) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('collection_binding_id', models.types.StringUUID(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py b/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py index 457338ef42..0faa48f535 100644 --- a/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py +++ b/api/migrations/versions/714aafe25d39_add_anntation_history_match_response.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '714aafe25d39' down_revision = 'f2a6fc85e260' @@ -23,16 +20,9 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: - batch_op.add_column(sa.Column('annotation_question', sa.Text(), nullable=False)) - batch_op.add_column(sa.Column('annotation_content', sa.Text(), nullable=False)) - else: - with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: - batch_op.add_column(sa.Column('annotation_question', models.types.LongText(), nullable=False)) - batch_op.add_column(sa.Column('annotation_content', models.types.LongText(), nullable=False)) + with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: + batch_op.add_column(sa.Column('annotation_question', models.types.LongText(), nullable=False)) + batch_op.add_column(sa.Column('annotation_content', models.types.LongText(), nullable=False)) # ### end Alembic commands ### diff --git a/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py b/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py index 7bcd1a1be3..aa7b4a21e2 100644 --- a/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py +++ b/api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '77e83833755c' down_revision = '6dcb43972bdc' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('retriever_resource', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('retriever_resource', models.types.LongText(), nullable=True)) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('retriever_resource', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py b/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py index 3c0aa082d5..34a17697d3 100644 --- a/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py +++ b/api/migrations/versions/7ce5a52e4eee_add_tool_providers.py @@ -27,7 +27,6 @@ def upgrade(): conn = op.get_bind() if _is_pg(conn): - # PostgreSQL: Keep original syntax op.create_table('tool_providers', sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), sa.Column('tenant_id', postgresql.UUID(), nullable=False), @@ -40,7 +39,6 @@ def upgrade(): sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') ) else: - # MySQL: Use compatible syntax op.create_table('tool_providers', sa.Column('id', models.types.StringUUID(), nullable=False), sa.Column('tenant_id', models.types.StringUUID(), nullable=False), @@ -52,12 +50,9 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'), sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') ) - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('sensitive_word_avoidance', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('sensitive_word_avoidance', models.types.LongText(), nullable=True)) + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('sensitive_word_avoidance', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py b/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py index beea90b384..884839c010 100644 --- a/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py +++ b/api/migrations/versions/88072f0caa04_add_custom_config_in_tenant.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '88072f0caa04' down_revision = '246ba09cbbdb' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tenants', schema=None) as batch_op: - batch_op.add_column(sa.Column('custom_config', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('tenants', schema=None) as batch_op: - batch_op.add_column(sa.Column('custom_config', models.types.LongText(), nullable=True)) + with op.batch_alter_table('tenants', schema=None) as batch_op: + batch_op.add_column(sa.Column('custom_config', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/89c7899ca936_.py b/api/migrations/versions/89c7899ca936_.py index 2420710e74..d26f1e82d6 100644 --- a/api/migrations/versions/89c7899ca936_.py +++ b/api/migrations/versions/89c7899ca936_.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '89c7899ca936' down_revision = '187385f442fc' @@ -23,39 +20,21 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('description', - existing_type=sa.VARCHAR(length=255), - type_=sa.Text(), - existing_nullable=True) - else: - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('description', - existing_type=sa.VARCHAR(length=255), - type_=models.types.LongText(), - existing_nullable=True) + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('description', + existing_type=sa.VARCHAR(length=255), + type_=models.types.LongText(), + existing_nullable=True) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('description', - existing_type=sa.Text(), - type_=sa.VARCHAR(length=255), - existing_nullable=True) - else: - with op.batch_alter_table('sites', schema=None) as batch_op: - batch_op.alter_column('description', - existing_type=models.types.LongText(), - type_=sa.VARCHAR(length=255), - existing_nullable=True) + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('description', + existing_type=models.types.LongText(), + type_=sa.VARCHAR(length=255), + existing_nullable=True) # ### end Alembic commands ### diff --git a/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py b/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py index 111e81240b..6022ea2c20 100644 --- a/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py +++ b/api/migrations/versions/8ec536f3c800_rename_api_provider_credentails.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = '8ec536f3c800' down_revision = 'ad472b61a054' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('credentials_str', sa.Text(), nullable=False)) - else: - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('credentials_str', models.types.LongText(), nullable=False)) + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('credentials_str', models.types.LongText(), nullable=False)) # ### end Alembic commands ### diff --git a/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py b/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py index 1c1c6cacbb..9d6d40114d 100644 --- a/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py +++ b/api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py @@ -57,12 +57,9 @@ def upgrade(): batch_op.create_index('message_file_created_by_idx', ['created_by'], unique=False) batch_op.create_index('message_file_message_idx', ['message_id'], unique=False) - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('file_upload', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('file_upload', models.types.LongText(), nullable=True)) + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('file_upload', models.types.LongText(), nullable=True)) if _is_pg(conn): with op.batch_alter_table('upload_files', schema=None) as batch_op: diff --git a/api/migrations/versions/9f4e3427ea84_add_created_by_role.py b/api/migrations/versions/9f4e3427ea84_add_created_by_role.py index 5d29d354f3..0b3f92a12e 100644 --- a/api/migrations/versions/9f4e3427ea84_add_created_by_role.py +++ b/api/migrations/versions/9f4e3427ea84_add_created_by_role.py @@ -24,7 +24,6 @@ def upgrade(): conn = op.get_bind() if _is_pg(conn): - # PostgreSQL: Keep original syntax with op.batch_alter_table('pinned_conversations', schema=None) as batch_op: batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'end_user'::character varying"), nullable=False)) batch_op.drop_index('pinned_conversation_conversation_idx') @@ -35,7 +34,6 @@ def upgrade(): batch_op.drop_index('saved_message_message_idx') batch_op.create_index('saved_message_message_idx', ['app_id', 'message_id', 'created_by_role', 'created_by'], unique=False) else: - # MySQL: Use compatible syntax with op.batch_alter_table('pinned_conversations', schema=None) as batch_op: batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'end_user'"), nullable=False)) batch_op.drop_index('pinned_conversation_conversation_idx') diff --git a/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py b/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py index 616cb2f163..c8747a51f7 100644 --- a/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py +++ b/api/migrations/versions/a5b56fb053ef_app_config_add_speech_to_text.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = 'a5b56fb053ef' down_revision = 'd3d503a3471c' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('speech_to_text', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('speech_to_text', models.types.LongText(), nullable=True)) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('speech_to_text', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py index 900ff78036..f56aeb7e66 100644 --- a/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py +++ b/api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = 'a9836e3baeee' down_revision = '968fff4c0ab9' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('external_data_tools', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('external_data_tools', models.types.LongText(), nullable=True)) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('external_data_tools', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/b24be59fbb04_.py b/api/migrations/versions/b24be59fbb04_.py index b0a6d10d8c..ae91eaf1bc 100644 --- a/api/migrations/versions/b24be59fbb04_.py +++ b/api/migrations/versions/b24be59fbb04_.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = 'b24be59fbb04' down_revision = 'de95f5c77138' @@ -23,14 +20,8 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('text_to_speech', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('text_to_speech', models.types.LongText(), nullable=True)) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('text_to_speech', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py index 772395c25b..c02c24c23f 100644 --- a/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py +++ b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py @@ -11,9 +11,6 @@ from alembic import op import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = 'b3a09c049e8e' down_revision = '2e9819ca5b28' @@ -23,20 +20,11 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple')) - batch_op.add_column(sa.Column('chat_prompt_config', sa.Text(), nullable=True)) - batch_op.add_column(sa.Column('completion_prompt_config', sa.Text(), nullable=True)) - batch_op.add_column(sa.Column('dataset_configs', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple')) - batch_op.add_column(sa.Column('chat_prompt_config', models.types.LongText(), nullable=True)) - batch_op.add_column(sa.Column('completion_prompt_config', models.types.LongText(), nullable=True)) - batch_op.add_column(sa.Column('dataset_configs', models.types.LongText(), nullable=True)) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple')) + batch_op.add_column(sa.Column('chat_prompt_config', models.types.LongText(), nullable=True)) + batch_op.add_column(sa.Column('completion_prompt_config', models.types.LongText(), nullable=True)) + batch_op.add_column(sa.Column('dataset_configs', models.types.LongText(), nullable=True)) # ### end Alembic commands ### diff --git a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py index 76be794ff4..fe51d1c78d 100644 --- a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py +++ b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py @@ -7,7 +7,6 @@ Create Date: 2024-06-17 10:01:00.255189 """ import sqlalchemy as sa from alembic import op -from sqlalchemy.dialects import postgresql import models.types diff --git a/api/migrations/versions/e1901f623fd0_add_annotation_reply.py b/api/migrations/versions/e1901f623fd0_add_annotation_reply.py index 9e02ec5d84..36e934f0fc 100644 --- a/api/migrations/versions/e1901f623fd0_add_annotation_reply.py +++ b/api/migrations/versions/e1901f623fd0_add_annotation_reply.py @@ -54,12 +54,9 @@ def upgrade(): batch_op.create_index('app_annotation_hit_histories_annotation_idx', ['annotation_id'], unique=False) batch_op.create_index('app_annotation_hit_histories_app_idx', ['app_id'], unique=False) - if _is_pg(conn): - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('annotation_reply', sa.Text(), nullable=True)) - else: - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), nullable=True)) + + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), nullable=True)) if _is_pg(conn): with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: @@ -68,54 +65,31 @@ def upgrade(): with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: batch_op.add_column(sa.Column('type', sa.String(length=40), server_default=sa.text("'dataset'"), nullable=False)) - if _is_pg(conn): - with op.batch_alter_table('message_annotations', schema=None) as batch_op: - batch_op.add_column(sa.Column('question', sa.Text(), nullable=True)) - batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False)) - batch_op.alter_column('conversation_id', - existing_type=postgresql.UUID(), - nullable=True) - batch_op.alter_column('message_id', - existing_type=postgresql.UUID(), - nullable=True) - else: - with op.batch_alter_table('message_annotations', schema=None) as batch_op: - batch_op.add_column(sa.Column('question', models.types.LongText(), nullable=True)) - batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False)) - batch_op.alter_column('conversation_id', - existing_type=models.types.StringUUID(), - nullable=True) - batch_op.alter_column('message_id', - existing_type=models.types.StringUUID(), - nullable=True) + with op.batch_alter_table('message_annotations', schema=None) as batch_op: + batch_op.add_column(sa.Column('question', models.types.LongText(), nullable=True)) + batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False)) + batch_op.alter_column('conversation_id', + existing_type=models.types.StringUUID(), + nullable=True) + batch_op.alter_column('message_id', + existing_type=models.types.StringUUID(), + nullable=True) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - if _is_pg(conn): - with op.batch_alter_table('message_annotations', schema=None) as batch_op: - batch_op.alter_column('message_id', - existing_type=postgresql.UUID(), - nullable=False) - batch_op.alter_column('conversation_id', - existing_type=postgresql.UUID(), - nullable=False) - batch_op.drop_column('hit_count') - batch_op.drop_column('question') - else: - with op.batch_alter_table('message_annotations', schema=None) as batch_op: - batch_op.alter_column('message_id', - existing_type=models.types.StringUUID(), - nullable=False) - batch_op.alter_column('conversation_id', - existing_type=models.types.StringUUID(), - nullable=False) - batch_op.drop_column('hit_count') - batch_op.drop_column('question') + with op.batch_alter_table('message_annotations', schema=None) as batch_op: + batch_op.alter_column('message_id', + existing_type=models.types.StringUUID(), + nullable=False) + batch_op.alter_column('conversation_id', + existing_type=models.types.StringUUID(), + nullable=False) + batch_op.drop_column('hit_count') + batch_op.drop_column('question') with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: batch_op.drop_column('type') diff --git a/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py b/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py index 02098e91c1..ac1c14e50c 100644 --- a/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py +++ b/api/migrations/versions/f2a6fc85e260_add_anntation_history_message_id.py @@ -12,9 +12,6 @@ from sqlalchemy.dialects import postgresql import models.types -def _is_pg(conn): - return conn.dialect.name == "postgresql" - # revision identifiers, used by Alembic. revision = 'f2a6fc85e260' down_revision = '46976cc39132' @@ -24,16 +21,9 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - conn = op.get_bind() - - if _is_pg(conn): - with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: - batch_op.add_column(sa.Column('message_id', postgresql.UUID(), nullable=False)) - batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False) - else: - with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: - batch_op.add_column(sa.Column('message_id', models.types.StringUUID(), nullable=False)) - batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False) + with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: + batch_op.add_column(sa.Column('message_id', models.types.StringUUID(), nullable=False)) + batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False) # ### end Alembic commands ### diff --git a/api/models/account.py b/api/models/account.py index 420e6adc6c..f7a9c20026 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -8,7 +8,7 @@ from uuid import uuid4 import sqlalchemy as sa from flask_login import UserMixin from sqlalchemy import DateTime, String, func, select -from sqlalchemy.orm import Mapped, Session, mapped_column +from sqlalchemy.orm import Mapped, Session, mapped_column, validates from typing_extensions import deprecated from .base import TypeBase @@ -116,6 +116,12 @@ class Account(UserMixin, TypeBase): role: TenantAccountRole | None = field(default=None, init=False) _current_tenant: "Tenant | None" = field(default=None, init=False) + @validates("status") + def _normalize_status(self, _key: str, value: str | AccountStatus) -> str: + if isinstance(value, AccountStatus): + return value.value + return value + @property def is_password_set(self): return self.password is not None diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py index db610df290..77d6b5a138 100644 --- a/api/schedule/queue_monitor_task.py +++ b/api/schedule/queue_monitor_task.py @@ -16,6 +16,11 @@ celery_redis = Redis( port=redis_config.get("port") or 6379, password=redis_config.get("password") or None, db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1, + ssl=bool(dify_config.BROKER_USE_SSL), + ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS if dify_config.BROKER_USE_SSL else None, + ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None, + ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None, + ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None, ) logger = logging.getLogger(__name__) diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 3d7cb6cc8d..26ce8cad33 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,3 +1,4 @@ +import json import logging import os from collections.abc import Sequence @@ -31,6 +32,11 @@ class BillingService: compliance_download_rate_limiter = RateLimiter("compliance_download_rate_limiter", 4, 60) + # Redis key prefix for tenant plan cache + _PLAN_CACHE_KEY_PREFIX = "tenant_plan:" + # Cache TTL: 10 minutes + _PLAN_CACHE_TTL = 600 + @classmethod def get_info(cls, tenant_id: str): params = {"tenant_id": tenant_id} @@ -272,14 +278,110 @@ class BillingService: data = resp.get("data", {}) for tenant_id, plan in data.items(): - subscription_plan = subscription_adapter.validate_python(plan) - results[tenant_id] = subscription_plan + try: + subscription_plan = subscription_adapter.validate_python(plan) + results[tenant_id] = subscription_plan + except Exception: + logger.exception( + "get_plan_bulk: failed to validate subscription plan for tenant(%s)", tenant_id + ) + continue except Exception: - logger.exception("Failed to fetch billing info batch for tenants: %s", chunk) + logger.exception("get_plan_bulk: failed to fetch billing info batch for tenants: %s", chunk) continue return results + @classmethod + def _make_plan_cache_key(cls, tenant_id: str) -> str: + return f"{cls._PLAN_CACHE_KEY_PREFIX}{tenant_id}" + + @classmethod + def get_plan_bulk_with_cache(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]: + """ + Bulk fetch billing subscription plan with cache to reduce billing API loads in batch job scenarios. + + NOTE: if you want to high data consistency, use get_plan_bulk instead. + + Returns: + Mapping of tenant_id -> {plan: str, expiration_date: int} + """ + tenant_plans: dict[str, SubscriptionPlan] = {} + + if not tenant_ids: + return tenant_plans + + subscription_adapter = TypeAdapter(SubscriptionPlan) + + # Step 1: Batch fetch from Redis cache using mget + redis_keys = [cls._make_plan_cache_key(tenant_id) for tenant_id in tenant_ids] + try: + cached_values = redis_client.mget(redis_keys) + + if len(cached_values) != len(tenant_ids): + raise Exception( + "get_plan_bulk_with_cache: unexpected error: redis mget failed: cached values length mismatch" + ) + + # Map cached values back to tenant_ids + cache_misses: list[str] = [] + + for tenant_id, cached_value in zip(tenant_ids, cached_values): + if cached_value: + try: + # Redis returns bytes, decode to string and parse JSON + json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value + plan_dict = json.loads(json_str) + subscription_plan = subscription_adapter.validate_python(plan_dict) + tenant_plans[tenant_id] = subscription_plan + except Exception: + logger.exception( + "get_plan_bulk_with_cache: process tenant(%s) failed, add to cache misses", tenant_id + ) + cache_misses.append(tenant_id) + else: + cache_misses.append(tenant_id) + + logger.info( + "get_plan_bulk_with_cache: cache hits=%s, cache misses=%s", + len(tenant_plans), + len(cache_misses), + ) + except Exception: + logger.exception("get_plan_bulk_with_cache: redis mget failed, falling back to API") + cache_misses = list(tenant_ids) + + # Step 2: Fetch missing plans from billing API + if cache_misses: + bulk_plans = BillingService.get_plan_bulk(cache_misses) + + if bulk_plans: + plans_to_cache: dict[str, SubscriptionPlan] = {} + + for tenant_id, subscription_plan in bulk_plans.items(): + tenant_plans[tenant_id] = subscription_plan + plans_to_cache[tenant_id] = subscription_plan + + # Step 3: Batch update Redis cache using pipeline + if plans_to_cache: + try: + pipe = redis_client.pipeline() + for tenant_id, subscription_plan in plans_to_cache.items(): + redis_key = cls._make_plan_cache_key(tenant_id) + # Serialize dict to JSON string + json_str = json.dumps(subscription_plan) + pipe.setex(redis_key, cls._PLAN_CACHE_TTL, json_str) + pipe.execute() + + logger.info( + "get_plan_bulk_with_cache: cached %s new tenant plans to Redis", + len(plans_to_cache), + ) + except Exception: + logger.exception("get_plan_bulk_with_cache: redis pipeline failed") + + return tenant_plans + @classmethod def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]: resp = cls._send_request("GET", "/subscription/cleanup/whitelist") diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index f405546909..a29d848ac5 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -70,7 +70,6 @@ class ProviderResponse(BaseModel): description: I18nObject | None = None icon_small: I18nObject | None = None icon_small_dark: I18nObject | None = None - icon_large: I18nObject | None = None background: str | None = None help: ProviderHelpEntity | None = None supported_model_types: Sequence[ModelType] @@ -98,11 +97,6 @@ class ProviderResponse(BaseModel): en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans", ) - - if self.icon_large is not None: - self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" - ) return self @@ -116,7 +110,6 @@ class ProviderWithModelsResponse(BaseModel): label: I18nObject icon_small: I18nObject | None = None icon_small_dark: I18nObject | None = None - icon_large: I18nObject | None = None status: CustomConfigurationStatus models: list[ProviderModelWithStatusEntity] @@ -134,11 +127,6 @@ class ProviderWithModelsResponse(BaseModel): self.icon_small_dark = I18nObject( en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans" ) - - if self.icon_large is not None: - self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" - ) return self @@ -163,11 +151,6 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): self.icon_small_dark = I18nObject( en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans" ) - - if self.icon_large is not None: - self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" - ) return self diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index eea382febe..edd1004b82 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -99,7 +99,6 @@ class ModelProviderService: description=provider_configuration.provider.description, icon_small=provider_configuration.provider.icon_small, icon_small_dark=provider_configuration.provider.icon_small_dark, - icon_large=provider_configuration.provider.icon_large, background=provider_configuration.provider.background, help=provider_configuration.provider.help, supported_model_types=provider_configuration.provider.supported_model_types, @@ -423,7 +422,6 @@ class ModelProviderService: label=first_model.provider.label, icon_small=first_model.provider.icon_small, icon_small_dark=first_model.provider.icon_small_dark, - icon_large=first_model.provider.icon_large, status=CustomConfigurationStatus.ACTIVE, models=[ ProviderModelWithStatusEntity( @@ -488,7 +486,6 @@ class ModelProviderService: provider=result.provider.provider, label=result.provider.label, icon_small=result.provider.icon_small, - icon_large=result.provider.icon_large, supported_model_types=result.provider.supported_model_types, ), ) @@ -522,7 +519,7 @@ class ModelProviderService: :param tenant_id: workspace id :param provider: provider name - :param icon_type: icon type (icon_small or icon_large) + :param icon_type: icon type (icon_small or icon_small_dark) :param lang: language (zh_Hans or en_US) :return: """ diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index b3b6e36346..c32157919b 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -7,7 +7,6 @@ from httpx import get from sqlalchemy import select from core.entities.provider_entities import ProviderConfig -from core.helper.tool_provider_cache import ToolProviderListCache from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_runtime import ToolRuntime from core.tools.custom_tool.provider import ApiToolProviderController @@ -86,7 +85,9 @@ class ApiToolManageService: raise ValueError(f"invalid schema: {str(e)}") @staticmethod - def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]: + def convert_schema_to_tool_bundles( + schema: str, extra_info: dict | None = None + ) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]: """ convert schema to tool bundles @@ -104,7 +105,7 @@ class ApiToolManageService: provider_name: str, icon: dict, credentials: dict, - schema_type: str, + schema_type: ApiProviderSchemaType, schema: str, privacy_policy: str, custom_disclaimer: str, @@ -113,9 +114,6 @@ class ApiToolManageService: """ create api tool provider """ - if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f"invalid schema type {schema}") - provider_name = provider_name.strip() # check if the provider exists @@ -178,9 +176,6 @@ class ApiToolManageService: # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @staticmethod @@ -245,18 +240,15 @@ class ApiToolManageService: original_provider: str, icon: dict, credentials: dict, - schema_type: str, + _schema_type: ApiProviderSchemaType, schema: str, - privacy_policy: str, + privacy_policy: str | None, custom_disclaimer: str, labels: list[str], ): """ update api tool provider """ - if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f"invalid schema type {schema}") - provider_name = provider_name.strip() # check if the provider exists @@ -281,7 +273,7 @@ class ApiToolManageService: provider.icon = json.dumps(icon) provider.schema = schema provider.description = extra_info.get("description", "") - provider.schema_type_str = ApiProviderSchemaType.OPENAPI + provider.schema_type_str = schema_type provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) provider.privacy_policy = privacy_policy provider.custom_disclaimer = custom_disclaimer @@ -322,9 +314,6 @@ class ApiToolManageService: # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @staticmethod @@ -347,9 +336,6 @@ class ApiToolManageService: db.session.delete(provider) db.session.commit() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @staticmethod @@ -366,7 +352,7 @@ class ApiToolManageService: tool_name: str, credentials: dict, parameters: dict, - schema_type: str, + schema_type: ApiProviderSchemaType, schema: str, ): """ diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 87951d53e6..6797a67dde 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -12,7 +12,6 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.helper.name_generator import generate_incremental_name from core.helper.position_helper import is_filtered from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache -from core.helper.tool_provider_cache import ToolProviderListCache from core.plugin.entities.plugin_daemon import CredentialType from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort @@ -205,9 +204,6 @@ class BuiltinToolManageService: db_provider.name = name session.commit() - - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) except Exception as e: session.rollback() raise ValueError(str(e)) @@ -290,8 +286,6 @@ class BuiltinToolManageService: session.rollback() raise ValueError(str(e)) - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id, "builtin") return {"result": "success"} @staticmethod @@ -409,9 +403,6 @@ class BuiltinToolManageService: ) cache.delete() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @staticmethod @@ -434,8 +425,6 @@ class BuiltinToolManageService: target_provider.is_default = True session.commit() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) return {"result": "success"} @staticmethod diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py index 038c462f15..51e9120b8d 100644 --- a/api/services/tools/tools_manage_service.py +++ b/api/services/tools/tools_manage_service.py @@ -1,6 +1,5 @@ import logging -from core.helper.tool_provider_cache import ToolProviderListCache from core.tools.entities.api_entities import ToolProviderTypeApiLiteral from core.tools.tool_manager import ToolManager from services.tools.tools_transform_service import ToolTransformService @@ -16,14 +15,6 @@ class ToolCommonService: :return: the list of tool providers """ - # Try to get from cache first - cached_result = ToolProviderListCache.get_cached_providers(tenant_id, typ) - if cached_result is not None: - logger.debug("Returning cached tool providers for tenant %s, type %s", tenant_id, typ) - return cached_result - - # Cache miss - fetch from database - logger.debug("Cache miss for tool providers, fetching from database for tenant %s, type %s", tenant_id, typ) providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ) # add icon @@ -32,7 +23,4 @@ class ToolCommonService: result = [provider.to_dict() for provider in providers] - # Cache the result - ToolProviderListCache.set_cached_providers(tenant_id, typ, result) - return result diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 714a651839..ab5d5480df 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -5,9 +5,8 @@ from datetime import datetime from typing import Any from sqlalchemy import or_, select +from sqlalchemy.orm import Session -from core.db.session_factory import session_factory -from core.helper.tool_provider_cache import ToolProviderListCache from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_provider import ToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity @@ -86,17 +85,13 @@ class WorkflowToolManageService: except Exception as e: raise ValueError(str(e)) - with session_factory.create_session() as session, session.begin(): + with Session(db.engine, expire_on_commit=False) as session, session.begin(): session.add(workflow_tool_provider) if labels is not None: ToolLabelManager.update_tool_labels( ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) - - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @classmethod @@ -184,9 +179,6 @@ class WorkflowToolManageService: ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @classmethod @@ -249,9 +241,6 @@ class WorkflowToolManageService: db.session.commit() - # Invalidate tool providers cache - ToolProviderListCache.invalidate_cache(tenant_id) - return {"result": "success"} @classmethod diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index d59d5dc0fe..5012defdad 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -48,10 +48,6 @@ class MockModelClass(PluginModelClient): en_US="https://example.com/icon_small.png", zh_Hans="https://example.com/icon_small.png", ), - icon_large=I18nObject( - en_US="https://example.com/icon_large.png", - zh_Hans="https://example.com/icon_large.png", - ), supported_model_types=[ModelType.LLM], configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], models=[ diff --git a/api/tests/test_containers_integration_tests/services/test_billing_service.py b/api/tests/test_containers_integration_tests/services/test_billing_service.py new file mode 100644 index 0000000000..76708b36b1 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_billing_service.py @@ -0,0 +1,365 @@ +import json +from unittest.mock import patch + +import pytest + +from extensions.ext_redis import redis_client +from services.billing_service import BillingService + + +class TestBillingServiceGetPlanBulkWithCache: + """ + Comprehensive integration tests for get_plan_bulk_with_cache using testcontainers. + + This test class covers all major scenarios: + - Cache hit/miss scenarios + - Redis operation failures and fallback behavior + - Invalid cache data handling + - TTL expiration handling + - Error recovery and logging + """ + + @pytest.fixture(autouse=True) + def setup_redis_cleanup(self, flask_app_with_containers): + """Clean up Redis cache before and after each test.""" + with flask_app_with_containers.app_context(): + # Clean up before test + yield + # Clean up after test + # Delete all test cache keys + pattern = f"{BillingService._PLAN_CACHE_KEY_PREFIX}*" + keys = redis_client.keys(pattern) + if keys: + redis_client.delete(*keys) + + def _create_test_plan_data(self, plan: str = "sandbox", expiration_date: int = 1735689600): + """Helper to create test SubscriptionPlan data.""" + return {"plan": plan, "expiration_date": expiration_date} + + def _set_cache(self, tenant_id: str, plan_data: dict, ttl: int = 600): + """Helper to set cache data in Redis.""" + cache_key = BillingService._make_plan_cache_key(tenant_id) + json_str = json.dumps(plan_data) + redis_client.setex(cache_key, ttl, json_str) + + def _get_cache(self, tenant_id: str): + """Helper to get cache data from Redis.""" + cache_key = BillingService._make_plan_cache_key(tenant_id) + value = redis_client.get(cache_key) + if value: + if isinstance(value, bytes): + return value.decode("utf-8") + return value + return None + + def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers): + """Test bulk plan retrieval when all tenants are in cache.""" + with flask_app_with_containers.app_context(): + # Arrange + tenant_ids = ["tenant-1", "tenant-2", "tenant-3"] + expected_plans = { + "tenant-1": self._create_test_plan_data("sandbox", 1735689600), + "tenant-2": self._create_test_plan_data("professional", 1767225600), + "tenant-3": self._create_test_plan_data("team", 1798761600), + } + + # Pre-populate cache + for tenant_id, plan_data in expected_plans.items(): + self._set_cache(tenant_id, plan_data) + + # Act + with patch.object(BillingService, "get_plan_bulk") as mock_get_plan_bulk: + result = BillingService.get_plan_bulk_with_cache(tenant_ids) + + # Assert + assert len(result) == 3 + assert result["tenant-1"]["plan"] == "sandbox" + assert result["tenant-1"]["expiration_date"] == 1735689600 + assert result["tenant-2"]["plan"] == "professional" + assert result["tenant-2"]["expiration_date"] == 1767225600 + assert result["tenant-3"]["plan"] == "team" + assert result["tenant-3"]["expiration_date"] == 1798761600 + + # Verify API was not called + mock_get_plan_bulk.assert_not_called() + + def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers): + """Test bulk plan retrieval when all tenants are not in cache.""" + with flask_app_with_containers.app_context(): + # Arrange + tenant_ids = ["tenant-1", "tenant-2"] + expected_plans = { + "tenant-1": self._create_test_plan_data("sandbox", 1735689600), + "tenant-2": self._create_test_plan_data("professional", 1767225600), + } + + # Act + with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk: + result = BillingService.get_plan_bulk_with_cache(tenant_ids) + + # Assert + assert len(result) == 2 + assert result["tenant-1"]["plan"] == "sandbox" + assert result["tenant-2"]["plan"] == "professional" + + # Verify API was called with correct tenant_ids + mock_get_plan_bulk.assert_called_once_with(tenant_ids) + + # Verify data was written to cache + cached_1 = self._get_cache("tenant-1") + cached_2 = self._get_cache("tenant-2") + assert cached_1 is not None + assert cached_2 is not None + + # Verify cache content + cached_data_1 = json.loads(cached_1) + cached_data_2 = json.loads(cached_2) + assert cached_data_1 == expected_plans["tenant-1"] + assert cached_data_2 == expected_plans["tenant-2"] + + # Verify TTL is set + cache_key_1 = BillingService._make_plan_cache_key("tenant-1") + ttl_1 = redis_client.ttl(cache_key_1) + assert ttl_1 > 0 + assert ttl_1 <= 600 # Should be <= 600 seconds + + def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers): + """Test bulk plan retrieval when some tenants are in cache, some are not.""" + with flask_app_with_containers.app_context(): + # Arrange + tenant_ids = ["tenant-1", "tenant-2", "tenant-3"] + # Pre-populate cache for tenant-1 and tenant-2 + self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600)) + self._set_cache("tenant-2", self._create_test_plan_data("professional", 1767225600)) + + # tenant-3 is not in cache + missing_plan = {"tenant-3": self._create_test_plan_data("team", 1798761600)} + + # Act + with patch.object(BillingService, "get_plan_bulk", return_value=missing_plan) as mock_get_plan_bulk: + result = BillingService.get_plan_bulk_with_cache(tenant_ids) + + # Assert + assert len(result) == 3 + assert result["tenant-1"]["plan"] == "sandbox" + assert result["tenant-2"]["plan"] == "professional" + assert result["tenant-3"]["plan"] == "team" + + # Verify API was called only for missing tenant + mock_get_plan_bulk.assert_called_once_with(["tenant-3"]) + + # Verify tenant-3 data was written to cache + cached_3 = self._get_cache("tenant-3") + assert cached_3 is not None + cached_data_3 = json.loads(cached_3) + assert cached_data_3 == missing_plan["tenant-3"] + + def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers): + """Test fallback to API when Redis mget fails.""" + with flask_app_with_containers.app_context(): + # Arrange + tenant_ids = ["tenant-1", "tenant-2"] + expected_plans = { + "tenant-1": self._create_test_plan_data("sandbox", 1735689600), + "tenant-2": self._create_test_plan_data("professional", 1767225600), + } + + # Act + with ( + patch.object(redis_client, "mget", side_effect=Exception("Redis connection error")), + patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk, + ): + result = BillingService.get_plan_bulk_with_cache(tenant_ids) + + # Assert + assert len(result) == 2 + assert result["tenant-1"]["plan"] == "sandbox" + assert result["tenant-2"]["plan"] == "professional" + + # Verify API was called for all tenants (fallback) + mock_get_plan_bulk.assert_called_once_with(tenant_ids) + + # Verify data was written to cache after fallback + cached_1 = self._get_cache("tenant-1") + cached_2 = self._get_cache("tenant-2") + assert cached_1 is not None + assert cached_2 is not None + + def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers): + """Test fallback to API when cache contains invalid JSON.""" + with flask_app_with_containers.app_context(): + # Arrange + tenant_ids = ["tenant-1", "tenant-2", "tenant-3"] + + # Set valid cache for tenant-1 + self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600)) + + # Set invalid JSON for tenant-2 + cache_key_2 = BillingService._make_plan_cache_key("tenant-2") + redis_client.setex(cache_key_2, 600, "invalid json {") + + # tenant-3 is not in cache + expected_plans = { + "tenant-2": self._create_test_plan_data("professional", 1767225600), + "tenant-3": self._create_test_plan_data("team", 1798761600), + } + + # Act + with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk: + result = BillingService.get_plan_bulk_with_cache(tenant_ids) + + # Assert + assert len(result) == 3 + assert result["tenant-1"]["plan"] == "sandbox" # From cache + assert result["tenant-2"]["plan"] == "professional" # From API (fallback) + assert result["tenant-3"]["plan"] == "team" # From API + + # Verify API was called for tenant-2 and tenant-3 + mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"]) + + # Verify tenant-2's invalid JSON was replaced with correct data in cache + cached_2 = self._get_cache("tenant-2") + assert cached_2 is not None + cached_data_2 = json.loads(cached_2) + assert cached_data_2 == expected_plans["tenant-2"] + assert cached_data_2["plan"] == "professional" + assert cached_data_2["expiration_date"] == 1767225600 + + # Verify tenant-2 cache has correct TTL + cache_key_2_new = BillingService._make_plan_cache_key("tenant-2") + ttl_2 = redis_client.ttl(cache_key_2_new) + assert ttl_2 > 0 + assert ttl_2 <= 600 + + # Verify tenant-3 data was also written to cache + cached_3 = self._get_cache("tenant-3") + assert cached_3 is not None + cached_data_3 = json.loads(cached_3) + assert cached_data_3 == expected_plans["tenant-3"] + + def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers): + """Test fallback to API when cache data doesn't match SubscriptionPlan schema.""" + with flask_app_with_containers.app_context(): + # Arrange + tenant_ids = ["tenant-1", "tenant-2", "tenant-3"] + + # Set valid cache for tenant-1 + self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600)) + + # Set invalid plan data for tenant-2 (missing expiration_date) + cache_key_2 = BillingService._make_plan_cache_key("tenant-2") + invalid_data = json.dumps({"plan": "professional"}) # Missing expiration_date + redis_client.setex(cache_key_2, 600, invalid_data) + + # tenant-3 is not in cache + expected_plans = { + "tenant-2": self._create_test_plan_data("professional", 1767225600), + "tenant-3": self._create_test_plan_data("team", 1798761600), + } + + # Act + with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk: + result = BillingService.get_plan_bulk_with_cache(tenant_ids) + + # Assert + assert len(result) == 3 + assert result["tenant-1"]["plan"] == "sandbox" # From cache + assert result["tenant-2"]["plan"] == "professional" # From API (fallback) + assert result["tenant-3"]["plan"] == "team" # From API + + # Verify API was called for tenant-2 and tenant-3 + mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"]) + + def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers): + """Test that pipeline failure doesn't affect return value.""" + with flask_app_with_containers.app_context(): + # Arrange + tenant_ids = ["tenant-1", "tenant-2"] + expected_plans = { + "tenant-1": self._create_test_plan_data("sandbox", 1735689600), + "tenant-2": self._create_test_plan_data("professional", 1767225600), + } + + # Act + with ( + patch.object(BillingService, "get_plan_bulk", return_value=expected_plans), + patch.object(redis_client, "pipeline") as mock_pipeline, + ): + # Create a mock pipeline that fails on execute + mock_pipe = mock_pipeline.return_value + mock_pipe.execute.side_effect = Exception("Pipeline execution failed") + + result = BillingService.get_plan_bulk_with_cache(tenant_ids) + + # Assert - Function should still return correct result despite pipeline failure + assert len(result) == 2 + assert result["tenant-1"]["plan"] == "sandbox" + assert result["tenant-2"]["plan"] == "professional" + + # Verify pipeline was attempted + mock_pipeline.assert_called_once() + + def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers): + """Test with empty tenant_ids list.""" + with flask_app_with_containers.app_context(): + # Act + with patch.object(BillingService, "get_plan_bulk") as mock_get_plan_bulk: + result = BillingService.get_plan_bulk_with_cache([]) + + # Assert + assert result == {} + assert len(result) == 0 + + # Verify no API calls + mock_get_plan_bulk.assert_not_called() + + # Verify no Redis operations (mget with empty list would return empty list) + # But we should check that mget was not called at all + # Since we can't easily verify this without more mocking, we just verify the result + + def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers): + """Test that expired cache keys are treated as cache misses.""" + with flask_app_with_containers.app_context(): + # Arrange + tenant_ids = ["tenant-1", "tenant-2"] + + # Set cache for tenant-1 with very short TTL (1 second) to simulate expiration + self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600), ttl=1) + + # Wait for TTL to expire (key will be deleted by Redis) + import time + + time.sleep(2) + + # Verify cache is expired (key doesn't exist) + cache_key_1 = BillingService._make_plan_cache_key("tenant-1") + exists = redis_client.exists(cache_key_1) + assert exists == 0 # Key doesn't exist (expired) + + # tenant-2 is not in cache + expected_plans = { + "tenant-1": self._create_test_plan_data("sandbox", 1735689600), + "tenant-2": self._create_test_plan_data("professional", 1767225600), + } + + # Act + with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk: + result = BillingService.get_plan_bulk_with_cache(tenant_ids) + + # Assert + assert len(result) == 2 + assert result["tenant-1"]["plan"] == "sandbox" + assert result["tenant-2"]["plan"] == "professional" + + # Verify API was called for both tenants (tenant-1 expired, tenant-2 missing) + mock_get_plan_bulk.assert_called_once_with(tenant_ids) + + # Verify both were written to cache with correct TTL + cache_key_1_new = BillingService._make_plan_cache_key("tenant-1") + cache_key_2 = BillingService._make_plan_cache_key("tenant-2") + ttl_1_new = redis_client.ttl(cache_key_1_new) + ttl_2 = redis_client.ttl(cache_key_2) + assert ttl_1_new > 0 + assert ttl_1_new <= 600 + assert ttl_2 > 0 + assert ttl_2 <= 600 diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index 612210ef86..d57ab7428b 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -228,7 +228,6 @@ class TestModelProviderService: mock_provider_entity.description = {"en_US": "OpenAI provider", "zh_Hans": "OpenAI ęä¾›å•†"} mock_provider_entity.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"} mock_provider_entity.icon_small_dark = None - mock_provider_entity.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"} mock_provider_entity.background = "#FF6B6B" mock_provider_entity.help = None mock_provider_entity.supported_model_types = [ModelType.LLM, ModelType.TEXT_EMBEDDING] @@ -302,7 +301,6 @@ class TestModelProviderService: mock_provider_entity_llm.description = {"en_US": "OpenAI provider", "zh_Hans": "OpenAI ęä¾›å•†"} mock_provider_entity_llm.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"} mock_provider_entity_llm.icon_small_dark = None - mock_provider_entity_llm.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"} mock_provider_entity_llm.background = "#FF6B6B" mock_provider_entity_llm.help = None mock_provider_entity_llm.supported_model_types = [ModelType.LLM] @@ -316,7 +314,6 @@ class TestModelProviderService: mock_provider_entity_embedding.description = {"en_US": "Cohere provider", "zh_Hans": "Cohere ęä¾›å•†"} mock_provider_entity_embedding.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"} mock_provider_entity_embedding.icon_small_dark = None - mock_provider_entity_embedding.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"} mock_provider_entity_embedding.background = "#4ECDC4" mock_provider_entity_embedding.help = None mock_provider_entity_embedding.supported_model_types = [ModelType.TEXT_EMBEDDING] @@ -419,7 +416,6 @@ class TestModelProviderService: provider="openai", label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"), - icon_large=I18nObject(en_US="icon_large.png", zh_Hans="icon_large.png"), supported_model_types=[ModelType.LLM], configurate_methods=[], models=[], @@ -431,7 +427,6 @@ class TestModelProviderService: provider="openai", label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"), - icon_large=I18nObject(en_US="icon_large.png", zh_Hans="icon_large.png"), supported_model_types=[ModelType.LLM], configurate_methods=[], models=[], @@ -655,7 +650,6 @@ class TestModelProviderService: provider="openai", label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"), - icon_large=I18nObject(en_US="icon_large.png", zh_Hans="icon_large.png"), supported_model_types=[ModelType.LLM], ), ) @@ -1027,7 +1021,6 @@ class TestModelProviderService: label={"en_US": "OpenAI", "zh_Hans": "OpenAI"}, icon_small={"en_US": "icon_small.png", "zh_Hans": "icon_small.png"}, icon_small_dark=None, - icon_large={"en_US": "icon_large.png", "zh_Hans": "icon_large.png"}, ), model="gpt-3.5-turbo", model_type=ModelType.LLM, @@ -1045,7 +1038,6 @@ class TestModelProviderService: label={"en_US": "OpenAI", "zh_Hans": "OpenAI"}, icon_small={"en_US": "icon_small.png", "zh_Hans": "icon_small.png"}, icon_small_dark=None, - icon_large={"en_US": "icon_large.png", "zh_Hans": "icon_large.png"}, ), model="gpt-4", model_type=ModelType.LLM, diff --git a/api/tests/unit_tests/controllers/common/test_fields.py b/api/tests/unit_tests/controllers/common/test_fields.py new file mode 100644 index 0000000000..d4dc13127d --- /dev/null +++ b/api/tests/unit_tests/controllers/common/test_fields.py @@ -0,0 +1,69 @@ +import builtins +from types import SimpleNamespace +from unittest.mock import patch + +from flask.views import MethodView as FlaskMethodView + +_NEEDS_METHOD_VIEW_CLEANUP = False +if not hasattr(builtins, "MethodView"): + builtins.MethodView = FlaskMethodView + _NEEDS_METHOD_VIEW_CLEANUP = True +from controllers.common.fields import Parameters, Site +from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict +from models.model import IconType + + +def test_parameters_model_round_trip(): + parameters = get_parameters_from_feature_dict(features_dict={}, user_input_form=[]) + + model = Parameters.model_validate(parameters) + + assert model.model_dump(mode="json") == parameters + + +def test_site_icon_url_uses_signed_url_for_image_icon(): + site = SimpleNamespace( + title="Example", + chat_color_theme=None, + chat_color_theme_inverted=False, + icon_type=IconType.IMAGE, + icon="file-id", + icon_background=None, + description=None, + copyright=None, + privacy_policy=None, + custom_disclaimer=None, + default_language="en-US", + show_workflow_steps=True, + use_icon_as_answer_icon=False, + ) + + with patch("controllers.common.fields.file_helpers.get_signed_file_url", return_value="signed") as mock_helper: + model = Site.model_validate(site) + + assert model.icon_url == "signed" + mock_helper.assert_called_once_with("file-id") + + +def test_site_icon_url_is_none_for_non_image_icon(): + site = SimpleNamespace( + title="Example", + chat_color_theme=None, + chat_color_theme_inverted=False, + icon_type=IconType.EMOJI, + icon="file-id", + icon_background=None, + description=None, + copyright=None, + privacy_policy=None, + custom_disclaimer=None, + default_language="en-US", + show_workflow_steps=True, + use_icon_as_answer_icon=False, + ) + + with patch("controllers.common.fields.file_helpers.get_signed_file_url") as mock_helper: + model = Site.model_validate(site) + + assert model.icon_url is None + mock_helper.assert_not_called() diff --git a/api/tests/unit_tests/controllers/console/app/test_xss_prevention.py b/api/tests/unit_tests/controllers/console/app/test_xss_prevention.py new file mode 100644 index 0000000000..313818547b --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_xss_prevention.py @@ -0,0 +1,254 @@ +""" +Unit tests for XSS prevention in App payloads. + +This test module validates that HTML tags, JavaScript, and other potentially +dangerous content are rejected in App names and descriptions. +""" + +import pytest + +from controllers.console.app.app import CopyAppPayload, CreateAppPayload, UpdateAppPayload + + +class TestXSSPreventionUnit: + """Unit tests for XSS prevention in App payloads.""" + + def test_create_app_valid_names(self): + """Test CreateAppPayload with valid app names.""" + # Normal app names should be valid + valid_names = [ + "My App", + "Test App 123", + "App with - dash", + "App with _ underscore", + "App with + plus", + "App with () parentheses", + "App with [] brackets", + "App with {} braces", + "App with ! exclamation", + "App with @ at", + "App with # hash", + "App with $ dollar", + "App with % percent", + "App with ^ caret", + "App with & ampersand", + "App with * asterisk", + "Unicode: 测试应用", + "Emoji: šŸ¤–", + "Mixed: Test 测试 123", + ] + + for name in valid_names: + payload = CreateAppPayload( + name=name, + mode="chat", + ) + assert payload.name == name + + def test_create_app_xss_script_tags(self): + """Test CreateAppPayload rejects script tags.""" + xss_payloads = [ + "", + "", + "", + "", + "", + "", + ] + + for name in xss_payloads: + with pytest.raises(ValueError) as exc_info: + CreateAppPayload(name=name, mode="chat") + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_create_app_xss_iframe_tags(self): + """Test CreateAppPayload rejects iframe tags.""" + xss_payloads = [ + "", + "", + "", + ] + + for name in xss_payloads: + with pytest.raises(ValueError) as exc_info: + CreateAppPayload(name=name, mode="chat") + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_create_app_xss_javascript_protocol(self): + """Test CreateAppPayload rejects javascript: protocol.""" + xss_payloads = [ + "javascript:alert(1)", + "JAVASCRIPT:alert(1)", + "JavaScript:alert(document.cookie)", + "javascript:void(0)", + "javascript://comment%0Aalert(1)", + ] + + for name in xss_payloads: + with pytest.raises(ValueError) as exc_info: + CreateAppPayload(name=name, mode="chat") + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_create_app_xss_svg_onload(self): + """Test CreateAppPayload rejects SVG with onload.""" + xss_payloads = [ + "", + "", + "", + ] + + for name in xss_payloads: + with pytest.raises(ValueError) as exc_info: + CreateAppPayload(name=name, mode="chat") + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_create_app_xss_event_handlers(self): + """Test CreateAppPayload rejects HTML event handlers.""" + xss_payloads = [ + "
", + "", + "", + "", + "", + "
", + "", + ] + + for name in xss_payloads: + with pytest.raises(ValueError) as exc_info: + CreateAppPayload(name=name, mode="chat") + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_create_app_xss_object_embed(self): + """Test CreateAppPayload rejects object and embed tags.""" + xss_payloads = [ + "", + "", + "", + ] + + for name in xss_payloads: + with pytest.raises(ValueError) as exc_info: + CreateAppPayload(name=name, mode="chat") + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_create_app_xss_link_javascript(self): + """Test CreateAppPayload rejects link tags with javascript.""" + xss_payloads = [ + "", + "", + ] + + for name in xss_payloads: + with pytest.raises(ValueError) as exc_info: + CreateAppPayload(name=name, mode="chat") + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_create_app_xss_in_description(self): + """Test CreateAppPayload rejects XSS in description.""" + xss_descriptions = [ + "", + "javascript:alert(1)", + "", + ] + + for description in xss_descriptions: + with pytest.raises(ValueError) as exc_info: + CreateAppPayload( + name="Valid Name", + mode="chat", + description=description, + ) + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_create_app_valid_descriptions(self): + """Test CreateAppPayload with valid descriptions.""" + valid_descriptions = [ + "A simple description", + "Description with < and > symbols", + "Description with & ampersand", + "Description with 'quotes' and \"double quotes\"", + "Description with / slashes", + "Description with \\ backslashes", + "Description with ; semicolons", + "Unicode: čæ™ę˜Æäø€äøŖęčæ°", + "Emoji: šŸŽ‰šŸš€", + ] + + for description in valid_descriptions: + payload = CreateAppPayload( + name="Valid App Name", + mode="chat", + description=description, + ) + assert payload.description == description + + def test_create_app_none_description(self): + """Test CreateAppPayload with None description.""" + payload = CreateAppPayload( + name="Valid App Name", + mode="chat", + description=None, + ) + assert payload.description is None + + def test_update_app_xss_prevention(self): + """Test UpdateAppPayload also prevents XSS.""" + xss_names = [ + "", + "javascript:alert(1)", + "", + ] + + for name in xss_names: + with pytest.raises(ValueError) as exc_info: + UpdateAppPayload(name=name) + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_update_app_valid_names(self): + """Test UpdateAppPayload with valid names.""" + payload = UpdateAppPayload(name="Valid Updated Name") + assert payload.name == "Valid Updated Name" + + def test_copy_app_xss_prevention(self): + """Test CopyAppPayload also prevents XSS.""" + xss_names = [ + "", + "javascript:alert(1)", + "", + ] + + for name in xss_names: + with pytest.raises(ValueError) as exc_info: + CopyAppPayload(name=name) + assert "invalid characters or patterns" in str(exc_info.value).lower() + + def test_copy_app_valid_names(self): + """Test CopyAppPayload with valid names.""" + payload = CopyAppPayload(name="Valid Copy Name") + assert payload.name == "Valid Copy Name" + + def test_copy_app_none_name(self): + """Test CopyAppPayload with None name (should be allowed).""" + payload = CopyAppPayload(name=None) + assert payload.name is None + + def test_edge_case_angle_brackets_content(self): + """Test that angle brackets with actual content are rejected.""" + # Angle brackets without valid HTML-like patterns should be checked + # The regex pattern <.*?on\w+\s*= should catch event handlers + # But let's verify other patterns too + + # Valid: angle brackets used as symbols (not matched by our patterns) + # Our patterns specifically look for dangerous constructs + + # Invalid: actual HTML tags with event handlers + invalid_names = [ + "
", + "", + ] + + for name in invalid_names: + with pytest.raises(ValueError) as exc_info: + CreateAppPayload(name=name, mode="chat") + assert "invalid characters or patterns" in str(exc_info.value).lower() diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_oauth.py index 399caf8c4d..3ddfcdb832 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth.py @@ -171,7 +171,7 @@ class TestOAuthCallback: ): mock_config.CONSOLE_WEB_URL = "http://localhost:3000" mock_get_providers.return_value = {"github": oauth_setup["provider"]} - mock_generate_account.return_value = oauth_setup["account"] + mock_generate_account.return_value = (oauth_setup["account"], True) mock_account_service.login.return_value = oauth_setup["token_pair"] with app.test_request_context("/auth/oauth/github/callback?code=test_code"): @@ -179,7 +179,7 @@ class TestOAuthCallback: oauth_setup["provider"].get_access_token.assert_called_once_with("test_code") oauth_setup["provider"].get_user_info.assert_called_once_with("access_token") - mock_redirect.assert_called_once_with("http://localhost:3000") + mock_redirect.assert_called_once_with("http://localhost:3000?oauth_new_user=true") @pytest.mark.parametrize( ("exception", "expected_error"), @@ -223,7 +223,7 @@ class TestOAuthCallback: # This documents actual behavior. See test_defensive_check_for_closed_account_status for details ( AccountStatus.CLOSED.value, - "http://localhost:3000", + "http://localhost:3000?oauth_new_user=false", ), ], ) @@ -260,7 +260,7 @@ class TestOAuthCallback: account = MagicMock() account.status = account_status account.id = "123" - mock_generate_account.return_value = account + mock_generate_account.return_value = (account, False) # Mock login for CLOSED status mock_token_pair = MagicMock() @@ -296,7 +296,7 @@ class TestOAuthCallback: mock_account = MagicMock() mock_account.status = AccountStatus.PENDING - mock_generate_account.return_value = mock_account + mock_generate_account.return_value = (mock_account, False) mock_token_pair = MagicMock() mock_token_pair.access_token = "jwt_access_token" @@ -360,7 +360,7 @@ class TestOAuthCallback: closed_account.status = AccountStatus.CLOSED closed_account.id = "123" closed_account.name = "Closed Account" - mock_generate_account.return_value = closed_account + mock_generate_account.return_value = (closed_account, False) # Mock successful login (current behavior) mock_token_pair = MagicMock() @@ -374,7 +374,7 @@ class TestOAuthCallback: resource.get("github") # Verify current behavior: login succeeds (this is NOT ideal) - mock_redirect.assert_called_once_with("http://localhost:3000") + mock_redirect.assert_called_once_with("http://localhost:3000?oauth_new_user=false") mock_account_service.login.assert_called_once() # Document expected behavior in comments: @@ -458,8 +458,9 @@ class TestAccountGeneration: with pytest.raises(AccountRegisterError): _generate_account("github", user_info) else: - result = _generate_account("github", user_info) + result, oauth_new_user = _generate_account("github", user_info) assert result == mock_account + assert oauth_new_user == should_create if should_create: mock_register_service.register.assert_called_once_with( @@ -490,9 +491,10 @@ class TestAccountGeneration: mock_tenant_service.create_tenant.return_value = mock_new_tenant with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}): - result = _generate_account("github", user_info) + result, oauth_new_user = _generate_account("github", user_info) assert result == mock_account + assert oauth_new_user is False mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace") mock_tenant_service.create_tenant_member.assert_called_once_with( mock_new_tenant, mock_account, role="owner" diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py index 2b03813ef4..c608f731c5 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py @@ -41,13 +41,10 @@ def client(): @patch( "controllers.console.workspace.tool_providers.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1") ) -@patch("controllers.console.workspace.tool_providers.ToolProviderListCache.invalidate_cache", return_value=None) @patch("controllers.console.workspace.tool_providers.Session") @patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url") @pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant") -def test_create_mcp_provider_populates_tools( - mock_reconnect, mock_session, mock_invalidate_cache, mock_current_account_with_tenant, client -): +def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_current_account_with_tenant, client): # Arrange: reconnect returns tools immediately mock_reconnect.return_value = ReconnectResult( authed=True, diff --git a/api/tests/unit_tests/core/helper/test_tool_provider_cache.py b/api/tests/unit_tests/core/helper/test_tool_provider_cache.py deleted file mode 100644 index d237c68f35..0000000000 --- a/api/tests/unit_tests/core/helper/test_tool_provider_cache.py +++ /dev/null @@ -1,126 +0,0 @@ -import json -from unittest.mock import patch - -import pytest -from redis.exceptions import RedisError - -from core.helper.tool_provider_cache import ToolProviderListCache -from core.tools.entities.api_entities import ToolProviderTypeApiLiteral - - -@pytest.fixture -def mock_redis_client(): - """Fixture: Mock Redis client""" - with patch("core.helper.tool_provider_cache.redis_client") as mock: - yield mock - - -class TestToolProviderListCache: - """Test class for ToolProviderListCache""" - - def test_generate_cache_key(self): - """Test cache key generation logic""" - # Scenario 1: Specify typ (valid literal value) - tenant_id = "tenant_123" - typ: ToolProviderTypeApiLiteral = "builtin" - expected_key = f"tool_providers:tenant_id:{tenant_id}:type:{typ}" - assert ToolProviderListCache._generate_cache_key(tenant_id, typ) == expected_key - - # Scenario 2: typ is None (defaults to "all") - expected_key_all = f"tool_providers:tenant_id:{tenant_id}:type:all" - assert ToolProviderListCache._generate_cache_key(tenant_id) == expected_key_all - - def test_get_cached_providers_hit(self, mock_redis_client): - """Test get cached providers - cache hit and successful decoding""" - tenant_id = "tenant_123" - typ: ToolProviderTypeApiLiteral = "api" - mock_providers = [{"id": "tool", "name": "test_provider"}] - mock_redis_client.get.return_value = json.dumps(mock_providers).encode("utf-8") - - result = ToolProviderListCache.get_cached_providers(tenant_id, typ) - - mock_redis_client.get.assert_called_once_with(ToolProviderListCache._generate_cache_key(tenant_id, typ)) - assert result == mock_providers - - def test_get_cached_providers_decode_error(self, mock_redis_client): - """Test get cached providers - cache hit but decoding failed""" - tenant_id = "tenant_123" - mock_redis_client.get.return_value = b"invalid_json_data" - - result = ToolProviderListCache.get_cached_providers(tenant_id) - - assert result is None - mock_redis_client.get.assert_called_once() - - def test_get_cached_providers_miss(self, mock_redis_client): - """Test get cached providers - cache miss""" - tenant_id = "tenant_123" - mock_redis_client.get.return_value = None - - result = ToolProviderListCache.get_cached_providers(tenant_id) - - assert result is None - mock_redis_client.get.assert_called_once() - - def test_set_cached_providers(self, mock_redis_client): - """Test set cached providers""" - tenant_id = "tenant_123" - typ: ToolProviderTypeApiLiteral = "builtin" - mock_providers = [{"id": "tool", "name": "test_provider"}] - cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) - - ToolProviderListCache.set_cached_providers(tenant_id, typ, mock_providers) - - mock_redis_client.setex.assert_called_once_with( - cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(mock_providers) - ) - - def test_invalidate_cache_specific_type(self, mock_redis_client): - """Test invalidate cache - specific type""" - tenant_id = "tenant_123" - typ: ToolProviderTypeApiLiteral = "workflow" - cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) - - ToolProviderListCache.invalidate_cache(tenant_id, typ) - - mock_redis_client.delete.assert_called_once_with(cache_key) - - def test_invalidate_cache_all_types(self, mock_redis_client): - """Test invalidate cache - clear all tenant cache""" - tenant_id = "tenant_123" - mock_keys = [ - b"tool_providers:tenant_id:tenant_123:type:all", - b"tool_providers:tenant_id:tenant_123:type:builtin", - ] - mock_redis_client.scan_iter.return_value = mock_keys - - ToolProviderListCache.invalidate_cache(tenant_id) - - def test_invalidate_cache_no_keys(self, mock_redis_client): - """Test invalidate cache - no cache keys for tenant""" - tenant_id = "tenant_123" - mock_redis_client.scan_iter.return_value = [] - - ToolProviderListCache.invalidate_cache(tenant_id) - - mock_redis_client.delete.assert_not_called() - - def test_redis_fallback_default_return(self, mock_redis_client): - """Test redis_fallback decorator - default return value (Redis error)""" - mock_redis_client.get.side_effect = RedisError("Redis connection error") - - result = ToolProviderListCache.get_cached_providers("tenant_123") - - assert result is None - mock_redis_client.get.assert_called_once() - - def test_redis_fallback_no_default(self, mock_redis_client): - """Test redis_fallback decorator - no default return value (Redis error)""" - mock_redis_client.setex.side_effect = RedisError("Redis connection error") - - try: - ToolProviderListCache.set_cached_providers("tenant_123", "mcp", []) - except RedisError: - pytest.fail("set_cached_providers should not raise RedisError (handled by fallback)") - - mock_redis_client.setex.assert_called_once() diff --git a/api/tests/unit_tests/core/rag/cleaner/__init__.py b/api/tests/unit_tests/core/rag/cleaner/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py b/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py new file mode 100644 index 0000000000..65ee62b8dd --- /dev/null +++ b/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py @@ -0,0 +1,213 @@ +from core.rag.cleaner.clean_processor import CleanProcessor + + +class TestCleanProcessor: + """Test cases for CleanProcessor.clean method.""" + + def test_clean_default_removal_of_invalid_symbols(self): + """Test default cleaning removes invalid symbols.""" + # Test <| replacement + assert CleanProcessor.clean("text<|with<|invalid", None) == "text replacement + assert CleanProcessor.clean("text|>with|>invalid", None) == "text>with>invalid" + + # Test removal of control characters + text_with_control = "normal\x00text\x1fwith\x07control\x7fchars" + expected = "normaltextwithcontrolchars" + assert CleanProcessor.clean(text_with_control, None) == expected + + # Test U+FFFE removal + text_with_ufffe = "normal\ufffepadding" + expected = "normalpadding" + assert CleanProcessor.clean(text_with_ufffe, None) == expected + + def test_clean_with_none_process_rule(self): + """Test cleaning with None process_rule - only default cleaning applied.""" + text = "Hello<|World\x00" + expected = "Hello becomes >, control chars and U+FFFE are removed + assert CleanProcessor.clean(text, None) == "<<>>" + + def test_clean_multiple_markdown_links_preserved(self): + """Test multiple markdown links are all preserved.""" + process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}} + + text = "[One](https://one.com) [Two](http://two.org) [Three](https://three.net)" + expected = "[One](https://one.com) [Two](http://two.org) [Three](https://three.net)" + assert CleanProcessor.clean(text, process_rule) == expected + + def test_clean_markdown_link_text_as_url(self): + """Test markdown link where the link text itself is a URL.""" + process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}} + + # Link text that looks like URL should be preserved + text = "[https://text-url.com](https://actual-url.com)" + expected = "[https://text-url.com](https://actual-url.com)" + assert CleanProcessor.clean(text, process_rule) == expected + + # Text URL without markdown should be removed + text = "https://text-url.com https://actual-url.com" + expected = " " + assert CleanProcessor.clean(text, process_rule) == expected + + def test_clean_complex_markdown_link_content(self): + """Test markdown links with complex content - known limitation with brackets in link text.""" + process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}} + + # Note: The regex pattern [^\]]* cannot handle ] within link text + # This is a known limitation - the pattern stops at the first ] + text = "[Text with [brackets] and (parens)](https://example.com)" + # Actual behavior: only matches up to first ], URL gets removed + expected = "[Text with [brackets] and (parens)](" + assert CleanProcessor.clean(text, process_rule) == expected + + # Test that properly formatted markdown links work + text = "[Text with (parens) and symbols](https://example.com)" + expected = "[Text with (parens) and symbols](https://example.com)" + assert CleanProcessor.clean(text, process_rule) == expected diff --git a/api/tests/unit_tests/core/rag/extractor/test_pdf_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_pdf_extractor.py new file mode 100644 index 0000000000..3167a9a301 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_pdf_extractor.py @@ -0,0 +1,186 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import core.rag.extractor.pdf_extractor as pe + + +@pytest.fixture +def mock_dependencies(monkeypatch): + # Mock storage + saves = [] + + def save(key, data): + saves.append((key, data)) + + monkeypatch.setattr(pe, "storage", SimpleNamespace(save=save)) + + # Mock db + class DummySession: + def __init__(self): + self.added = [] + self.committed = False + + def add(self, obj): + self.added.append(obj) + + def add_all(self, objs): + self.added.extend(objs) + + def commit(self): + self.committed = True + + db_stub = SimpleNamespace(session=DummySession()) + monkeypatch.setattr(pe, "db", db_stub) + + # Mock UploadFile + class FakeUploadFile: + DEFAULT_ID = "test_file_id" + + def __init__(self, **kwargs): + # Assign id from DEFAULT_ID, allow override via kwargs if needed + self.id = self.DEFAULT_ID + for k, v in kwargs.items(): + setattr(self, k, v) + + monkeypatch.setattr(pe, "UploadFile", FakeUploadFile) + + # Mock config + monkeypatch.setattr(pe.dify_config, "FILES_URL", "http://files.local") + monkeypatch.setattr(pe.dify_config, "INTERNAL_FILES_URL", None) + monkeypatch.setattr(pe.dify_config, "STORAGE_TYPE", "local") + + return SimpleNamespace(saves=saves, db=db_stub, UploadFile=FakeUploadFile) + + +@pytest.mark.parametrize( + ("image_bytes", "expected_mime", "expected_ext", "file_id"), + [ + (b"\xff\xd8\xff some jpeg", "image/jpeg", "jpg", "test_file_id_jpeg"), + (b"\x89PNG\r\n\x1a\n some png", "image/png", "png", "test_file_id_png"), + ], +) +def test_extract_images_formats(mock_dependencies, monkeypatch, image_bytes, expected_mime, expected_ext, file_id): + saves = mock_dependencies.saves + db_stub = mock_dependencies.db + + # Customize FakeUploadFile id for this test case. + # Using monkeypatch ensures the class attribute is reset between parameter sets. + monkeypatch.setattr(mock_dependencies.UploadFile, "DEFAULT_ID", file_id) + + # Mock page and image objects + mock_page = MagicMock() + mock_image_obj = MagicMock() + + def mock_extract(buf, fb_format=None): + buf.write(image_bytes) + + mock_image_obj.extract.side_effect = mock_extract + + mock_page.get_objects.return_value = [mock_image_obj] + + extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1") + + # We need to handle the import inside _extract_images + with patch("pypdfium2.raw") as mock_raw: + mock_raw.FPDF_PAGEOBJ_IMAGE = 1 + result = extractor._extract_images(mock_page) + + assert f"![image](http://files.local/files/{file_id}/file-preview)" in result + assert len(saves) == 1 + assert saves[0][1] == image_bytes + assert len(db_stub.session.added) == 1 + assert db_stub.session.added[0].tenant_id == "t1" + assert db_stub.session.added[0].size == len(image_bytes) + assert db_stub.session.added[0].mime_type == expected_mime + assert db_stub.session.added[0].extension == expected_ext + assert db_stub.session.committed is True + + +@pytest.mark.parametrize( + ("get_objects_side_effect", "get_objects_return_value"), + [ + (None, []), # Empty list + (None, None), # None returned + (Exception("Failed to get objects"), None), # Exception raised + ], +) +def test_extract_images_get_objects_scenarios(mock_dependencies, get_objects_side_effect, get_objects_return_value): + mock_page = MagicMock() + if get_objects_side_effect: + mock_page.get_objects.side_effect = get_objects_side_effect + else: + mock_page.get_objects.return_value = get_objects_return_value + + extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1") + + with patch("pypdfium2.raw") as mock_raw: + mock_raw.FPDF_PAGEOBJ_IMAGE = 1 + result = extractor._extract_images(mock_page) + + assert result == "" + + +def test_extract_calls_extract_images(mock_dependencies, monkeypatch): + # Mock pypdfium2 + mock_pdf_doc = MagicMock() + mock_page = MagicMock() + mock_pdf_doc.__iter__.return_value = [mock_page] + + # Mock text extraction + mock_text_page = MagicMock() + mock_text_page.get_text_range.return_value = "Page text content" + mock_page.get_textpage.return_value = mock_text_page + + with patch("pypdfium2.PdfDocument", return_value=mock_pdf_doc): + # Mock Blob + mock_blob = MagicMock() + mock_blob.source = "test.pdf" + with patch("core.rag.extractor.pdf_extractor.Blob.from_path", return_value=mock_blob): + extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1") + + # Mock _extract_images to return a known string + monkeypatch.setattr(extractor, "_extract_images", lambda p: "![image](img_url)") + + documents = list(extractor.extract()) + + assert len(documents) == 1 + assert "Page text content" in documents[0].page_content + assert "![image](img_url)" in documents[0].page_content + assert documents[0].metadata["page"] == 0 + + +def test_extract_images_failures(mock_dependencies): + saves = mock_dependencies.saves + db_stub = mock_dependencies.db + + # Mock page and image objects + mock_page = MagicMock() + mock_image_obj_fail = MagicMock() + mock_image_obj_ok = MagicMock() + + # First image raises exception + mock_image_obj_fail.extract.side_effect = Exception("Extraction failure") + + # Second image is OK (JPEG) + jpeg_bytes = b"\xff\xd8\xff some image data" + + def mock_extract(buf, fb_format=None): + buf.write(jpeg_bytes) + + mock_image_obj_ok.extract.side_effect = mock_extract + + mock_page.get_objects.return_value = [mock_image_obj_fail, mock_image_obj_ok] + + extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1") + + with patch("pypdfium2.raw") as mock_raw: + mock_raw.FPDF_PAGEOBJ_IMAGE = 1 + result = extractor._extract_images(mock_page) + + # Should have one success + assert "![image](http://files.local/files/test_file_id/file-preview)" in result + assert len(saves) == 1 + assert saves[0][1] == jpeg_bytes + assert db_stub.session.committed is True diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index affd6c648f..6306d665e7 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -421,7 +421,18 @@ class TestRetrievalService: # In real code, this waits for all futures to complete # In tests, futures complete immediately, so wait is a no-op with patch("core.rag.datasource.retrieval_service.concurrent.futures.wait"): - yield mock_executor + # Mock concurrent.futures.as_completed for early error propagation + # In real code, this yields futures as they complete + # In tests, we yield all futures immediately since they're already done + def mock_as_completed(futures_list, timeout=None): + """Mock as_completed that yields futures immediately.""" + yield from futures_list + + with patch( + "core.rag.datasource.retrieval_service.concurrent.futures.as_completed", + side_effect=mock_as_completed, + ): + yield mock_executor # ==================== Vector Search Tests ==================== diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py index 9060cf7b6c..636fac7a40 100644 --- a/api/tests/unit_tests/core/test_provider_configuration.py +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -32,7 +32,6 @@ def mock_provider_entity(): label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), description=I18nObject(en_US="OpenAI provider", zh_Hans="OpenAI ęä¾›å•†"), icon_small=I18nObject(en_US="icon.png", zh_Hans="icon.png"), - icon_large=I18nObject(en_US="icon.png", zh_Hans="icon.png"), background="background.png", help=None, supported_model_types=[ModelType.LLM], diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py new file mode 100644 index 0000000000..cf8811dc2b --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py @@ -0,0 +1 @@ +"""Tests for graph traversal components.""" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py new file mode 100644 index 0000000000..0019020ede --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py @@ -0,0 +1,307 @@ +"""Unit tests for skip propagator.""" + +from unittest.mock import MagicMock, create_autospec + +from core.workflow.graph import Edge, Graph +from core.workflow.graph_engine.graph_state_manager import GraphStateManager +from core.workflow.graph_engine.graph_traversal.skip_propagator import SkipPropagator + + +class TestSkipPropagator: + """Test suite for SkipPropagator.""" + + def test_propagate_skip_from_edge_with_unknown_edges_stops_processing(self) -> None: + """When there are unknown incoming edges, propagation should stop.""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + # Create a mock edge + mock_edge = MagicMock(spec=Edge) + mock_edge.id = "edge_1" + mock_edge.head = "node_2" + + # Setup graph edges dict + mock_graph.edges = {"edge_1": mock_edge} + + # Setup incoming edges + incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge)] + mock_graph.get_incoming_edges.return_value = incoming_edges + + # Setup state manager to return has_unknown=True + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": True, + "has_taken": False, + "all_skipped": False, + } + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert + mock_graph.get_incoming_edges.assert_called_once_with("node_2") + mock_state_manager.analyze_edge_states.assert_called_once_with(incoming_edges) + # Should not call any other state manager methods + mock_state_manager.enqueue_node.assert_not_called() + mock_state_manager.start_execution.assert_not_called() + mock_state_manager.mark_node_skipped.assert_not_called() + + def test_propagate_skip_from_edge_with_taken_edge_enqueues_node(self) -> None: + """When there is at least one taken edge, node should be enqueued.""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + # Create a mock edge + mock_edge = MagicMock(spec=Edge) + mock_edge.id = "edge_1" + mock_edge.head = "node_2" + + mock_graph.edges = {"edge_1": mock_edge} + incoming_edges = [MagicMock(spec=Edge)] + mock_graph.get_incoming_edges.return_value = incoming_edges + + # Setup state manager to return has_taken=True + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": False, + "has_taken": True, + "all_skipped": False, + } + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert + mock_state_manager.enqueue_node.assert_called_once_with("node_2") + mock_state_manager.start_execution.assert_called_once_with("node_2") + mock_state_manager.mark_node_skipped.assert_not_called() + + def test_propagate_skip_from_edge_with_all_skipped_propagates_to_node(self) -> None: + """When all incoming edges are skipped, should propagate skip to node.""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + # Create a mock edge + mock_edge = MagicMock(spec=Edge) + mock_edge.id = "edge_1" + mock_edge.head = "node_2" + + mock_graph.edges = {"edge_1": mock_edge} + incoming_edges = [MagicMock(spec=Edge)] + mock_graph.get_incoming_edges.return_value = incoming_edges + + # Setup state manager to return all_skipped=True + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": False, + "has_taken": False, + "all_skipped": True, + } + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert + mock_state_manager.mark_node_skipped.assert_called_once_with("node_2") + mock_state_manager.enqueue_node.assert_not_called() + mock_state_manager.start_execution.assert_not_called() + + def test_propagate_skip_to_node_marks_node_and_outgoing_edges_skipped(self) -> None: + """_propagate_skip_to_node should mark node and all outgoing edges as skipped.""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + # Create outgoing edges + edge1 = MagicMock(spec=Edge) + edge1.id = "edge_2" + edge1.head = "node_downstream_1" # Set head for propagate_skip_from_edge + + edge2 = MagicMock(spec=Edge) + edge2.id = "edge_3" + edge2.head = "node_downstream_2" + + # Setup graph edges dict for propagate_skip_from_edge + mock_graph.edges = {"edge_2": edge1, "edge_3": edge2} + mock_graph.get_outgoing_edges.return_value = [edge1, edge2] + + # Setup get_incoming_edges to return empty list to stop recursion + mock_graph.get_incoming_edges.return_value = [] + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Use mock to call private method + # Act + propagator._propagate_skip_to_node("node_1") + + # Assert + mock_state_manager.mark_node_skipped.assert_called_once_with("node_1") + mock_state_manager.mark_edge_skipped.assert_any_call("edge_2") + mock_state_manager.mark_edge_skipped.assert_any_call("edge_3") + assert mock_state_manager.mark_edge_skipped.call_count == 2 + # Should recursively propagate from each edge + # Since propagate_skip_from_edge is called, we need to verify it was called + # But we can't directly verify due to recursion. We'll trust the logic. + + def test_skip_branch_paths_marks_unselected_edges_and_propagates(self) -> None: + """skip_branch_paths should mark all unselected edges as skipped and propagate.""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + # Create unselected edges + edge1 = MagicMock(spec=Edge) + edge1.id = "edge_1" + edge1.head = "node_downstream_1" + + edge2 = MagicMock(spec=Edge) + edge2.id = "edge_2" + edge2.head = "node_downstream_2" + + unselected_edges = [edge1, edge2] + + # Setup graph edges dict + mock_graph.edges = {"edge_1": edge1, "edge_2": edge2} + # Setup get_incoming_edges to return empty list to stop recursion + mock_graph.get_incoming_edges.return_value = [] + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Act + propagator.skip_branch_paths(unselected_edges) + + # Assert + mock_state_manager.mark_edge_skipped.assert_any_call("edge_1") + mock_state_manager.mark_edge_skipped.assert_any_call("edge_2") + assert mock_state_manager.mark_edge_skipped.call_count == 2 + # propagate_skip_from_edge should be called for each edge + # We can't directly verify due to the mock, but the logic is covered + + def test_propagate_skip_from_edge_recursively_propagates_through_graph(self) -> None: + """Skip propagation should recursively propagate through the graph.""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + # Create edge chain: edge_1 -> node_2 -> edge_3 -> node_4 + edge1 = MagicMock(spec=Edge) + edge1.id = "edge_1" + edge1.head = "node_2" + + edge3 = MagicMock(spec=Edge) + edge3.id = "edge_3" + edge3.head = "node_4" + + mock_graph.edges = {"edge_1": edge1, "edge_3": edge3} + + # Setup get_incoming_edges to return different values based on node + def get_incoming_edges_side_effect(node_id): + if node_id == "node_2": + return [edge1] + elif node_id == "node_4": + return [edge3] + return [] + + mock_graph.get_incoming_edges.side_effect = get_incoming_edges_side_effect + + # Setup get_outgoing_edges to return different values based on node + def get_outgoing_edges_side_effect(node_id): + if node_id == "node_2": + return [edge3] + elif node_id == "node_4": + return [] # No outgoing edges, stops recursion + return [] + + mock_graph.get_outgoing_edges.side_effect = get_outgoing_edges_side_effect + + # Setup state manager to return all_skipped for both nodes + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": False, + "has_taken": False, + "all_skipped": True, + } + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert + # Should mark node_2 as skipped + mock_state_manager.mark_node_skipped.assert_any_call("node_2") + # Should mark edge_3 as skipped + mock_state_manager.mark_edge_skipped.assert_any_call("edge_3") + # Should propagate to node_4 + mock_state_manager.mark_node_skipped.assert_any_call("node_4") + assert mock_state_manager.mark_node_skipped.call_count == 2 + + def test_propagate_skip_from_edge_with_mixed_edge_states_handles_correctly(self) -> None: + """Test with mixed edge states (some unknown, some taken, some skipped).""" + # Arrange + mock_graph = create_autospec(Graph) + mock_state_manager = create_autospec(GraphStateManager) + + mock_edge = MagicMock(spec=Edge) + mock_edge.id = "edge_1" + mock_edge.head = "node_2" + + mock_graph.edges = {"edge_1": mock_edge} + incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge), MagicMock(spec=Edge)] + mock_graph.get_incoming_edges.return_value = incoming_edges + + # Test 1: has_unknown=True, has_taken=False, all_skipped=False + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": True, + "has_taken": False, + "all_skipped": False, + } + + propagator = SkipPropagator(mock_graph, mock_state_manager) + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert - should stop processing + mock_state_manager.enqueue_node.assert_not_called() + mock_state_manager.mark_node_skipped.assert_not_called() + + # Reset mocks for next test + mock_state_manager.reset_mock() + mock_graph.reset_mock() + + # Test 2: has_unknown=False, has_taken=True, all_skipped=False + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": False, + "has_taken": True, + "all_skipped": False, + } + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert - should enqueue node + mock_state_manager.enqueue_node.assert_called_once_with("node_2") + mock_state_manager.start_execution.assert_called_once_with("node_2") + + # Reset mocks for next test + mock_state_manager.reset_mock() + mock_graph.reset_mock() + + # Test 3: has_unknown=False, has_taken=False, all_skipped=True + mock_state_manager.analyze_edge_states.return_value = { + "has_unknown": False, + "has_taken": False, + "all_skipped": True, + } + + # Act + propagator.propagate_skip_from_edge("edge_1") + + # Assert - should propagate skip + mock_state_manager.mark_node_skipped.assert_called_once_with("node_2") diff --git a/api/tests/unit_tests/extensions/test_celery_ssl.py b/api/tests/unit_tests/extensions/test_celery_ssl.py index fc7a090ef9..d3a4d69f07 100644 --- a/api/tests/unit_tests/extensions/test_celery_ssl.py +++ b/api/tests/unit_tests/extensions/test_celery_ssl.py @@ -8,11 +8,12 @@ class TestCelerySSLConfiguration: """Test suite for Celery SSL configuration.""" def test_get_celery_ssl_options_when_ssl_disabled(self): - """Test SSL options when REDIS_USE_SSL is False.""" - mock_config = MagicMock() - mock_config.REDIS_USE_SSL = False + """Test SSL options when BROKER_USE_SSL is False.""" + from configs import DifyConfig - with patch("extensions.ext_celery.dify_config", mock_config): + dify_config = DifyConfig(CELERY_BROKER_URL="redis://localhost:6379/0") + + with patch("extensions.ext_celery.dify_config", dify_config): from extensions.ext_celery import _get_celery_ssl_options result = _get_celery_ssl_options() @@ -21,7 +22,6 @@ class TestCelerySSLConfiguration: def test_get_celery_ssl_options_when_broker_not_redis(self): """Test SSL options when broker is not Redis.""" mock_config = MagicMock() - mock_config.REDIS_USE_SSL = True mock_config.CELERY_BROKER_URL = "amqp://localhost:5672" with patch("extensions.ext_celery.dify_config", mock_config): @@ -33,7 +33,6 @@ class TestCelerySSLConfiguration: def test_get_celery_ssl_options_with_cert_none(self): """Test SSL options with CERT_NONE requirement.""" mock_config = MagicMock() - mock_config.REDIS_USE_SSL = True mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" mock_config.REDIS_SSL_CERT_REQS = "CERT_NONE" mock_config.REDIS_SSL_CA_CERTS = None @@ -53,7 +52,6 @@ class TestCelerySSLConfiguration: def test_get_celery_ssl_options_with_cert_required(self): """Test SSL options with CERT_REQUIRED and certificates.""" mock_config = MagicMock() - mock_config.REDIS_USE_SSL = True mock_config.CELERY_BROKER_URL = "rediss://localhost:6380/0" mock_config.REDIS_SSL_CERT_REQS = "CERT_REQUIRED" mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt" @@ -73,7 +71,6 @@ class TestCelerySSLConfiguration: def test_get_celery_ssl_options_with_cert_optional(self): """Test SSL options with CERT_OPTIONAL requirement.""" mock_config = MagicMock() - mock_config.REDIS_USE_SSL = True mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" mock_config.REDIS_SSL_CERT_REQS = "CERT_OPTIONAL" mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt" @@ -91,7 +88,6 @@ class TestCelerySSLConfiguration: def test_get_celery_ssl_options_with_invalid_cert_reqs(self): """Test SSL options with invalid cert requirement defaults to CERT_NONE.""" mock_config = MagicMock() - mock_config.REDIS_USE_SSL = True mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" mock_config.REDIS_SSL_CERT_REQS = "INVALID_VALUE" mock_config.REDIS_SSL_CA_CERTS = None @@ -108,7 +104,6 @@ class TestCelerySSLConfiguration: def test_celery_init_applies_ssl_to_broker_and_backend(self): """Test that SSL options are applied to both broker and backend when using Redis.""" mock_config = MagicMock() - mock_config.REDIS_USE_SSL = True mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" mock_config.CELERY_BACKEND = "redis" mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0" diff --git a/api/tests/unit_tests/libs/test_archive_storage.py b/api/tests/unit_tests/libs/test_archive_storage.py new file mode 100644 index 0000000000..697760e33a --- /dev/null +++ b/api/tests/unit_tests/libs/test_archive_storage.py @@ -0,0 +1,272 @@ +import base64 +import hashlib +from datetime import datetime +from unittest.mock import ANY, MagicMock + +import pytest +from botocore.exceptions import ClientError + +from libs import archive_storage as storage_module +from libs.archive_storage import ( + ArchiveStorage, + ArchiveStorageError, + ArchiveStorageNotConfiguredError, +) + +BUCKET_NAME = "archive-bucket" + + +def _configure_storage(monkeypatch, **overrides): + defaults = { + "ARCHIVE_STORAGE_ENABLED": True, + "ARCHIVE_STORAGE_ENDPOINT": "https://storage.example.com", + "ARCHIVE_STORAGE_ARCHIVE_BUCKET": BUCKET_NAME, + "ARCHIVE_STORAGE_ACCESS_KEY": "access", + "ARCHIVE_STORAGE_SECRET_KEY": "secret", + "ARCHIVE_STORAGE_REGION": "auto", + } + defaults.update(overrides) + for key, value in defaults.items(): + monkeypatch.setattr(storage_module.dify_config, key, value, raising=False) + + +def _client_error(code: str) -> ClientError: + return ClientError({"Error": {"Code": code}}, "Operation") + + +def _mock_client(monkeypatch): + client = MagicMock() + client.head_bucket.return_value = None + boto_client = MagicMock(return_value=client) + monkeypatch.setattr(storage_module.boto3, "client", boto_client) + return client, boto_client + + +def test_init_disabled(monkeypatch): + _configure_storage(monkeypatch, ARCHIVE_STORAGE_ENABLED=False) + with pytest.raises(ArchiveStorageNotConfiguredError, match="not enabled"): + ArchiveStorage(bucket=BUCKET_NAME) + + +def test_init_missing_config(monkeypatch): + _configure_storage(monkeypatch, ARCHIVE_STORAGE_ENDPOINT=None) + with pytest.raises(ArchiveStorageNotConfiguredError, match="incomplete"): + ArchiveStorage(bucket=BUCKET_NAME) + + +def test_init_bucket_not_found(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.head_bucket.side_effect = _client_error("404") + + with pytest.raises(ArchiveStorageNotConfiguredError, match="does not exist"): + ArchiveStorage(bucket=BUCKET_NAME) + + +def test_init_bucket_access_denied(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.head_bucket.side_effect = _client_error("403") + + with pytest.raises(ArchiveStorageNotConfiguredError, match="Access denied"): + ArchiveStorage(bucket=BUCKET_NAME) + + +def test_init_bucket_other_error(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.head_bucket.side_effect = _client_error("500") + + with pytest.raises(ArchiveStorageError, match="Failed to access archive bucket"): + ArchiveStorage(bucket=BUCKET_NAME) + + +def test_init_sets_client(monkeypatch): + _configure_storage(monkeypatch) + client, boto_client = _mock_client(monkeypatch) + + storage = ArchiveStorage(bucket=BUCKET_NAME) + + boto_client.assert_called_once_with( + "s3", + endpoint_url="https://storage.example.com", + aws_access_key_id="access", + aws_secret_access_key="secret", + region_name="auto", + config=ANY, + ) + assert storage.client is client + assert storage.bucket == BUCKET_NAME + + +def test_put_object_returns_checksum(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + storage = ArchiveStorage(bucket=BUCKET_NAME) + + data = b"hello" + checksum = storage.put_object("key", data) + + expected_md5 = hashlib.md5(data).hexdigest() + expected_content_md5 = base64.b64encode(hashlib.md5(data).digest()).decode() + client.put_object.assert_called_once_with( + Bucket="archive-bucket", + Key="key", + Body=data, + ContentMD5=expected_content_md5, + ) + assert checksum == expected_md5 + + +def test_put_object_raises_on_error(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + storage = ArchiveStorage(bucket=BUCKET_NAME) + client.put_object.side_effect = _client_error("500") + + with pytest.raises(ArchiveStorageError, match="Failed to upload object"): + storage.put_object("key", b"data") + + +def test_get_object_returns_bytes(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + body = MagicMock() + body.read.return_value = b"payload" + client.get_object.return_value = {"Body": body} + storage = ArchiveStorage(bucket=BUCKET_NAME) + + assert storage.get_object("key") == b"payload" + + +def test_get_object_missing(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.get_object.side_effect = _client_error("NoSuchKey") + storage = ArchiveStorage(bucket=BUCKET_NAME) + + with pytest.raises(FileNotFoundError, match="Archive object not found"): + storage.get_object("missing") + + +def test_get_object_stream(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + body = MagicMock() + body.iter_chunks.return_value = [b"a", b"b"] + client.get_object.return_value = {"Body": body} + storage = ArchiveStorage(bucket=BUCKET_NAME) + + assert list(storage.get_object_stream("key")) == [b"a", b"b"] + + +def test_get_object_stream_missing(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.get_object.side_effect = _client_error("NoSuchKey") + storage = ArchiveStorage(bucket=BUCKET_NAME) + + with pytest.raises(FileNotFoundError, match="Archive object not found"): + list(storage.get_object_stream("missing")) + + +def test_object_exists(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + storage = ArchiveStorage(bucket=BUCKET_NAME) + + assert storage.object_exists("key") is True + client.head_object.side_effect = _client_error("404") + assert storage.object_exists("missing") is False + + +def test_delete_object_error(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.delete_object.side_effect = _client_error("500") + storage = ArchiveStorage(bucket=BUCKET_NAME) + + with pytest.raises(ArchiveStorageError, match="Failed to delete object"): + storage.delete_object("key") + + +def test_list_objects(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + paginator = MagicMock() + paginator.paginate.return_value = [ + {"Contents": [{"Key": "a"}, {"Key": "b"}]}, + {"Contents": [{"Key": "c"}]}, + ] + client.get_paginator.return_value = paginator + storage = ArchiveStorage(bucket=BUCKET_NAME) + + assert storage.list_objects("prefix") == ["a", "b", "c"] + paginator.paginate.assert_called_once_with(Bucket="archive-bucket", Prefix="prefix") + + +def test_list_objects_error(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + paginator = MagicMock() + paginator.paginate.side_effect = _client_error("500") + client.get_paginator.return_value = paginator + storage = ArchiveStorage(bucket=BUCKET_NAME) + + with pytest.raises(ArchiveStorageError, match="Failed to list objects"): + storage.list_objects("prefix") + + +def test_generate_presigned_url(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.generate_presigned_url.return_value = "http://signed-url" + storage = ArchiveStorage(bucket=BUCKET_NAME) + + url = storage.generate_presigned_url("key", expires_in=123) + + client.generate_presigned_url.assert_called_once_with( + ClientMethod="get_object", + Params={"Bucket": "archive-bucket", "Key": "key"}, + ExpiresIn=123, + ) + assert url == "http://signed-url" + + +def test_generate_presigned_url_error(monkeypatch): + _configure_storage(monkeypatch) + client, _ = _mock_client(monkeypatch) + client.generate_presigned_url.side_effect = _client_error("500") + storage = ArchiveStorage(bucket=BUCKET_NAME) + + with pytest.raises(ArchiveStorageError, match="Failed to generate pre-signed URL"): + storage.generate_presigned_url("key") + + +def test_serialization_roundtrip(): + records = [ + { + "id": "1", + "created_at": datetime(2024, 1, 1, 12, 0, 0), + "payload": {"nested": "value"}, + "items": [{"name": "a"}], + }, + {"id": "2", "value": 123}, + ] + + data = ArchiveStorage.serialize_to_jsonl_gz(records) + decoded = ArchiveStorage.deserialize_from_jsonl_gz(data) + + assert decoded[0]["id"] == "1" + assert decoded[0]["payload"]["nested"] == "value" + assert decoded[0]["items"][0]["name"] == "a" + assert "2024-01-01T12:00:00" in decoded[0]["created_at"] + assert decoded[1]["value"] == 123 + + +def test_content_md5_matches_checksum(): + data = b"checksum" + expected = base64.b64encode(hashlib.md5(data).digest()).decode() + + assert ArchiveStorage._content_md5(data) == expected + assert ArchiveStorage.compute_checksum(data) == hashlib.md5(data).hexdigest() diff --git a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py index 303f0493bd..a0fed1aa14 100644 --- a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py +++ b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from qcloud_cos import CosConfig @@ -18,3 +18,72 @@ class TestTencentCos(BaseStorageTest): with patch.object(CosConfig, "__init__", return_value=None): self.storage = TencentCosStorage() self.storage.bucket_name = get_example_bucket() + + +class TestTencentCosConfiguration: + """Tests for TencentCosStorage initialization with different configurations.""" + + def test_init_with_custom_domain(self): + """Test initialization with custom domain configured.""" + # Mock dify_config to return custom domain configuration + mock_dify_config = MagicMock() + mock_dify_config.TENCENT_COS_CUSTOM_DOMAIN = "cos.example.com" + mock_dify_config.TENCENT_COS_SECRET_ID = "test-secret-id" + mock_dify_config.TENCENT_COS_SECRET_KEY = "test-secret-key" + mock_dify_config.TENCENT_COS_SCHEME = "https" + + # Mock CosConfig and CosS3Client + mock_config_instance = MagicMock() + mock_client = MagicMock() + + with ( + patch("extensions.storage.tencent_cos_storage.dify_config", mock_dify_config), + patch( + "extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance + ) as mock_cos_config, + patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client), + ): + TencentCosStorage() + + # Verify CosConfig was called with Domain parameter (not Region) + mock_cos_config.assert_called_once() + call_kwargs = mock_cos_config.call_args[1] + assert "Domain" in call_kwargs + assert call_kwargs["Domain"] == "cos.example.com" + assert "Region" not in call_kwargs + assert call_kwargs["SecretId"] == "test-secret-id" + assert call_kwargs["SecretKey"] == "test-secret-key" + assert call_kwargs["Scheme"] == "https" + + def test_init_with_region(self): + """Test initialization with region configured (no custom domain).""" + # Mock dify_config to return region configuration + mock_dify_config = MagicMock() + mock_dify_config.TENCENT_COS_CUSTOM_DOMAIN = None + mock_dify_config.TENCENT_COS_REGION = "ap-guangzhou" + mock_dify_config.TENCENT_COS_SECRET_ID = "test-secret-id" + mock_dify_config.TENCENT_COS_SECRET_KEY = "test-secret-key" + mock_dify_config.TENCENT_COS_SCHEME = "https" + + # Mock CosConfig and CosS3Client + mock_config_instance = MagicMock() + mock_client = MagicMock() + + with ( + patch("extensions.storage.tencent_cos_storage.dify_config", mock_dify_config), + patch( + "extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance + ) as mock_cos_config, + patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client), + ): + TencentCosStorage() + + # Verify CosConfig was called with Region parameter (not Domain) + mock_cos_config.assert_called_once() + call_kwargs = mock_cos_config.call_args[1] + assert "Region" in call_kwargs + assert call_kwargs["Region"] == "ap-guangzhou" + assert "Domain" not in call_kwargs + assert call_kwargs["SecretId"] == "test-secret-id" + assert call_kwargs["SecretKey"] == "test-secret-key" + assert call_kwargs["Scheme"] == "https" diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index f50f744a75..d00743278e 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -1294,6 +1294,42 @@ class TestBillingServiceSubscriptionOperations: # Assert assert result == {} + def test_get_plan_bulk_with_invalid_tenant_plan_skipped(self, mock_send_request): + """Test bulk plan retrieval when one tenant has invalid plan data (should skip that tenant).""" + # Arrange + tenant_ids = ["tenant-valid-1", "tenant-invalid", "tenant-valid-2"] + + # Response with one invalid tenant plan (missing expiration_date) and two valid ones + mock_send_request.return_value = { + "data": { + "tenant-valid-1": {"plan": "sandbox", "expiration_date": 1735689600}, + "tenant-invalid": {"plan": "professional"}, # Missing expiration_date field + "tenant-valid-2": {"plan": "team", "expiration_date": 1767225600}, + } + } + + # Act + with patch("services.billing_service.logger") as mock_logger: + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert - should only contain valid tenants + assert len(result) == 2 + assert "tenant-valid-1" in result + assert "tenant-valid-2" in result + assert "tenant-invalid" not in result + + # Verify valid tenants have correct data + assert result["tenant-valid-1"]["plan"] == "sandbox" + assert result["tenant-valid-1"]["expiration_date"] == 1735689600 + assert result["tenant-valid-2"]["plan"] == "team" + assert result["tenant-valid-2"]["expiration_date"] == 1767225600 + + # Verify exception was logged for the invalid tenant + mock_logger.exception.assert_called_once() + log_call_args = mock_logger.exception.call_args[0] + assert "get_plan_bulk: failed to validate subscription plan for tenant" in log_call_args[0] + assert "tenant-invalid" in log_call_args[1] + def test_get_expired_subscription_cleanup_whitelist_success(self, mock_send_request): """Test successful retrieval of expired subscription cleanup whitelist.""" # Arrange diff --git a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py index 9a107da1c7..e2360b116d 100644 --- a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py +++ b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py @@ -27,7 +27,6 @@ def service_with_fake_configurations(): description=None, icon_small=None, icon_small_dark=None, - icon_large=None, background=None, help=None, supported_model_types=[ModelType.LLM], diff --git a/api/tests/unit_tests/utils/test_text_processing.py b/api/tests/unit_tests/utils/test_text_processing.py index 11e017464a..bf61162a66 100644 --- a/api/tests/unit_tests/utils/test_text_processing.py +++ b/api/tests/unit_tests/utils/test_text_processing.py @@ -15,6 +15,11 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols ("", ""), (" ", " "), ("【测试】", "【测试】"), + # Markdown link preservation - should be preserved if text starts with a markdown link + ("[Google](https://google.com) is a search engine", "[Google](https://google.com) is a search engine"), + ("[Example](http://example.com) some text", "[Example](http://example.com) some text"), + # Leading symbols before markdown link are removed, including the opening bracket [ + ("@[Test](https://example.com)", "Test](https://example.com)"), ], ) def test_remove_leading_symbols(input_text, expected_output): diff --git a/docker/.env.example b/docker/.env.example index 1ea1fb9a8e..5c1089408c 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -447,6 +447,15 @@ S3_SECRET_KEY= # If set to false, the access key and secret key must be provided. S3_USE_AWS_MANAGED_IAM=false +# Workflow run and Conversation archive storage (S3-compatible) +ARCHIVE_STORAGE_ENABLED=false +ARCHIVE_STORAGE_ENDPOINT= +ARCHIVE_STORAGE_ARCHIVE_BUCKET= +ARCHIVE_STORAGE_EXPORT_BUCKET= +ARCHIVE_STORAGE_ACCESS_KEY= +ARCHIVE_STORAGE_SECRET_KEY= +ARCHIVE_STORAGE_REGION=auto + # Azure Blob Configuration # AZURE_BLOB_ACCOUNT_NAME=difyai @@ -478,6 +487,7 @@ TENCENT_COS_SECRET_KEY=your-secret-key TENCENT_COS_SECRET_ID=your-secret-id TENCENT_COS_REGION=your-region TENCENT_COS_SCHEME=your-scheme +TENCENT_COS_CUSTOM_DOMAIN=your-custom-domain # Oracle Storage Configuration # diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index c03cb2ef9f..9910c95a84 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -122,6 +122,13 @@ x-shared-env: &shared-api-worker-env S3_ACCESS_KEY: ${S3_ACCESS_KEY:-} S3_SECRET_KEY: ${S3_SECRET_KEY:-} S3_USE_AWS_MANAGED_IAM: ${S3_USE_AWS_MANAGED_IAM:-false} + ARCHIVE_STORAGE_ENABLED: ${ARCHIVE_STORAGE_ENABLED:-false} + ARCHIVE_STORAGE_ENDPOINT: ${ARCHIVE_STORAGE_ENDPOINT:-} + ARCHIVE_STORAGE_ARCHIVE_BUCKET: ${ARCHIVE_STORAGE_ARCHIVE_BUCKET:-} + ARCHIVE_STORAGE_EXPORT_BUCKET: ${ARCHIVE_STORAGE_EXPORT_BUCKET:-} + ARCHIVE_STORAGE_ACCESS_KEY: ${ARCHIVE_STORAGE_ACCESS_KEY:-} + ARCHIVE_STORAGE_SECRET_KEY: ${ARCHIVE_STORAGE_SECRET_KEY:-} + ARCHIVE_STORAGE_REGION: ${ARCHIVE_STORAGE_REGION:-auto} AZURE_BLOB_ACCOUNT_NAME: ${AZURE_BLOB_ACCOUNT_NAME:-difyai} AZURE_BLOB_ACCOUNT_KEY: ${AZURE_BLOB_ACCOUNT_KEY:-difyai} AZURE_BLOB_CONTAINER_NAME: ${AZURE_BLOB_CONTAINER_NAME:-difyai-container} @@ -141,6 +148,7 @@ x-shared-env: &shared-api-worker-env TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id} TENCENT_COS_REGION: ${TENCENT_COS_REGION:-your-region} TENCENT_COS_SCHEME: ${TENCENT_COS_SCHEME:-your-scheme} + TENCENT_COS_CUSTOM_DOMAIN: ${TENCENT_COS_CUSTOM_DOMAIN:-your-custom-domain} OCI_ENDPOINT: ${OCI_ENDPOINT:-https://your-object-storage-namespace.compat.objectstorage.us-ashburn-1.oraclecloud.com} OCI_BUCKET_NAME: ${OCI_BUCKET_NAME:-your-bucket-name} OCI_ACCESS_KEY: ${OCI_ACCESS_KEY:-your-access-key} diff --git a/web/__mocks__/provider-context.ts b/web/__mocks__/provider-context.ts index c69a2ad1d2..373c2f86d3 100644 --- a/web/__mocks__/provider-context.ts +++ b/web/__mocks__/provider-context.ts @@ -1,6 +1,7 @@ import type { Plan, UsagePlanInfo } from '@/app/components/billing/type' import type { ProviderContextState } from '@/context/provider-context' -import { merge, noop } from 'es-toolkit/compat' +import { merge } from 'es-toolkit/compat' +import { noop } from 'es-toolkit/function' import { defaultPlan } from '@/app/components/billing/config' // Avoid being mocked in tests diff --git a/web/__tests__/check-i18n.test.ts b/web/__tests__/check-i18n.test.ts index a6f86d8107..9f573bda10 100644 --- a/web/__tests__/check-i18n.test.ts +++ b/web/__tests__/check-i18n.test.ts @@ -3,7 +3,7 @@ import path from 'node:path' import vm from 'node:vm' import { transpile } from 'typescript' -describe('check-i18n script functionality', () => { +describe('i18n:check script functionality', () => { const testDir = path.join(__dirname, '../i18n-test') const testEnDir = path.join(testDir, 'en-US') const testZhDir = path.join(testDir, 'zh-Hans') diff --git a/web/__tests__/i18n-upload-features.test.ts b/web/__tests__/i18n-upload-features.test.ts index af418262d6..3d0f1d30dd 100644 --- a/web/__tests__/i18n-upload-features.test.ts +++ b/web/__tests__/i18n-upload-features.test.ts @@ -16,7 +16,7 @@ const getSupportedLocales = (): string[] => { // Helper function to load translation file content const loadTranslationContent = (locale: string): string => { - const filePath = path.join(I18N_DIR, locale, 'app-debug.ts') + const filePath = path.join(I18N_DIR, locale, 'app-debug.json') if (!fs.existsSync(filePath)) throw new Error(`Translation file not found: ${filePath}`) @@ -24,14 +24,14 @@ const loadTranslationContent = (locale: string): string => { return fs.readFileSync(filePath, 'utf-8') } -// Helper function to check if upload features exist +// Helper function to check if upload features exist (supports flattened JSON) const hasUploadFeatures = (content: string): { [key: string]: boolean } => { return { - fileUpload: /fileUpload\s*:\s*\{/.test(content), - imageUpload: /imageUpload\s*:\s*\{/.test(content), - documentUpload: /documentUpload\s*:\s*\{/.test(content), - audioUpload: /audioUpload\s*:\s*\{/.test(content), - featureBar: /bar\s*:\s*\{/.test(content), + fileUpload: /"feature\.fileUpload\.title"/.test(content), + imageUpload: /"feature\.imageUpload\.title"/.test(content), + documentUpload: /"feature\.documentUpload\.title"/.test(content), + audioUpload: /"feature\.audioUpload\.title"/.test(content), + featureBar: /"feature\.bar\.empty"/.test(content), } } @@ -45,7 +45,7 @@ describe('Upload Features i18n Translations - Issue #23062', () => { it('all locales should have translation files', () => { supportedLocales.forEach((locale) => { - const filePath = path.join(I18N_DIR, locale, 'app-debug.ts') + const filePath = path.join(I18N_DIR, locale, 'app-debug.json') expect(fs.existsSync(filePath)).toBe(true) }) }) @@ -76,12 +76,9 @@ describe('Upload Features i18n Translations - Issue #23062', () => { previouslyMissingLocales.forEach((locale) => { const content = loadTranslationContent(locale) - // Verify audioUpload exists - expect(/audioUpload\s*:\s*\{/.test(content)).toBe(true) - - // Verify it has title and description - expect(/audioUpload[^}]*title\s*:/.test(content)).toBe(true) - expect(/audioUpload[^}]*description\s*:/.test(content)).toBe(true) + // Verify audioUpload exists with title and description (flattened JSON format) + expect(/"feature\.audioUpload\.title"/.test(content)).toBe(true) + expect(/"feature\.audioUpload\.description"/.test(content)).toBe(true) console.log(`āœ… ${locale} - Issue #23062 resolved: audioUpload feature present`) }) @@ -91,28 +88,28 @@ describe('Upload Features i18n Translations - Issue #23062', () => { supportedLocales.forEach((locale) => { const content = loadTranslationContent(locale) - // Check fileUpload has required properties - if (/fileUpload\s*:\s*\{/.test(content)) { - expect(/fileUpload[^}]*title\s*:/.test(content)).toBe(true) - expect(/fileUpload[^}]*description\s*:/.test(content)).toBe(true) + // Check fileUpload has required properties (flattened JSON format) + if (/"feature\.fileUpload\.title"/.test(content)) { + expect(/"feature\.fileUpload\.title"/.test(content)).toBe(true) + expect(/"feature\.fileUpload\.description"/.test(content)).toBe(true) } // Check imageUpload has required properties - if (/imageUpload\s*:\s*\{/.test(content)) { - expect(/imageUpload[^}]*title\s*:/.test(content)).toBe(true) - expect(/imageUpload[^}]*description\s*:/.test(content)).toBe(true) + if (/"feature\.imageUpload\.title"/.test(content)) { + expect(/"feature\.imageUpload\.title"/.test(content)).toBe(true) + expect(/"feature\.imageUpload\.description"/.test(content)).toBe(true) } // Check documentUpload has required properties - if (/documentUpload\s*:\s*\{/.test(content)) { - expect(/documentUpload[^}]*title\s*:/.test(content)).toBe(true) - expect(/documentUpload[^}]*description\s*:/.test(content)).toBe(true) + if (/"feature\.documentUpload\.title"/.test(content)) { + expect(/"feature\.documentUpload\.title"/.test(content)).toBe(true) + expect(/"feature\.documentUpload\.description"/.test(content)).toBe(true) } // Check audioUpload has required properties - if (/audioUpload\s*:\s*\{/.test(content)) { - expect(/audioUpload[^}]*title\s*:/.test(content)).toBe(true) - expect(/audioUpload[^}]*description\s*:/.test(content)).toBe(true) + if (/"feature\.audioUpload\.title"/.test(content)) { + expect(/"feature\.audioUpload\.title"/.test(content)).toBe(true) + expect(/"feature\.audioUpload\.description"/.test(content)).toBe(true) } }) }) diff --git a/web/__tests__/workflow-parallel-limit.test.tsx b/web/__tests__/workflow-parallel-limit.test.tsx index 18657f4bd2..ba3840ac3e 100644 --- a/web/__tests__/workflow-parallel-limit.test.tsx +++ b/web/__tests__/workflow-parallel-limit.test.tsx @@ -64,7 +64,6 @@ vi.mock('i18next', () => ({ // Mock the useConfig hook vi.mock('@/app/components/workflow/nodes/iteration/use-config', () => ({ - __esModule: true, default: () => ({ inputs: { is_parallel: true, diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx index c1feb3ea5d..470f4477fa 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx @@ -70,7 +70,7 @@ const AppDetailLayout: FC = (props) => { const navConfig = [ ...(isCurrentWorkspaceEditor ? [{ - name: t('common.appMenus.promptEng'), + name: t('appMenus.promptEng', { ns: 'common' }), href: `/app/${appId}/${(mode === AppModeEnum.WORKFLOW || mode === AppModeEnum.ADVANCED_CHAT) ? 'workflow' : 'configuration'}`, icon: RiTerminalWindowLine, selectedIcon: RiTerminalWindowFill, @@ -78,7 +78,7 @@ const AppDetailLayout: FC = (props) => { : [] ), { - name: t('common.appMenus.apiAccess'), + name: t('appMenus.apiAccess', { ns: 'common' }), href: `/app/${appId}/develop`, icon: RiTerminalBoxLine, selectedIcon: RiTerminalBoxFill, @@ -86,8 +86,8 @@ const AppDetailLayout: FC = (props) => { ...(isCurrentWorkspaceEditor ? [{ name: mode !== AppModeEnum.WORKFLOW - ? t('common.appMenus.logAndAnn') - : t('common.appMenus.logs'), + ? t('appMenus.logAndAnn', { ns: 'common' }) + : t('appMenus.logs', { ns: 'common' }), href: `/app/${appId}/logs`, icon: RiFileList3Line, selectedIcon: RiFileList3Fill, @@ -95,7 +95,7 @@ const AppDetailLayout: FC = (props) => { : [] ), { - name: t('common.appMenus.overview'), + name: t('appMenus.overview', { ns: 'common' }), href: `/app/${appId}/overview`, icon: RiDashboard2Line, selectedIcon: RiDashboard2Fill, @@ -104,7 +104,7 @@ const AppDetailLayout: FC = (props) => { return navConfig }, [t]) - useDocumentTitle(appDetail?.name || t('common.menus.appDetail')) + useDocumentTitle(appDetail?.name || t('menus.appDetail', { ns: 'common' })) useEffect(() => { if (appDetail) { diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx index 34bb0f82f8..939e4e9fe6 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx @@ -4,6 +4,7 @@ import type { IAppCardProps } from '@/app/components/app/overview/app-card' import type { BlockEnum } from '@/app/components/workflow/types' import type { UpdateAppSiteCodeResponse } from '@/models/app' import type { App } from '@/types/app' +import type { I18nKeysByPrefix } from '@/types/i18n' import * as React from 'react' import { useCallback, useMemo } from 'react' import { useTranslation } from 'react-i18next' @@ -62,7 +63,7 @@ const CardView: FC = ({ appId, isInPanel, className }) => { const buildTriggerModeMessage = useCallback((featureName: string) => (
- {t('appOverview.overview.disableTooltip.triggerMode', { feature: featureName })} + {t('overview.disableTooltip.triggerMode', { ns: 'appOverview', feature: featureName })}
= ({ appId, isInPanel, className }) => { window.open(triggerDocUrl, '_blank') }} > - {t('appOverview.overview.appInfo.enableTooltip.learnMore')} + {t('overview.appInfo.enableTooltip.learnMore', { ns: 'appOverview' })}
), [t, triggerDocUrl]) const disableWebAppTooltip = disableAppCards - ? buildTriggerModeMessage(t('appOverview.overview.appInfo.title')) + ? buildTriggerModeMessage(t('overview.appInfo.title', { ns: 'appOverview' })) : null const disableApiTooltip = disableAppCards - ? buildTriggerModeMessage(t('appOverview.overview.apiInfo.title')) + ? buildTriggerModeMessage(t('overview.apiInfo.title', { ns: 'appOverview' })) : null const disableMcpTooltip = disableAppCards - ? buildTriggerModeMessage(t('tools.mcp.server.title')) + ? buildTriggerModeMessage(t('mcp.server.title', { ns: 'tools' })) : null const updateAppDetail = async () => { @@ -94,7 +95,7 @@ const CardView: FC = ({ appId, isInPanel, className }) => { catch (error) { console.error(error) } } - const handleCallbackResult = (err: Error | null, message?: string) => { + const handleCallbackResult = (err: Error | null, message?: I18nKeysByPrefix<'common', 'actionMsg.'>) => { const type = err ? 'error' : 'success' message ||= (type === 'success' ? 'modifiedSuccessfully' : 'modifiedUnsuccessfully') @@ -104,7 +105,7 @@ const CardView: FC = ({ appId, isInPanel, className }) => { notify({ type, - message: t(`common.actionMsg.${message}` as any) as string, + message: t(`actionMsg.${message}`, { ns: 'common' }) as string, }) } diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chart-view.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chart-view.tsx index e1fca90c5d..b6e902f456 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chart-view.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chart-view.tsx @@ -1,5 +1,6 @@ 'use client' import type { PeriodParams } from '@/app/components/app/overview/app-chart' +import type { I18nKeysByPrefix } from '@/types/i18n' import dayjs from 'dayjs' import quarterOfYear from 'dayjs/plugin/quarterOfYear' import * as React from 'react' @@ -16,7 +17,9 @@ dayjs.extend(quarterOfYear) const today = dayjs() -const TIME_PERIOD_MAPPING = [ +type TimePeriodName = I18nKeysByPrefix<'appLog', 'filter.period.'> + +const TIME_PERIOD_MAPPING: { value: number, name: TimePeriodName }[] = [ { value: 0, name: 'today' }, { value: 7, name: 'last7days' }, { value: 30, name: 'last30days' }, @@ -35,8 +38,8 @@ export default function ChartView({ appId, headerRight }: IChartViewProps) { const isChatApp = appDetail?.mode !== 'completion' && appDetail?.mode !== 'workflow' const isWorkflow = appDetail?.mode === 'workflow' const [period, setPeriod] = useState(IS_CLOUD_EDITION - ? { name: t('appLog.filter.period.today'), query: { start: today.startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } } - : { name: t('appLog.filter.period.last7days'), query: { start: today.subtract(7, 'day').startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } }, + ? { name: t('filter.period.today', { ns: 'appLog' }), query: { start: today.startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } } + : { name: t('filter.period.last7days', { ns: 'appLog' }), query: { start: today.subtract(7, 'day').startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } }, ) if (!appDetail) @@ -45,7 +48,7 @@ export default function ChartView({ appId, headerRight }: IChartViewProps) { return (
-
{t('common.appMenus.overview')}
+
{t('appMenus.overview', { ns: 'common' })}
{IS_CLOUD_EDITION ? ( diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/long-time-range-picker.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/long-time-range-picker.tsx index 55ed906a45..f7178d7ac2 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/long-time-range-picker.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/long-time-range-picker.tsx @@ -2,13 +2,16 @@ import type { FC } from 'react' import type { PeriodParams } from '@/app/components/app/overview/app-chart' import type { Item } from '@/app/components/base/select' +import type { I18nKeysByPrefix } from '@/types/i18n' import dayjs from 'dayjs' import * as React from 'react' import { useTranslation } from 'react-i18next' import { SimpleSelect } from '@/app/components/base/select' +type TimePeriodName = I18nKeysByPrefix<'appLog', 'filter.period.'> + type Props = { - periodMapping: { [key: string]: { value: number, name: string } } + periodMapping: { [key: string]: { value: number, name: TimePeriodName } } onSelect: (payload: PeriodParams) => void queryDateFormat: string } @@ -25,9 +28,9 @@ const LongTimeRangePicker: FC = ({ const handleSelect = React.useCallback((item: Item) => { const id = item.value const value = periodMapping[id]?.value ?? '-1' - const name = item.name || t('appLog.filter.period.allTime') + const name = item.name || t('filter.period.allTime', { ns: 'appLog' }) if (value === -1) { - onSelect({ name: t('appLog.filter.period.allTime'), query: undefined }) + onSelect({ name: t('filter.period.allTime', { ns: 'appLog' }), query: undefined }) } else if (value === 0) { const startOfToday = today.startOf('day').format(queryDateFormat) @@ -53,7 +56,7 @@ const LongTimeRangePicker: FC = ({ return ( ({ value: k, name: t(`appLog.filter.period.${v.name}` as any) as string }))} + items={Object.entries(periodMapping).map(([k, v]) => ({ value: k, name: t(`filter.period.${v.name}`, { ns: 'appLog' }) }))} className="mt-0 !w-40" notClearable={true} onSelect={handleSelect} diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx index 004f83afc5..368c3dcfc3 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/date-picker.tsx @@ -4,11 +4,11 @@ import type { FC } from 'react' import type { TriggerProps } from '@/app/components/base/date-and-time-picker/types' import { RiCalendarLine } from '@remixicon/react' import dayjs from 'dayjs' -import { noop } from 'es-toolkit/compat' +import { noop } from 'es-toolkit/function' import * as React from 'react' import { useCallback } from 'react' import Picker from '@/app/components/base/date-and-time-picker/date-picker' -import { useI18N } from '@/context/i18n' +import { useLocale } from '@/context/i18n' import { cn } from '@/utils/classnames' import { formatToLocalTime } from '@/utils/format' @@ -26,7 +26,7 @@ const DatePicker: FC = ({ onStartChange, onEndChange, }) => { - const { locale } = useI18N() + const locale = useLocale() const renderDate = useCallback(({ value, handleClickTrigger, isOpen }: TriggerProps) => { return ( diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/index.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/index.tsx index 469bc97737..53794ad8db 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/index.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/index.tsx @@ -2,19 +2,22 @@ import type { Dayjs } from 'dayjs' import type { FC } from 'react' import type { PeriodParams, PeriodParamsWithTimeRange } from '@/app/components/app/overview/app-chart' +import type { I18nKeysByPrefix } from '@/types/i18n' import dayjs from 'dayjs' import * as React from 'react' import { useCallback, useState } from 'react' import { HourglassShape } from '@/app/components/base/icons/src/vender/other' -import { useI18N } from '@/context/i18n' +import { useLocale } from '@/context/i18n' import { formatToLocalTime } from '@/utils/format' import DatePicker from './date-picker' import RangeSelector from './range-selector' const today = dayjs() +type TimePeriodName = I18nKeysByPrefix<'appLog', 'filter.period.'> + type Props = { - ranges: { value: number, name: string }[] + ranges: { value: number, name: TimePeriodName }[] onSelect: (payload: PeriodParams) => void queryDateFormat: string } @@ -24,7 +27,7 @@ const TimeRangePicker: FC = ({ onSelect, queryDateFormat, }) => { - const { locale } = useI18N() + const locale = useLocale() const [isCustomRange, setIsCustomRange] = useState(false) const [start, setStart] = useState(today) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx index 88cb79ce0d..986170728f 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/time-range-picker/range-selector.tsx @@ -2,6 +2,7 @@ import type { FC } from 'react' import type { PeriodParamsWithTimeRange, TimeRange } from '@/app/components/app/overview/app-chart' import type { Item } from '@/app/components/base/select' +import type { I18nKeysByPrefix } from '@/types/i18n' import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react' import dayjs from 'dayjs' import * as React from 'react' @@ -12,9 +13,11 @@ import { cn } from '@/utils/classnames' const today = dayjs() +type TimePeriodName = I18nKeysByPrefix<'appLog', 'filter.period.'> + type Props = { isCustomRange: boolean - ranges: { value: number, name: string }[] + ranges: { value: number, name: TimePeriodName }[] onSelect: (payload: PeriodParamsWithTimeRange) => void } @@ -42,7 +45,7 @@ const RangeSelector: FC = ({ const renderTrigger = useCallback((item: Item | null, isOpen: boolean) => { return (
-
{isCustomRange ? t('appLog.filter.period.custom') : item?.name}
+
{isCustomRange ? t('filter.period.custom', { ns: 'appLog' }) : item?.name}
) @@ -66,7 +69,7 @@ const RangeSelector: FC = ({ }, []) return ( ({ ...v, name: t(`appLog.filter.period.${v.name}` as any) as string }))} + items={ranges.map(v => ({ ...v, name: t(`filter.period.${v.name}`, { ns: 'appLog' }) }))} className="mt-0 !w-40" notClearable={true} onSelect={handleSelectRange} diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx index 35ab8a61ec..4469459b52 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx @@ -15,7 +15,7 @@ import ProviderPanel from './provider-panel' import TracingIcon from './tracing-icon' import { TracingProvider } from './type' -const I18N_PREFIX = 'app.tracing' +const I18N_PREFIX = 'tracing' export type PopupProps = { appId: string @@ -327,19 +327,19 @@ const ConfigPopup: FC = ({
-
{t(`${I18N_PREFIX}.tracing`)}
+
{t(`${I18N_PREFIX}.tracing`, { ns: 'app' })}
- {t(`${I18N_PREFIX}.${enabled ? 'enabled' : 'disabled'}`)} + {t(`${I18N_PREFIX}.${enabled ? 'enabled' : 'disabled'}`, { ns: 'app' })}
{!readOnly && ( <> {providerAllNotConfigured ? ( {switchContent} @@ -351,14 +351,14 @@ const ConfigPopup: FC = ({
- {t(`${I18N_PREFIX}.tracingDescription`)} + {t(`${I18N_PREFIX}.tracingDescription`, { ns: 'app' })}
{(providerAllConfigured || providerAllNotConfigured) ? ( <> -
{t(`${I18N_PREFIX}.configProviderTitle.${providerAllConfigured ? 'configured' : 'notConfigured'}`)}
+
{t(`${I18N_PREFIX}.configProviderTitle.${providerAllConfigured ? 'configured' : 'notConfigured'}`, { ns: 'app' })}
{langfusePanel} {langSmithPanel} @@ -375,11 +375,11 @@ const ConfigPopup: FC = ({ ) : ( <> -
{t(`${I18N_PREFIX}.configProviderTitle.configured`)}
+
{t(`${I18N_PREFIX}.configProviderTitle.configured`, { ns: 'app' })}
{configuredProviderPanel()}
-
{t(`${I18N_PREFIX}.configProviderTitle.moreProvider`)}
+
{t(`${I18N_PREFIX}.configProviderTitle.moreProvider`, { ns: 'app' })}
{moreProviderPanel()}
diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx index 438dbddfb8..5e7d98d191 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx @@ -23,7 +23,7 @@ import ConfigButton from './config-button' import TracingIcon from './tracing-icon' import { TracingProvider } from './type' -const I18N_PREFIX = 'app.tracing' +const I18N_PREFIX = 'tracing' const Panel: FC = () => { const { t } = useTranslation() @@ -45,7 +45,7 @@ const Panel: FC = () => { if (!noToast) { Toast.notify({ type: 'success', - message: t('common.api.success'), + message: t('api.success', { ns: 'common' }), }) } } @@ -254,7 +254,7 @@ const Panel: FC = () => { )} > -
{t(`${I18N_PREFIX}.title`)}
+
{t(`${I18N_PREFIX}.title`, { ns: 'app' })}
@@ -295,7 +295,7 @@ const Panel: FC = () => {
- {t(`${I18N_PREFIX}.${enabled ? 'enabled' : 'disabled'}`)} + {t(`${I18N_PREFIX}.${enabled ? 'enabled' : 'disabled'}`, { ns: 'app' })}
{InUseProviderIcon && } diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx index 0ef97e6970..ff78712c3c 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx @@ -30,7 +30,7 @@ type Props = { onChosen: (provider: TracingProvider) => void } -const I18N_PREFIX = 'app.tracing.configProvider' +const I18N_PREFIX = 'tracing.configProvider' const arizeConfigTemplate = { api_key: '', @@ -157,7 +157,7 @@ const ProviderConfigModal: FC = ({ }) Toast.notify({ type: 'success', - message: t('common.api.remove'), + message: t('api.remove', { ns: 'common' }), }) onRemoved() hideRemoveConfirm() @@ -177,37 +177,37 @@ const ProviderConfigModal: FC = ({ if (type === TracingProvider.arize) { const postData = config as ArizeConfig if (!postData.api_key) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'API Key' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'API Key' }) if (!postData.space_id) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'Space ID' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Space ID' }) if (!errorMessage && !postData.project) - errorMessage = t('common.errorMsg.fieldRequired', { field: t(`${I18N_PREFIX}.project`) }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.project`, { ns: 'app' }) }) } if (type === TracingProvider.phoenix) { const postData = config as PhoenixConfig if (!postData.api_key) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'API Key' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'API Key' }) if (!errorMessage && !postData.project) - errorMessage = t('common.errorMsg.fieldRequired', { field: t(`${I18N_PREFIX}.project`) }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.project`, { ns: 'app' }) }) } if (type === TracingProvider.langSmith) { const postData = config as LangSmithConfig if (!postData.api_key) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'API Key' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'API Key' }) if (!errorMessage && !postData.project) - errorMessage = t('common.errorMsg.fieldRequired', { field: t(`${I18N_PREFIX}.project`) }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.project`, { ns: 'app' }) }) } if (type === TracingProvider.langfuse) { const postData = config as LangFuseConfig if (!errorMessage && !postData.secret_key) - errorMessage = t('common.errorMsg.fieldRequired', { field: t(`${I18N_PREFIX}.secretKey`) }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.secretKey`, { ns: 'app' }) }) if (!errorMessage && !postData.public_key) - errorMessage = t('common.errorMsg.fieldRequired', { field: t(`${I18N_PREFIX}.publicKey`) }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.publicKey`, { ns: 'app' }) }) if (!errorMessage && !postData.host) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'Host' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Host' }) } if (type === TracingProvider.opik) { @@ -218,43 +218,43 @@ const ProviderConfigModal: FC = ({ if (type === TracingProvider.weave) { const postData = config as WeaveConfig if (!errorMessage && !postData.api_key) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'API Key' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'API Key' }) if (!errorMessage && !postData.project) - errorMessage = t('common.errorMsg.fieldRequired', { field: t(`${I18N_PREFIX}.project`) }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.project`, { ns: 'app' }) }) } if (type === TracingProvider.aliyun) { const postData = config as AliyunConfig if (!errorMessage && !postData.app_name) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'App Name' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'App Name' }) if (!errorMessage && !postData.license_key) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'License Key' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'License Key' }) if (!errorMessage && !postData.endpoint) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'Endpoint' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Endpoint' }) } if (type === TracingProvider.mlflow) { const postData = config as MLflowConfig if (!errorMessage && !postData.tracking_uri) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'Tracking URI' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Tracking URI' }) } if (type === TracingProvider.databricks) { const postData = config as DatabricksConfig if (!errorMessage && !postData.experiment_id) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'Experiment ID' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Experiment ID' }) if (!errorMessage && !postData.host) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'Host' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Host' }) } if (type === TracingProvider.tencent) { const postData = config as TencentConfig if (!errorMessage && !postData.token) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'Token' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Token' }) if (!errorMessage && !postData.endpoint) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'Endpoint' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Endpoint' }) if (!errorMessage && !postData.service_name) - errorMessage = t('common.errorMsg.fieldRequired', { field: 'Service Name' }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'Service Name' }) } return errorMessage @@ -281,7 +281,7 @@ const ProviderConfigModal: FC = ({ }) Toast.notify({ type: 'success', - message: t('common.api.success'), + message: t('api.success', { ns: 'common' }), }) onSaved(config) if (isAdd) @@ -303,8 +303,8 @@ const ProviderConfigModal: FC = ({
- {t(`${I18N_PREFIX}.title`)} - {t(`app.tracing.${type}.title`)} + {t(`${I18N_PREFIX}.title`, { ns: 'app' })} + {t(`tracing.${type}.title`, { ns: 'app' })}
@@ -317,7 +317,7 @@ const ProviderConfigModal: FC = ({ isRequired value={(config as ArizeConfig).api_key} onChange={handleConfigChange('api_key')} - placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'API Key' })!} + placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'API Key' })!} /> = ({ isRequired value={(config as ArizeConfig).space_id} onChange={handleConfigChange('space_id')} - placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'Space ID' })!} + placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'Space ID' })!} /> = ({ isRequired value={(config as PhoenixConfig).api_key} onChange={handleConfigChange('api_key')} - placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'API Key' })!} + placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'API Key' })!} /> = ({ isRequired value={(config as AliyunConfig).license_key} onChange={handleConfigChange('license_key')} - placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'License Key' })!} + placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'License Key' })!} /> = ({ isRequired value={(config as TencentConfig).token} onChange={handleConfigChange('token')} - placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'Token' })!} + placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'Token' })!} /> = ({ isRequired value={(config as WeaveConfig).api_key} onChange={handleConfigChange('api_key')} - placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'API Key' })!} + placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'API Key' })!} /> = ({ isRequired value={(config as LangSmithConfig).api_key} onChange={handleConfigChange('api_key')} - placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'API Key' })!} + placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'API Key' })!} /> = ({ {type === TracingProvider.langfuse && ( <> = ({ labelClassName="!text-sm" value={(config as OpikConfig).api_key} onChange={handleConfigChange('api_key')} - placeholder={t(`${I18N_PREFIX}.placeholder`, { key: 'API Key' })!} + placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'API Key' })!} /> = ({ {type === TracingProvider.mlflow && ( <> = ({ placeholder="http://localhost:5000" /> )} {type === TracingProvider.databricks && ( <> )} @@ -634,7 +634,7 @@ const ProviderConfigModal: FC = ({ target="_blank" href={docURL[type]} > - {t(`${I18N_PREFIX}.viewDocsLink`, { key: t(`app.tracing.${type}.title`) })} + {t(`${I18N_PREFIX}.viewDocsLink`, { ns: 'app', key: t(`tracing.${type}.title`, { ns: 'app' }) })}
@@ -644,7 +644,7 @@ const ProviderConfigModal: FC = ({ className="h-9 text-sm font-medium text-text-secondary" onClick={showRemoveConfirm} > - {t('common.operation.remove')} + {t('operation.remove', { ns: 'common' })} @@ -653,7 +653,7 @@ const ProviderConfigModal: FC = ({ className="mr-2 h-9 text-sm font-medium text-text-secondary" onClick={onCancel} > - {t('common.operation.cancel')} + {t('operation.cancel', { ns: 'common' })}
@@ -670,7 +670,7 @@ const ProviderConfigModal: FC = ({
- {t('common.modelProvider.encrypted.front')} + {t('modelProvider.encrypted.front', { ns: 'common' })} = ({ > PKCS1_OAEP - {t('common.modelProvider.encrypted.back')} + {t('modelProvider.encrypted.back', { ns: 'common' })}
@@ -691,8 +691,8 @@ const ProviderConfigModal: FC = ({ diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx index 6c66b19ad3..dc0fe2fbbc 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx @@ -11,7 +11,7 @@ import { Eye as View } from '@/app/components/base/icons/src/vender/solid/genera import { cn } from '@/utils/classnames' import { TracingProvider } from './type' -const I18N_PREFIX = 'app.tracing' +const I18N_PREFIX = 'tracing' type Props = { type: TracingProvider @@ -82,14 +82,14 @@ const ProviderPanel: FC = ({
- {isChosen &&
{t(`${I18N_PREFIX}.inUse`)}
} + {isChosen &&
{t(`${I18N_PREFIX}.inUse`, { ns: 'app' })}
}
{!readOnly && (
{hasConfigured && (
-
{t(`${I18N_PREFIX}.view`)}
+
{t(`${I18N_PREFIX}.view`, { ns: 'app' })}
)}
= ({ onClick={handleConfigBtnClick} > -
{t(`${I18N_PREFIX}.config`)}
+
{t(`${I18N_PREFIX}.config`, { ns: 'app' })}
)}
- {t(`${I18N_PREFIX}.${type}.description`)} + {t(`${I18N_PREFIX}.${type}.description`, { ns: 'app' })}
) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/layout.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/layout.tsx index 4135482dd9..a918ae2786 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/layout.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/layout.tsx @@ -15,7 +15,7 @@ const AppDetail: FC = ({ children }) => { const router = useRouter() const { isCurrentWorkspaceDatasetOperator } = useAppContext() const { t } = useTranslation() - useDocumentTitle(t('common.menus.appDetail')) + useDocumentTitle(t('menus.appDetail', { ns: 'common' })) useEffect(() => { if (isCurrentWorkspaceDatasetOperator) diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index 10a12d75e1..1c5434924f 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -70,14 +70,14 @@ const DatasetDetailLayout: FC = (props) => { const navigation = useMemo(() => { const baseNavigation = [ { - name: t('common.datasetMenus.hitTesting'), + name: t('datasetMenus.hitTesting', { ns: 'common' }), href: `/datasets/${datasetId}/hitTesting`, icon: RiFocus2Line, selectedIcon: RiFocus2Fill, disabled: isButtonDisabledWithPipeline, }, { - name: t('common.datasetMenus.settings'), + name: t('datasetMenus.settings', { ns: 'common' }), href: `/datasets/${datasetId}/settings`, icon: RiEqualizer2Line, selectedIcon: RiEqualizer2Fill, @@ -87,14 +87,14 @@ const DatasetDetailLayout: FC = (props) => { if (datasetRes?.provider !== 'external') { baseNavigation.unshift({ - name: t('common.datasetMenus.pipeline'), + name: t('datasetMenus.pipeline', { ns: 'common' }), href: `/datasets/${datasetId}/pipeline`, icon: PipelineLine as RemixiconComponentType, selectedIcon: PipelineFill as RemixiconComponentType, disabled: false, }) baseNavigation.unshift({ - name: t('common.datasetMenus.documents'), + name: t('datasetMenus.documents', { ns: 'common' }), href: `/datasets/${datasetId}/documents`, icon: RiFileTextLine, selectedIcon: RiFileTextFill, @@ -105,7 +105,7 @@ const DatasetDetailLayout: FC = (props) => { return baseNavigation }, [t, datasetId, isButtonDisabledWithPipeline, datasetRes?.provider]) - useDocumentTitle(datasetRes?.name || t('common.menus.datasets')) + useDocumentTitle(datasetRes?.name || t('menus.datasets', { ns: 'common' })) const setAppSidebarExpand = useStore(state => state.setAppSidebarExpand) diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx index aa64df3449..8080b565cd 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx @@ -1,3 +1,4 @@ +/* eslint-disable dify-i18n/require-ns-option */ import * as React from 'react' import Form from '@/app/components/datasets/settings/form' import { getLocaleOnServer, getTranslation } from '@/i18n-config/server' @@ -9,7 +10,7 @@ const Settings = async () => { return (
-
{t('title') as any}
+
{t('title')}
{t('desc')}
diff --git a/web/app/(commonLayout)/explore/layout.tsx b/web/app/(commonLayout)/explore/layout.tsx index 5928308cdc..7d59d397f9 100644 --- a/web/app/(commonLayout)/explore/layout.tsx +++ b/web/app/(commonLayout)/explore/layout.tsx @@ -7,7 +7,7 @@ import useDocumentTitle from '@/hooks/use-document-title' const ExploreLayout: FC = ({ children }) => { const { t } = useTranslation() - useDocumentTitle(t('common.menus.explore')) + useDocumentTitle(t('menus.explore', { ns: 'common' })) return ( {children} diff --git a/web/app/(commonLayout)/tools/page.tsx b/web/app/(commonLayout)/tools/page.tsx index 2d5c1a8e44..3e88050eba 100644 --- a/web/app/(commonLayout)/tools/page.tsx +++ b/web/app/(commonLayout)/tools/page.tsx @@ -12,7 +12,7 @@ const ToolsList: FC = () => { const router = useRouter() const { isCurrentWorkspaceDatasetOperator } = useAppContext() const { t } = useTranslation() - useDocumentTitle(t('common.menus.tools')) + useDocumentTitle(t('menus.tools', { ns: 'common' })) useEffect(() => { if (isCurrentWorkspaceDatasetOperator) diff --git a/web/app/(shareLayout)/components/authenticated-layout.tsx b/web/app/(shareLayout)/components/authenticated-layout.tsx index 00288b7a61..113f3b5680 100644 --- a/web/app/(shareLayout)/components/authenticated-layout.tsx +++ b/web/app/(shareLayout)/components/authenticated-layout.tsx @@ -81,7 +81,7 @@ const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => { return (
- {t('common.userProfile.logout')} + {t('userProfile.logout', { ns: 'common' })}
) } diff --git a/web/app/(shareLayout)/components/splash.tsx b/web/app/(shareLayout)/components/splash.tsx index b8ea1b2e56..9f89a03993 100644 --- a/web/app/(shareLayout)/components/splash.tsx +++ b/web/app/(shareLayout)/components/splash.tsx @@ -94,8 +94,8 @@ const Splash: FC = ({ children }) => { if (message) { return (
- - {code === '403' ? t('common.userProfile.logout') : t('share.login.backToHome')} + + {code === '403' ? t('userProfile.logout', { ns: 'common' }) : t('login.backToHome', { ns: 'share' })}
) } diff --git a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx index 9ce058340c..fbf45259e5 100644 --- a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx @@ -3,12 +3,12 @@ import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react' import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import Countdown from '@/app/components/signin/countdown' -import I18NContext from '@/context/i18n' + +import { useLocale } from '@/context/i18n' import { sendWebAppResetPasswordCode, verifyWebAppResetPasswordCode } from '@/service/common' export default function CheckCode() { @@ -19,21 +19,21 @@ export default function CheckCode() { const token = decodeURIComponent(searchParams.get('token') as string) const [code, setVerifyCode] = useState('') const [loading, setIsLoading] = useState(false) - const { locale } = useContext(I18NContext) + const locale = useLocale() const verify = async () => { try { if (!code.trim()) { Toast.notify({ type: 'error', - message: t('login.checkCode.emptyCode'), + message: t('checkCode.emptyCode', { ns: 'login' }), }) return } if (!/\d{6}/.test(code)) { Toast.notify({ type: 'error', - message: t('login.checkCode.invalidCode'), + message: t('checkCode.invalidCode', { ns: 'login' }), }) return } @@ -69,22 +69,22 @@ export default function CheckCode() {
-

{t('login.checkCode.checkYourEmail')}

+

{t('checkCode.checkYourEmail', { ns: 'login' })}

- {t('login.checkCode.tipsPrefix')} + {t('checkCode.tipsPrefix', { ns: 'login' })} {email}
- {t('login.checkCode.validTime')} + {t('checkCode.validTime', { ns: 'login' })}

- - setVerifyCode(e.target.value)} maxLength={6} className="mt-1" placeholder={t('login.checkCode.verificationCodePlaceholder') || ''} /> - + + setVerifyCode(e.target.value)} maxLength={6} className="mt-1" placeholder={t('checkCode.verificationCodePlaceholder', { ns: 'login' }) || ''} /> +
@@ -94,7 +94,7 @@ export default function CheckCode() {
- {t('login.back')} + {t('back', { ns: 'login' })}
) diff --git a/web/app/(shareLayout)/webapp-reset-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/page.tsx index 108bd4b22e..9b9a853cdd 100644 --- a/web/app/(shareLayout)/webapp-reset-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/page.tsx @@ -1,17 +1,17 @@ 'use client' import { RiArrowLeftLine, RiLockPasswordLine } from '@remixicon/react' -import { noop } from 'es-toolkit/compat' +import { noop } from 'es-toolkit/function' import Link from 'next/link' import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown' import { emailRegex } from '@/config' -import I18NContext from '@/context/i18n' + +import { useLocale } from '@/context/i18n' import useDocumentTitle from '@/hooks/use-document-title' import { sendResetPasswordCode } from '@/service/common' @@ -22,19 +22,19 @@ export default function CheckCode() { const router = useRouter() const [email, setEmail] = useState('') const [loading, setIsLoading] = useState(false) - const { locale } = useContext(I18NContext) + const locale = useLocale() const handleGetEMailVerificationCode = async () => { try { if (!email) { - Toast.notify({ type: 'error', message: t('login.error.emailEmpty') }) + Toast.notify({ type: 'error', message: t('error.emailEmpty', { ns: 'login' }) }) return } if (!emailRegex.test(email)) { Toast.notify({ type: 'error', - message: t('login.error.emailInValid'), + message: t('error.emailInValid', { ns: 'login' }), }) return } @@ -50,7 +50,7 @@ export default function CheckCode() { else if (res.code === 'account_not_found') { Toast.notify({ type: 'error', - message: t('login.error.registrationNotAllowed'), + message: t('error.registrationNotAllowed', { ns: 'login' }), }) } else { @@ -74,21 +74,21 @@ export default function CheckCode() {
-

{t('login.resetPassword')}

+

{t('resetPassword', { ns: 'login' })}

- {t('login.resetPasswordDesc')} + {t('resetPasswordDesc', { ns: 'login' })}

- +
- setEmail(e.target.value)} /> + setEmail(e.target.value)} />
- +
@@ -99,7 +99,7 @@ export default function CheckCode() {
- {t('login.backToLogin')} + {t('backToLogin', { ns: 'login' })}
) diff --git a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx index bfb71d9c6f..9f59e8f9eb 100644 --- a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx @@ -45,15 +45,15 @@ const ChangePasswordForm = () => { const valid = useCallback(() => { if (!password.trim()) { - showErrorMessage(t('login.error.passwordEmpty')) + showErrorMessage(t('error.passwordEmpty', { ns: 'login' })) return false } if (!validPassword.test(password)) { - showErrorMessage(t('login.error.passwordInvalid')) + showErrorMessage(t('error.passwordInvalid', { ns: 'login' })) return false } if (password !== confirmPassword) { - showErrorMessage(t('common.account.notEqual')) + showErrorMessage(t('account.notEqual', { ns: 'common' })) return false } return true @@ -92,10 +92,10 @@ const ChangePasswordForm = () => {

- {t('login.changePassword')} + {t('changePassword', { ns: 'login' })}

- {t('login.changePasswordTip')} + {t('changePasswordTip', { ns: 'login' })}

@@ -104,7 +104,7 @@ const ChangePasswordForm = () => { {/* Password */}
{ type={showPassword ? 'text' : 'password'} value={password} onChange={e => setPassword(e.target.value)} - placeholder={t('login.passwordPlaceholder') || ''} + placeholder={t('passwordPlaceholder', { ns: 'login' }) || ''} />
@@ -125,12 +125,12 @@ const ChangePasswordForm = () => {
-
{t('login.error.passwordInvalid')}
+
{t('error.passwordInvalid', { ns: 'login' })}
{/* Confirm Password */}
{ type={showConfirmPassword ? 'text' : 'password'} value={confirmPassword} onChange={e => setConfirmPassword(e.target.value)} - placeholder={t('login.confirmPasswordPlaceholder') || ''} + placeholder={t('confirmPasswordPlaceholder', { ns: 'login' }) || ''} />
@@ -171,7 +171,7 @@ const ChangePasswordForm = () => {

- {t('login.passwordChangedTip')} + {t('passwordChangedTip', { ns: 'login' })}

@@ -183,7 +183,7 @@ const ChangePasswordForm = () => { router.replace(getSignInUrl()) }} > - {t('login.passwordChanged')} + {t('passwordChanged', { ns: 'login' })} {' '} ( {Math.round(countdown / 1000)} diff --git a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx index ee7fc22bea..bda5484197 100644 --- a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx @@ -4,12 +4,12 @@ import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react' import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import Countdown from '@/app/components/signin/countdown' -import I18NContext from '@/context/i18n' + +import { useLocale } from '@/context/i18n' import { useWebAppStore } from '@/context/web-app-context' import { sendWebAppEMailLoginCode, webAppEmailLoginWithCode } from '@/service/common' import { fetchAccessToken } from '@/service/share' @@ -23,7 +23,7 @@ export default function CheckCode() { const token = decodeURIComponent(searchParams.get('token') as string) const [code, setVerifyCode] = useState('') const [loading, setIsLoading] = useState(false) - const { locale } = useContext(I18NContext) + const locale = useLocale() const codeInputRef = useRef(null) const redirectUrl = searchParams.get('redirect_url') const embeddedUserId = useWebAppStore(s => s.embeddedUserId) @@ -45,21 +45,21 @@ export default function CheckCode() { if (!code.trim()) { Toast.notify({ type: 'error', - message: t('login.checkCode.emptyCode'), + message: t('checkCode.emptyCode', { ns: 'login' }), }) return } if (!/\d{6}/.test(code)) { Toast.notify({ type: 'error', - message: t('login.checkCode.invalidCode'), + message: t('checkCode.invalidCode', { ns: 'login' }), }) return } if (!redirectUrl || !appCode) { Toast.notify({ type: 'error', - message: t('login.error.redirectUrlMissing'), + message: t('error.redirectUrlMissing', { ns: 'login' }), }) return } @@ -108,19 +108,19 @@ export default function CheckCode() {
-

{t('login.checkCode.checkYourEmail')}

+

{t('checkCode.checkYourEmail', { ns: 'login' })}

- {t('login.checkCode.tipsPrefix')} + {t('checkCode.tipsPrefix', { ns: 'login' })} {email}
- {t('login.checkCode.validTime')} + {t('checkCode.validTime', { ns: 'login' })}

- + setVerifyCode(e.target.value)} maxLength={6} className="mt-1" - placeholder={t('login.checkCode.verificationCodePlaceholder') || ''} + placeholder={t('checkCode.verificationCodePlaceholder', { ns: 'login' }) || ''} /> - +
@@ -140,7 +140,7 @@ export default function CheckCode() {
- {t('login.back')} + {t('back', { ns: 'login' })}
) diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx index 8b611b9eea..5aa9d9f141 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx @@ -1,14 +1,13 @@ -import { noop } from 'es-toolkit/compat' +import { noop } from 'es-toolkit/function' import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown' import { emailRegex } from '@/config' -import I18NContext from '@/context/i18n' +import { useLocale } from '@/context/i18n' import { sendWebAppEMailLoginCode } from '@/service/common' export default function MailAndCodeAuth() { @@ -18,19 +17,19 @@ export default function MailAndCodeAuth() { const emailFromLink = decodeURIComponent(searchParams.get('email') || '') const [email, setEmail] = useState(emailFromLink) const [loading, setIsLoading] = useState(false) - const { locale } = useContext(I18NContext) + const locale = useLocale() const handleGetEMailVerificationCode = async () => { try { if (!email) { - Toast.notify({ type: 'error', message: t('login.error.emailEmpty') }) + Toast.notify({ type: 'error', message: t('error.emailEmpty', { ns: 'login' }) }) return } if (!emailRegex.test(email)) { Toast.notify({ type: 'error', - message: t('login.error.emailInValid'), + message: t('error.emailInValid', { ns: 'login' }), }) return } @@ -56,12 +55,12 @@ export default function MailAndCodeAuth() {
- +
- setEmail(e.target.value)} /> + setEmail(e.target.value)} />
- +
diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx index 46645ed68c..23ac83e76c 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx @@ -1,15 +1,14 @@ 'use client' -import { noop } from 'es-toolkit/compat' +import { noop } from 'es-toolkit/function' import Link from 'next/link' import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import { emailRegex } from '@/config' -import I18NContext from '@/context/i18n' +import { useLocale } from '@/context/i18n' import { useWebAppStore } from '@/context/web-app-context' import { webAppLogin } from '@/service/common' import { fetchAccessToken } from '@/service/share' @@ -21,7 +20,7 @@ type MailAndPasswordAuthProps = { export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAuthProps) { const { t } = useTranslation() - const { locale } = useContext(I18NContext) + const locale = useLocale() const router = useRouter() const searchParams = useSearchParams() const [showPassword, setShowPassword] = useState(false) @@ -46,25 +45,25 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut const appCode = getAppCodeFromRedirectUrl() const handleEmailPasswordLogin = async () => { if (!email) { - Toast.notify({ type: 'error', message: t('login.error.emailEmpty') }) + Toast.notify({ type: 'error', message: t('error.emailEmpty', { ns: 'login' }) }) return } if (!emailRegex.test(email)) { Toast.notify({ type: 'error', - message: t('login.error.emailInValid'), + message: t('error.emailInValid', { ns: 'login' }), }) return } if (!password?.trim()) { - Toast.notify({ type: 'error', message: t('login.error.passwordEmpty') }) + Toast.notify({ type: 'error', message: t('error.passwordEmpty', { ns: 'login' }) }) return } if (!redirectUrl || !appCode) { Toast.notify({ type: 'error', - message: t('login.error.redirectUrlMissing'), + message: t('error.redirectUrlMissing', { ns: 'login' }), }) return } @@ -111,7 +110,7 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
@@ -128,14 +127,14 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
@@ -149,7 +148,7 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut }} type={showPassword ? 'text' : 'password'} autoComplete="current-password" - placeholder={t('login.passwordPlaceholder') || ''} + placeholder={t('passwordPlaceholder', { ns: 'login' }) || ''} tabIndex={2} />
@@ -172,7 +171,7 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut disabled={isLoading || !email || !password} className="w-full" > - {t('login.signBtn')} + {t('signBtn', { ns: 'login' })}
diff --git a/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx index 472952c2c8..d8f3854868 100644 --- a/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx @@ -82,7 +82,7 @@ const SSOAuth: FC = ({ className="w-full" > - {t('login.withSSO')} + {t('withSSO', { ns: 'login' })} ) } diff --git a/web/app/(shareLayout)/webapp-signin/layout.tsx b/web/app/(shareLayout)/webapp-signin/layout.tsx index dd4510a541..21cb0e1f57 100644 --- a/web/app/(shareLayout)/webapp-signin/layout.tsx +++ b/web/app/(shareLayout)/webapp-signin/layout.tsx @@ -9,7 +9,7 @@ import { cn } from '@/utils/classnames' export default function SignInLayout({ children }: PropsWithChildren) { const { t } = useTranslation() const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) - useDocumentTitle(t('login.webapp.login')) + useDocumentTitle(t('webapp.login', { ns: 'login' })) return ( <>
diff --git a/web/app/(shareLayout)/webapp-signin/normalForm.tsx b/web/app/(shareLayout)/webapp-signin/normalForm.tsx index 40d34dcaf5..b15145346f 100644 --- a/web/app/(shareLayout)/webapp-signin/normalForm.tsx +++ b/web/app/(shareLayout)/webapp-signin/normalForm.tsx @@ -60,8 +60,8 @@ const NormalForm = () => {
-

{t('login.licenseLost')}

-

{t('login.licenseLostTip')}

+

{t('licenseLost', { ns: 'login' })}

+

{t('licenseLostTip', { ns: 'login' })}

@@ -76,8 +76,8 @@ const NormalForm = () => {
-

{t('login.licenseExpired')}

-

{t('login.licenseExpiredTip')}

+

{t('licenseExpired', { ns: 'login' })}

+

{t('licenseExpiredTip', { ns: 'login' })}

@@ -92,8 +92,8 @@ const NormalForm = () => { -

{t('login.licenseInactive')}

-

{t('login.licenseInactiveTip')}

+

{t('licenseInactive', { ns: 'login' })}

+

{t('licenseInactiveTip', { ns: 'login' })}

@@ -104,8 +104,8 @@ const NormalForm = () => { <>
-

{systemFeatures.branding.enabled ? t('login.pageTitleForE') : t('login.pageTitle')}

-

{t('login.welcome')}

+

{systemFeatures.branding.enabled ? t('pageTitleForE', { ns: 'login' }) : t('pageTitle', { ns: 'login' })}

+

{t('welcome', { ns: 'login' })}

@@ -122,7 +122,7 @@ const NormalForm = () => {
- {t('login.or')} + {t('or', { ns: 'login' })}
)} @@ -134,7 +134,7 @@ const NormalForm = () => { {systemFeatures.enable_email_password_login && (
{ updateAuthType('password') }}> - {t('login.usePassword')} + {t('usePassword', { ns: 'login' })}
)} @@ -144,7 +144,7 @@ const NormalForm = () => { {systemFeatures.enable_email_code_login && (
{ updateAuthType('code') }}> - {t('login.useVerificationCode')} + {t('useVerificationCode', { ns: 'login' })}
)} @@ -158,8 +158,8 @@ const NormalForm = () => {
-

{t('login.noLoginMethod')}

-

{t('login.noLoginMethodTip')}

+

{t('noLoginMethod', { ns: 'login' })}

+

{t('noLoginMethodTip', { ns: 'login' })}

{step === STEP.start && ( <> -
{t('common.account.changeEmail.title')}
+
{t('account.changeEmail.title', { ns: 'common' })}
-
{t('common.account.changeEmail.authTip')}
+
{t('account.changeEmail.authTip', { ns: 'common' })}
}} values={{ email }} /> @@ -227,34 +228,35 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { variant="primary" onClick={sendCodeToOriginEmail} > - {t('common.account.changeEmail.sendVerifyCode')} + {t('account.changeEmail.sendVerifyCode', { ns: 'common' })}
)} {step === STEP.verifyOrigin && ( <> -
{t('common.account.changeEmail.verifyEmail')}
+
{t('account.changeEmail.verifyEmail', { ns: 'common' })}
}} values={{ email }} />
-
{t('common.account.changeEmail.codeLabel')}
+
{t('account.changeEmail.codeLabel', { ns: 'common' })}
setCode(e.target.value)} maxLength={6} @@ -267,46 +269,46 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { variant="primary" onClick={handleVerifyOriginEmail} > - {t('common.account.changeEmail.continue')} + {t('account.changeEmail.continue', { ns: 'common' })}
- {t('common.account.changeEmail.resendTip')} + {t('account.changeEmail.resendTip', { ns: 'common' })} {time > 0 && ( - {t('common.account.changeEmail.resendCount', { count: time })} + {t('account.changeEmail.resendCount', { ns: 'common', count: time })} )} {!time && ( - {t('common.account.changeEmail.resend')} + {t('account.changeEmail.resend', { ns: 'common' })} )}
)} {step === STEP.newEmail && ( <> -
{t('common.account.changeEmail.newEmail')}
+
{t('account.changeEmail.newEmail', { ns: 'common' })}
-
{t('common.account.changeEmail.content3')}
+
{t('account.changeEmail.content3', { ns: 'common' })}
-
{t('common.account.changeEmail.emailLabel')}
+
{t('account.changeEmail.emailLabel', { ns: 'common' })}
handleNewEmailValueChange(e.target.value)} destructive={newEmailExited || unAvailableEmail} /> {newEmailExited && ( -
{t('common.account.changeEmail.existingEmail')}
+
{t('account.changeEmail.existingEmail', { ns: 'common' })}
)} {unAvailableEmail && ( -
{t('common.account.changeEmail.unAvailableEmail')}
+
{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}
)}
@@ -316,34 +318,35 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { variant="primary" onClick={sendCodeToNewEmail} > - {t('common.account.changeEmail.sendVerifyCode')} + {t('account.changeEmail.sendVerifyCode', { ns: 'common' })}
)} {step === STEP.verifyNew && ( <> -
{t('common.account.changeEmail.verifyNew')}
+
{t('account.changeEmail.verifyNew', { ns: 'common' })}
}} values={{ email: mail }} />
-
{t('common.account.changeEmail.codeLabel')}
+
{t('account.changeEmail.codeLabel', { ns: 'common' })}
setCode(e.target.value)} maxLength={6} @@ -356,22 +359,22 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { variant="primary" onClick={submitNewEmail} > - {t('common.account.changeEmail.changeTo', { email: mail })} + {t('account.changeEmail.changeTo', { ns: 'common', email: mail })}
- {t('common.account.changeEmail.resendTip')} + {t('account.changeEmail.resendTip', { ns: 'common' })} {time > 0 && ( - {t('common.account.changeEmail.resendCount', { count: time })} + {t('account.changeEmail.resendCount', { ns: 'common', count: time })} )} {!time && ( - {t('common.account.changeEmail.resend')} + {t('account.changeEmail.resend', { ns: 'common' })} )}
diff --git a/web/app/account/(commonLayout)/account-page/index.tsx b/web/app/account/(commonLayout)/account-page/index.tsx index baa9759ffe..f01efc002c 100644 --- a/web/app/account/(commonLayout)/account-page/index.tsx +++ b/web/app/account/(commonLayout)/account-page/index.tsx @@ -61,7 +61,7 @@ export default function AccountPage() { try { setEditing(true) await updateUserProfile({ url: 'account/name', body: { name: editName } }) - notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) mutateUserProfile() setEditNameModalVisible(false) setEditing(false) @@ -80,15 +80,15 @@ export default function AccountPage() { } const valid = () => { if (!password.trim()) { - showErrorMessage(t('login.error.passwordEmpty')) + showErrorMessage(t('error.passwordEmpty', { ns: 'login' })) return false } if (!validPassword.test(password)) { - showErrorMessage(t('login.error.passwordInvalid')) + showErrorMessage(t('error.passwordInvalid', { ns: 'login' })) return false } if (password !== confirmPassword) { - showErrorMessage(t('common.account.notEqual')) + showErrorMessage(t('account.notEqual', { ns: 'common' })) return false } @@ -112,7 +112,7 @@ export default function AccountPage() { repeat_new_password: confirmPassword, }, }) - notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) mutateUserProfile() setEditPasswordModalVisible(false) resetPasswordForm() @@ -146,7 +146,7 @@ export default function AccountPage() { return ( <>
-

{t('common.account.myAccount')}

+

{t('account.myAccount', { ns: 'common' })}

@@ -164,25 +164,25 @@ export default function AccountPage() {
-
{t('common.account.name')}
+
{t('account.name', { ns: 'common' })}
{userProfile.name}
- {t('common.operation.edit')} + {t('operation.edit', { ns: 'common' })}
-
{t('common.account.email')}
+
{t('account.email', { ns: 'common' })}
{userProfile.email}
{systemFeatures.enable_change_email && (
setShowUpdateEmail(true)}> - {t('common.operation.change')} + {t('operation.change', { ns: 'common' })}
)}
@@ -191,26 +191,26 @@ export default function AccountPage() { systemFeatures.enable_email_password_login && (
-
{t('common.account.password')}
-
{t('common.account.passwordTip')}
+
{t('account.password', { ns: 'common' })}
+
{t('account.passwordTip', { ns: 'common' })}
- +
) }
-
{t('common.account.langGeniusAccount')}
-
{t('common.account.langGeniusAccountTip')}
+
{t('account.langGeniusAccount', { ns: 'common' })}
+
{t('account.langGeniusAccountTip', { ns: 'common' })}
{!!apps.length && ( ({ ...app, key: app.id, name: app.name }))} renderItem={renderAppItem} wrapperClassName="mt-2" /> )} - {!IS_CE_EDITION && } + {!IS_CE_EDITION && }
{ editNameModalVisible && ( @@ -219,21 +219,21 @@ export default function AccountPage() { onClose={() => setEditNameModalVisible(false)} className="!w-[420px] !p-6" > -
{t('common.account.editName')}
-
{t('common.account.name')}
+
{t('account.editName', { ns: 'common' })}
+
{t('account.name', { ns: 'common' })}
setEditName(e.target.value)} />
- +
@@ -249,10 +249,10 @@ export default function AccountPage() { }} className="!w-[420px] !p-6" > -
{userProfile.is_password_set ? t('common.account.resetPassword') : t('common.account.setPassword')}
+
{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}
{userProfile.is_password_set && ( <> -
{t('common.account.currentPassword')}
+
{t('account.currentPassword', { ns: 'common' })}
)}
- {userProfile.is_password_set ? t('common.account.newPassword') : t('common.account.password')} + {userProfile.is_password_set ? t('account.newPassword', { ns: 'common' }) : t('account.password', { ns: 'common' })}
-
{t('common.account.confirmPassword')}
+
{t('account.confirmPassword', { ns: 'common' })}
- {t('common.operation.cancel')} + {t('operation.cancel', { ns: 'common' })}
diff --git a/web/app/account/(commonLayout)/avatar.tsx b/web/app/account/(commonLayout)/avatar.tsx index 3cbbb47a76..8ea29e8e45 100644 --- a/web/app/account/(commonLayout)/avatar.tsx +++ b/web/app/account/(commonLayout)/avatar.tsx @@ -94,7 +94,7 @@ export default function AppSelector() { className="group flex h-9 cursor-pointer items-center justify-start rounded-lg px-3 hover:bg-state-base-hover" > -
{t('common.userProfile.logout')}
+
{t('userProfile.logout', { ns: 'common' })}
diff --git a/web/app/account/(commonLayout)/delete-account/components/check-email.tsx b/web/app/account/(commonLayout)/delete-account/components/check-email.tsx index c0bbf4422a..65a58c936e 100644 --- a/web/app/account/(commonLayout)/delete-account/components/check-email.tsx +++ b/web/app/account/(commonLayout)/delete-account/components/check-email.tsx @@ -31,22 +31,22 @@ export default function CheckEmail(props: DeleteAccountProps) { return ( <>
- {t('common.account.deleteTip')} + {t('account.deleteTip', { ns: 'common' })}
- {t('common.account.deletePrivacyLinkTip')} - {t('common.account.deletePrivacyLink')} + {t('account.deletePrivacyLinkTip', { ns: 'common' })} + {t('account.deletePrivacyLink', { ns: 'common' })}
- + { setUserInputEmail(e.target.value) }} />
- - + +
) diff --git a/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx b/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx index 4a1b41cb20..67fea3c141 100644 --- a/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx +++ b/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx @@ -28,7 +28,7 @@ export default function FeedBack(props: DeleteAccountProps) { await logout() // Tokens are now stored in cookies and cleared by backend router.push('/signin') - Toast.notify({ type: 'info', message: t('common.account.deleteSuccessTip') }) + Toast.notify({ type: 'info', message: t('account.deleteSuccessTip', { ns: 'common' }) }) } catch (error) { console.error(error) } }, [router, t]) @@ -50,22 +50,22 @@ export default function FeedBack(props: DeleteAccountProps) { - + @@ -116,10 +116,10 @@ const MCPServerModal = ({ {latestParams.length > 0 && (
-
{t('tools.mcp.server.modal.parameters')}
+
{t('mcp.server.modal.parameters', { ns: 'tools' })}
-
{t('tools.mcp.server.modal.parametersTip')}
+
{t('mcp.server.modal.parametersTip', { ns: 'tools' })}
{latestParams.map(paramItem => (
- - + +
) diff --git a/web/app/components/tools/mcp/mcp-server-param-item.tsx b/web/app/components/tools/mcp/mcp-server-param-item.tsx index d951f19caa..db27cfdf98 100644 --- a/web/app/components/tools/mcp/mcp-server-param-item.tsx +++ b/web/app/components/tools/mcp/mcp-server-param-item.tsx @@ -27,7 +27,7 @@ const MCPServerParamItem = ({ diff --git a/web/app/components/tools/mcp/mcp-service-card.tsx b/web/app/components/tools/mcp/mcp-service-card.tsx index 07aa5e0168..fc6de3bc3d 100644 --- a/web/app/components/tools/mcp/mcp-service-card.tsx +++ b/web/app/components/tools/mcp/mcp-service-card.tsx @@ -172,7 +172,7 @@ function MCPServiceCard({
- {t('tools.mcp.server.title')} + {t('mcp.server.title', { ns: 'tools' })}
@@ -180,8 +180,8 @@ function MCPServiceCard({
{serverActivated - ? t('appOverview.overview.status.running') - : t('appOverview.overview.status.disable')} + ? t('overview.status.running', { ns: 'appOverview' }) + : t('overview.status.disable', { ns: 'appOverview' })}
- {t('appOverview.overview.appInfo.enableTooltip.description')} + {t('overview.appInfo.enableTooltip.description', { ns: 'appOverview' })}
window.open(docLink('/guides/workflow/node/user-input'), '_blank')} > - {t('appOverview.overview.appInfo.enableTooltip.learnMore')} + {t('overview.appInfo.enableTooltip.learnMore', { ns: 'appOverview' })}
) @@ -222,7 +222,7 @@ function MCPServiceCard({ {!isMinimalState && (
- {t('tools.mcp.server.url')} + {t('mcp.server.url', { ns: 'tools' })}
@@ -239,7 +239,7 @@ function MCPServiceCard({ {isCurrentWorkspaceManager && (
-
{serverPublished ? t('tools.mcp.server.edit') : t('tools.mcp.server.addDescription')}
+
{serverPublished ? t('mcp.server.edit', { ns: 'tools' }) : t('mcp.server.addDescription', { ns: 'tools' })}
@@ -287,8 +287,8 @@ function MCPServiceCard({ {showConfirmDelete && ( { onGenCode() diff --git a/web/app/components/tools/mcp/modal.tsx b/web/app/components/tools/mcp/modal.tsx index c5cde65674..413a2d3948 100644 --- a/web/app/components/tools/mcp/modal.tsx +++ b/web/app/components/tools/mcp/modal.tsx @@ -5,7 +5,7 @@ import type { ToolWithProvider } from '@/app/components/workflow/types' import type { AppIconType } from '@/types/app' import { RiCloseLine, RiEditLine } from '@remixicon/react' import { useHover } from 'ahooks' -import { noop } from 'es-toolkit/compat' +import { noop } from 'es-toolkit/function' import * as React from 'react' import { useCallback, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -81,15 +81,15 @@ const MCPModal = ({ const authMethods = [ { - text: t('tools.mcp.modal.authentication'), + text: t('mcp.modal.authentication', { ns: 'tools' }), value: MCPAuthMethod.authentication, }, { - text: t('tools.mcp.modal.headers'), + text: t('mcp.modal.headers', { ns: 'tools' }), value: MCPAuthMethod.headers, }, { - text: t('tools.mcp.modal.configurations'), + text: t('mcp.modal.configurations', { ns: 'tools' }), value: MCPAuthMethod.configurations, }, ] @@ -231,33 +231,33 @@ const MCPModal = ({
-
{!isCreate ? t('tools.mcp.modal.editTitle') : t('tools.mcp.modal.title')}
+
{!isCreate ? t('mcp.modal.editTitle', { ns: 'tools' }) : t('mcp.modal.title', { ns: 'tools' })}
- {t('tools.mcp.modal.serverUrl')} + {t('mcp.modal.serverUrl', { ns: 'tools' })}
setUrl(e.target.value)} onBlur={e => handleBlur(e.target.value.trim())} - placeholder={t('tools.mcp.modal.serverUrlPlaceholder')} + placeholder={t('mcp.modal.serverUrlPlaceholder', { ns: 'tools' })} /> {originalServerUrl && originalServerUrl !== url && (
- {t('tools.mcp.modal.serverUrlWarning')} + {t('mcp.modal.serverUrlWarning', { ns: 'tools' })}
)}
- {t('tools.mcp.modal.name')} + {t('mcp.modal.name', { ns: 'tools' })}
setName(e.target.value)} - placeholder={t('tools.mcp.modal.namePlaceholder')} + placeholder={t('mcp.modal.namePlaceholder', { ns: 'tools' })} />
@@ -284,17 +284,17 @@ const MCPModal = ({
- {t('tools.mcp.modal.serverIdentifier')} + {t('mcp.modal.serverIdentifier', { ns: 'tools' })}
-
{t('tools.mcp.modal.serverIdentifierTip')}
+
{t('mcp.modal.serverIdentifierTip', { ns: 'tools' })}
setServerIdentifier(e.target.value)} - placeholder={t('tools.mcp.modal.serverIdentifierPlaceholder')} + placeholder={t('mcp.modal.serverIdentifierPlaceholder', { ns: 'tools' })} /> {originalServerID && originalServerID !== serverIdentifier && (
- {t('tools.mcp.modal.serverIdentifierWarning')} + {t('mcp.modal.serverIdentifierWarning', { ns: 'tools' })}
)}
@@ -317,13 +317,13 @@ const MCPModal = ({ defaultValue={isDynamicRegistration} onChange={setIsDynamicRegistration} /> - {t('tools.mcp.modal.useDynamicClientRegistration')} + {t('mcp.modal.useDynamicClientRegistration', { ns: 'tools' })}
{!isDynamicRegistration && (
-
{t('tools.mcp.modal.redirectUrlWarning')}
+
{t('mcp.modal.redirectUrlWarning', { ns: 'tools' })}
{`${API_PREFIX}/mcp/oauth/callback`} @@ -333,25 +333,25 @@ const MCPModal = ({
- {t('tools.mcp.modal.clientID')} + {t('mcp.modal.clientID', { ns: 'tools' })}
setClientID(e.target.value)} onBlur={e => handleBlur(e.target.value.trim())} - placeholder={t('tools.mcp.modal.clientID')} + placeholder={t('mcp.modal.clientID', { ns: 'tools' })} disabled={isDynamicRegistration} />
- {t('tools.mcp.modal.clientSecret')} + {t('mcp.modal.clientSecret', { ns: 'tools' })}
setCredentials(e.target.value)} onBlur={e => handleBlur(e.target.value.trim())} - placeholder={t('tools.mcp.modal.clientSecretPlaceholder')} + placeholder={t('mcp.modal.clientSecretPlaceholder', { ns: 'tools' })} disabled={isDynamicRegistration} />
@@ -362,9 +362,9 @@ const MCPModal = ({ authMethod === MCPAuthMethod.headers && (
- {t('tools.mcp.modal.headers')} + {t('mcp.modal.headers', { ns: 'tools' })}
-
{t('tools.mcp.modal.headersTip')}
+
{t('mcp.modal.headersTip', { ns: 'tools' })}
- {t('tools.mcp.modal.timeout')} + {t('mcp.modal.timeout', { ns: 'tools' })}
setMcpTimeout(Number(e.target.value))} onBlur={e => handleBlur(e.target.value.trim())} - placeholder={t('tools.mcp.modal.timeoutPlaceholder')} + placeholder={t('mcp.modal.timeoutPlaceholder', { ns: 'tools' })} />
- {t('tools.mcp.modal.sseReadTimeout')} + {t('mcp.modal.sseReadTimeout', { ns: 'tools' })}
setSseReadTimeout(Number(e.target.value))} onBlur={e => handleBlur(e.target.value.trim())} - placeholder={t('tools.mcp.modal.timeoutPlaceholder')} + placeholder={t('mcp.modal.timeoutPlaceholder', { ns: 'tools' })} />
@@ -406,8 +406,8 @@ const MCPModal = ({ }
- - + +
{showAppIconPicker && ( diff --git a/web/app/components/tools/mcp/provider-card.tsx b/web/app/components/tools/mcp/provider-card.tsx index a7b092e0c0..d8a8e71a82 100644 --- a/web/app/components/tools/mcp/provider-card.tsx +++ b/web/app/components/tools/mcp/provider-card.tsx @@ -96,19 +96,19 @@ const MCPCard = ({
{data.tools.length > 0 && ( -
{t('tools.mcp.toolsCount', { count: data.tools.length })}
+
{t('mcp.toolsCount', { ns: 'tools', count: data.tools.length })}
)} {!data.tools.length && ( -
{t('tools.mcp.noTools')}
+
{t('mcp.noTools', { ns: 'tools' })}
)}
/
-
{`${t('tools.mcp.updateTime')} ${formatTimeFromNow(data.updated_at! * 1000)}`}
+
{`${t('mcp.updateTime', { ns: 'tools' })} ${formatTimeFromNow(data.updated_at! * 1000)}`}
{data.is_team_authorization && data.tools.length > 0 && } {(!data.is_team_authorization || !data.tools.length) && (
- {t('tools.mcp.noConfigured')} + {t('mcp.noConfigured', { ns: 'tools' })}
)} @@ -134,10 +134,10 @@ const MCPCard = ({ {isShowDeleteConfirm && ( - {t('tools.mcp.deleteConfirmTitle', { mcp: data.name })} + {t('mcp.deleteConfirmTitle', { ns: 'tools', mcp: data.name })}
)} onCancel={hideDeleteConfirm} diff --git a/web/app/components/tools/provider-list.tsx b/web/app/components/tools/provider-list.tsx index 95f36afcc3..48fd4ef29d 100644 --- a/web/app/components/tools/provider-list.tsx +++ b/web/app/components/tools/provider-list.tsx @@ -49,9 +49,9 @@ const ProviderList = () => { defaultValue: 'builtin', }) const options = [ - { value: 'builtin', text: t('tools.type.builtIn') }, - { value: 'api', text: t('tools.type.custom') }, - { value: 'workflow', text: t('tools.type.workflow') }, + { value: 'builtin', text: t('type.builtIn', { ns: 'tools' }) }, + { value: 'api', text: t('type.custom', { ns: 'tools' }) }, + { value: 'workflow', text: t('type.workflow', { ns: 'tools' }) }, { value: 'mcp', text: 'MCP' }, ] const [tagFilterValue, setTagFilterValue] = useState([]) @@ -194,7 +194,7 @@ const ProviderList = () => {
)} {!filteredCollectionList.length && activeTab === 'builtin' && ( - + )}
{enable_marketplace && activeTab === 'builtin' && ( diff --git a/web/app/components/tools/provider/custom-create-card.tsx b/web/app/components/tools/provider/custom-create-card.tsx index ba0c9e6449..637d17c3c3 100644 --- a/web/app/components/tools/provider/custom-create-card.tsx +++ b/web/app/components/tools/provider/custom-create-card.tsx @@ -7,11 +7,10 @@ import { } from '@remixicon/react' import { useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import Toast from '@/app/components/base/toast' import EditCustomToolModal from '@/app/components/tools/edit-custom-collection-modal' import { useAppContext } from '@/context/app-context' -import I18n, { useDocLink } from '@/context/i18n' +import { useDocLink, useLocale } from '@/context/i18n' import { getLanguage } from '@/i18n-config/language' import { createCustomCollection } from '@/service/tools' @@ -21,7 +20,7 @@ type Props = { const Contribute = ({ onRefreshData }: Props) => { const { t } = useTranslation() - const { locale } = useContext(I18n) + const locale = useLocale() const language = getLanguage(locale) const { isCurrentWorkspaceManager } = useAppContext() @@ -37,7 +36,7 @@ const Contribute = ({ onRefreshData }: Props) => { await createCustomCollection(data) Toast.notify({ type: 'success', - message: t('common.api.actionSuccess'), + message: t('api.actionSuccess', { ns: 'common' }), }) setIsShowEditCustomCollectionModal(false) onRefreshData() @@ -52,13 +51,13 @@ const Contribute = ({ onRefreshData }: Props) => {
-
{t('tools.createCustomTool')}
+
{t('createCustomTool', { ns: 'tools' })}
diff --git a/web/app/components/tools/provider/detail.tsx b/web/app/components/tools/provider/detail.tsx index e0a2281696..a23f722cbe 100644 --- a/web/app/components/tools/provider/detail.tsx +++ b/web/app/components/tools/provider/detail.tsx @@ -6,7 +6,6 @@ import { import * as React from 'react' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import ActionButton from '@/app/components/base/action-button' import Button from '@/app/components/base/button' import Confirm from '@/app/components/base/confirm' @@ -24,7 +23,7 @@ import EditCustomToolModal from '@/app/components/tools/edit-custom-collection-m import ConfigCredential from '@/app/components/tools/setting/build-in/config-credentials' import WorkflowToolModal from '@/app/components/tools/workflow-tool' import { useAppContext } from '@/context/app-context' -import I18n from '@/context/i18n' +import { useLocale } from '@/context/i18n' import { useModalContext } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' @@ -60,7 +59,7 @@ const ProviderDetail = ({ onRefreshData, }: Props) => { const { t } = useTranslation() - const { locale } = useContext(I18n) + const locale = useLocale() const language = getLanguage(locale) const needAuth = collection.allow_delete || collection.type === CollectionType.model @@ -124,7 +123,7 @@ const ProviderDetail = ({ setCustomCollection(prev => prev ? { ...prev, labels: data.labels } : null) Toast.notify({ type: 'success', - message: t('common.api.actionSuccess'), + message: t('api.actionSuccess', { ns: 'common' }), }) setIsShowEditCustomCollectionModal(false) } @@ -133,7 +132,7 @@ const ProviderDetail = ({ onRefreshData() Toast.notify({ type: 'success', - message: t('common.api.actionSuccess'), + message: t('api.actionSuccess', { ns: 'common' }), }) setIsShowEditCustomCollectionModal(false) } @@ -163,7 +162,7 @@ const ProviderDetail = ({ onRefreshData() Toast.notify({ type: 'success', - message: t('common.api.actionSuccess'), + message: t('api.actionSuccess', { ns: 'common' }), }) setIsShowEditWorkflowToolModal(false) } @@ -177,7 +176,7 @@ const ProviderDetail = ({ getWorkflowToolProvider() Toast.notify({ type: 'success', - message: t('common.api.actionSuccess'), + message: t('api.actionSuccess', { ns: 'common' }), }) setIsShowEditWorkflowToolModal(false) } @@ -275,7 +274,7 @@ const ProviderDetail = ({ onClick={() => setIsShowEditCustomCollectionModal(true)} > -
{t('tools.createTool.editAction')}
+
{t('createTool.editAction', { ns: 'tools' })}
)} {collection.type === CollectionType.workflow && !isDetailLoading && customCollection && ( @@ -285,7 +284,7 @@ const ProviderDetail = ({ className={cn('my-3 w-[183px] shrink-0')} > -
{t('tools.openInStudio')}
+
{t('openInStudio', { ns: 'tools' })}
@@ -294,7 +293,7 @@ const ProviderDetail = ({ onClick={() => setIsShowEditWorkflowToolModal(true)} disabled={!isCurrentWorkspaceManager} > -
{t('tools.createTool.editAction')}
+
{t('createTool.editAction', { ns: 'tools' })}
)} @@ -306,7 +305,7 @@ const ProviderDetail = ({
{(collection.type === CollectionType.builtIn || collection.type === CollectionType.model) && isAuthed && (
- {t('plugin.detailPanel.actionNum', { num: toolList.length, action: toolList.length > 1 ? 'actions' : 'action' })} + {t('detailPanel.actionNum', { ns: 'plugin', num: toolList.length, action: toolList.length > 1 ? 'actions' : 'action' })} {needAuth && ( )}
@@ -326,9 +325,9 @@ const ProviderDetail = ({ {(collection.type === CollectionType.builtIn || collection.type === CollectionType.model) && needAuth && !isAuthed && ( <>
- {t('tools.includeToolNum', { num: toolList.length, action: toolList.length > 1 ? 'actions' : 'action' }).toLocaleUpperCase()} + {t('includeToolNum', { ns: 'tools', num: toolList.length, action: toolList.length > 1 ? 'actions' : 'action' }).toLocaleUpperCase()} Ā· - {t('tools.auth.setup').toLocaleUpperCase()} + {t('auth.setup', { ns: 'tools' }).toLocaleUpperCase()}
)} {(collection.type === CollectionType.custom) && (
- {t('tools.includeToolNum', { num: toolList.length, action: toolList.length > 1 ? 'actions' : 'action' }).toLocaleUpperCase()} + {t('includeToolNum', { ns: 'tools', num: toolList.length, action: toolList.length > 1 ? 'actions' : 'action' }).toLocaleUpperCase()}
)} {(collection.type === CollectionType.workflow) && (
- {t('tools.createTool.toolInput.title').toLocaleUpperCase()} + {t('createTool.toolInput.title', { ns: 'tools' }).toLocaleUpperCase()}
)}
@@ -370,7 +369,7 @@ const ProviderDetail = ({
{item.name} {item.type} - {item.required ? t('tools.createTool.toolInput.required') : ''} + {item.required ? t('createTool.toolInput.required', { ns: 'tools' }) : ''}
{item.llm_description}
@@ -387,7 +386,7 @@ const ProviderDetail = ({ await updateBuiltInToolCredential(collection.name, value) Toast.notify({ type: 'success', - message: t('common.api.actionSuccess'), + message: t('api.actionSuccess', { ns: 'common' }), }) await onRefreshData() setShowSettingAuth(false) @@ -396,7 +395,7 @@ const ProviderDetail = ({ await removeBuiltInToolCredential(collection.name) Toast.notify({ type: 'success', - message: t('common.api.actionSuccess'), + message: t('api.actionSuccess', { ns: 'common' }), }) await onRefreshData() setShowSettingAuth(false) @@ -421,8 +420,8 @@ const ProviderDetail = ({ )} {showConfirmDelete && ( setShowConfirmDelete(false)} diff --git a/web/app/components/tools/provider/empty.tsx b/web/app/components/tools/provider/empty.tsx index e79607751e..4940dd6fc5 100644 --- a/web/app/components/tools/provider/empty.tsx +++ b/web/app/components/tools/provider/empty.tsx @@ -32,18 +32,18 @@ const Empty = ({ const hasLink = type && [ToolTypeEnum.Custom, ToolTypeEnum.MCP].includes(type) const Comp = (hasLink ? Link : 'div') as any const linkProps = hasLink ? { href: getLink(type), target: '_blank' } : {} - const renderType = isAgent ? 'agent' : type - const hasTitle = t(`tools.addToolModal.${renderType}.title` as any) as string !== `tools.addToolModal.${renderType}.title` + const renderType = isAgent ? 'agent' as const : type + const hasTitle = renderType && t(`addToolModal.${renderType}.title`, { ns: 'tools' }) !== `addToolModal.${renderType}.title` return (
- {hasTitle ? t(`tools.addToolModal.${renderType}.title` as any) as string : 'No tools available'} + {(hasTitle && renderType) ? t(`addToolModal.${renderType}.title`, { ns: 'tools' }) : 'No tools available'}
- {(!isAgent && hasTitle) && ( + {(!isAgent && hasTitle && renderType) && ( - {t(`tools.addToolModal.${renderType}.tip` as any) as string} + {t(`addToolModal.${renderType}.tip`, { ns: 'tools' })} {' '} {hasLink && } diff --git a/web/app/components/tools/provider/tool-item.tsx b/web/app/components/tools/provider/tool-item.tsx index b240bf6a41..4e28a7427b 100644 --- a/web/app/components/tools/provider/tool-item.tsx +++ b/web/app/components/tools/provider/tool-item.tsx @@ -2,9 +2,8 @@ import type { Collection, Tool } from '../types' import * as React from 'react' import { useState } from 'react' -import { useContext } from 'use-context-selector' import SettingBuiltInTool from '@/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool' -import I18n from '@/context/i18n' +import { useLocale } from '@/context/i18n' import { getLanguage } from '@/i18n-config/language' import { cn } from '@/utils/classnames' @@ -23,7 +22,7 @@ const ToolItem = ({ isBuiltIn, isModel, }: Props) => { - const { locale } = useContext(I18n) + const locale = useLocale() const language = getLanguage(locale) const [showDetail, setShowDetail] = useState(false) diff --git a/web/app/components/tools/setting/build-in/config-credentials.tsx b/web/app/components/tools/setting/build-in/config-credentials.tsx index 033052e8a1..863b3ba352 100644 --- a/web/app/components/tools/setting/build-in/config-credentials.tsx +++ b/web/app/components/tools/setting/build-in/config-credentials.tsx @@ -1,7 +1,7 @@ 'use client' import type { FC } from 'react' import type { Collection } from '../../types' -import { noop } from 'es-toolkit/compat' +import { noop } from 'es-toolkit/function' import * as React from 'react' import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -53,7 +53,7 @@ const ConfigCredential: FC = ({ const handleSave = async () => { for (const field of credentialSchema) { if (field.required && !tempCredential[field.name]) { - Toast.notify({ type: 'error', message: t('common.errorMsg.fieldRequired', { field: field.label[language] || field.label.en_US }) }) + Toast.notify({ type: 'error', message: t('errorMsg.fieldRequired', { ns: 'common', field: field.label[language] || field.label.en_US }) }) return } } @@ -71,8 +71,8 @@ const ConfigCredential: FC = ({ = ({ rel="noopener noreferrer" className="inline-flex items-center text-xs text-text-accent" > - {t('tools.howToGet')} + {t('howToGet', { ns: 'tools' })} ) @@ -111,12 +111,12 @@ const ConfigCredential: FC = ({
{ (collection.is_team_authorization && !isHideRemoveBtn) && ( - + ) }
- - + +
diff --git a/web/app/components/tools/workflow-tool/configure-button.tsx b/web/app/components/tools/workflow-tool/configure-button.tsx index f142989ff6..6526722b63 100644 --- a/web/app/components/tools/workflow-tool/configure-button.tsx +++ b/web/app/components/tools/workflow-tool/configure-button.tsx @@ -166,7 +166,7 @@ const WorkflowToolConfigureButton = ({ getDetail(workflowAppId) Toast.notify({ type: 'success', - message: t('common.api.actionSuccess'), + message: t('api.actionSuccess', { ns: 'common' }), }) setShowModal(false) } @@ -187,7 +187,7 @@ const WorkflowToolConfigureButton = ({ getDetail(workflowAppId) Toast.notify({ type: 'success', - message: t('common.api.actionSuccess'), + message: t('api.actionSuccess', { ns: 'common' }), }) setShowModal(false) } @@ -214,14 +214,14 @@ const WorkflowToolConfigureButton = ({ >
- {t('workflow.common.workflowAsTool')} + {t('common.workflowAsTool', { ns: 'workflow' })}
{!published && ( - {t('workflow.common.configureRequired')} + {t('common.configureRequired', { ns: 'workflow' })} )}
@@ -232,10 +232,10 @@ const WorkflowToolConfigureButton = ({ >
- {t('workflow.common.workflowAsTool')} + {t('common.workflowAsTool', { ns: 'workflow' })}
)} @@ -253,7 +253,7 @@ const WorkflowToolConfigureButton = ({ onClick={() => setShowModal(true)} disabled={!isCurrentWorkspaceManager || disabled} > - {t('workflow.common.configure')} + {t('common.configure', { ns: 'workflow' })} {outdated && } {outdated && (
- {t('workflow.common.workflowAsToolTip')} + {t('common.workflowAsToolTip', { ns: 'workflow' })}
)} diff --git a/web/app/components/tools/workflow-tool/confirm-modal/index.tsx b/web/app/components/tools/workflow-tool/confirm-modal/index.tsx index e1a7dff113..0c7e083a56 100644 --- a/web/app/components/tools/workflow-tool/confirm-modal/index.tsx +++ b/web/app/components/tools/workflow-tool/confirm-modal/index.tsx @@ -1,7 +1,7 @@ 'use client' import { RiCloseLine } from '@remixicon/react' -import { noop } from 'es-toolkit/compat' +import { noop } from 'es-toolkit/function' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' @@ -29,14 +29,14 @@ const ConfirmModal = ({ show, onConfirm, onClose }: ConfirmModalProps) => {
-
{t('tools.createTool.confirmTitle')}
+
{t('createTool.confirmTitle', { ns: 'tools' })}
- {t('tools.createTool.confirmTip')} + {t('createTool.confirmTip', { ns: 'tools' })}
- - + +
diff --git a/web/app/components/tools/workflow-tool/index.tsx b/web/app/components/tools/workflow-tool/index.tsx index 8804a4128d..9a2c6a4c4c 100644 --- a/web/app/components/tools/workflow-tool/index.tsx +++ b/web/app/components/tools/workflow-tool/index.tsx @@ -55,19 +55,19 @@ const WorkflowToolAsModal: FC = ({ const reservedOutputParameters: WorkflowToolProviderOutputParameter[] = [ { name: 'text', - description: t('workflow.nodes.tool.outputVars.text'), + description: t('nodes.tool.outputVars.text', { ns: 'workflow' }), type: VarType.string, reserved: true, }, { name: 'files', - description: t('workflow.nodes.tool.outputVars.files.title'), + description: t('nodes.tool.outputVars.files.title', { ns: 'workflow' }), type: VarType.arrayFile, reserved: true, }, { name: 'json', - description: t('workflow.nodes.tool.outputVars.json'), + description: t('nodes.tool.outputVars.json', { ns: 'workflow' }), type: VarType.arrayObject, reserved: true, }, @@ -104,13 +104,13 @@ const WorkflowToolAsModal: FC = ({ const onConfirm = () => { let errorMessage = '' if (!label) - errorMessage = t('common.errorMsg.fieldRequired', { field: t('tools.createTool.name') }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t('createTool.name', { ns: 'tools' }) }) if (!name) - errorMessage = t('common.errorMsg.fieldRequired', { field: t('tools.createTool.nameForToolCall') }) + errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t('createTool.nameForToolCall', { ns: 'tools' }) }) if (!isNameValid(name)) - errorMessage = t('tools.createTool.nameForToolCall') + t('tools.createTool.nameForToolCallTip') + errorMessage = t('createTool.nameForToolCall', { ns: 'tools' }) + t('createTool.nameForToolCallTip', { ns: 'tools' }) if (errorMessage) { Toast.notify({ @@ -152,7 +152,7 @@ const WorkflowToolAsModal: FC = ({ = ({ {/* name & icon */}
- {t('tools.createTool.name')} + {t('createTool.name', { ns: 'tools' })} {' '} *
@@ -171,7 +171,7 @@ const WorkflowToolAsModal: FC = ({ { setShowEmojiPicker(true) }} className="cursor-pointer" iconType="emoji" icon={emoji.content} background={emoji.background} /> setLabel(e.target.value)} /> @@ -180,46 +180,46 @@ const WorkflowToolAsModal: FC = ({ {/* name for tool call */}
- {t('tools.createTool.nameForToolCall')} + {t('createTool.nameForToolCall', { ns: 'tools' })} {' '} * - {t('tools.createTool.nameForToolCallPlaceHolder')} + {t('createTool.nameForToolCallPlaceHolder', { ns: 'tools' })}
)} />
setName(e.target.value)} /> {!isNameValid(name) && ( -
{t('tools.createTool.nameForToolCallTip')}
+
{t('createTool.nameForToolCallTip', { ns: 'tools' })}
)}
{/* description */}
-
{t('tools.createTool.description')}
+
{t('createTool.description', { ns: 'tools' })}