From aacea166d702a91e9f73eca181ad9bb971a2627f Mon Sep 17 00:00:00 2001 From: lyzno1 <92089059+lyzno1@users.noreply.github.com> Date: Tue, 19 Aug 2025 13:47:38 +0800 Subject: [PATCH] fix: resolve merge conflict between Features removal and validation enhancement (#24150) --- .gitignore | 2 + CLAUDE.md | 83 + api/.env.example | 16 + api/configs/feature/__init__.py | 9 + api/configs/middleware/cache/redis_config.py | 20 + api/controllers/common/helpers.py | 5 +- api/controllers/console/app/annotation.py | 6 +- api/controllers/console/app/generator.py | 118 ++ .../console/datasets/datasets_document.py | 8 +- api/controllers/console/datasets/metadata.py | 4 +- api/controllers/service_api/app/annotation.py | 6 +- .../service_api/dataset/dataset.py | 6 +- .../service_api/dataset/metadata.py | 4 +- api/core/app/apps/chat/app_generator.py | 4 +- api/core/app/apps/completion/app_generator.py | 4 +- .../apps/message_based_app_queue_manager.py | 10 - api/core/llm_generator/llm_generator.py | 184 ++ api/core/llm_generator/prompts.py | 113 ++ .../vdb/oceanbase/oceanbase_vector.py | 17 +- .../datasource/vdb/qdrant/qdrant_vector.py | 9 +- api/core/tools/entities/tool_entities.py | 12 +- api/core/workflow/nodes/tool/tool_node.py | 27 - .../document_index_event.py | 0 .../event_handlers/create_document_index.py | 2 +- api/extensions/ext_celery.py | 62 +- api/extensions/ext_redis.py | 179 +- api/mypy.ini | 3 +- api/pyproject.toml | 2 +- .../clean_workflow_runlogs_precise.py | 155 ++ api/services/dataset_service.py | 23 +- api/services/plugin/oauth_service.py | 4 +- api/tasks/deal_dataset_vector_index_task.py | 3 +- .../vdb/qdrant/test_qdrant.py | 9 + .../services/test_feature_service.py | 1785 +++++++++++++++++ .../services/test_metadata_service.py | 1144 +++++++++++ .../test_model_load_balancing_service.py | 474 +++++ .../unit_tests/extensions/test_celery_ssl.py | 149 ++ api/uv.lock | 19 +- docker/.env.example | 17 + docker/docker-compose.yaml | 7 + .../account/account-page/AvatarWithEdit.tsx | 51 +- .../config-prompt/simple-prompt-input.tsx | 16 +- .../config/automatic/get-automatic-res.tsx | 287 +-- .../config/automatic/idea-output.tsx | 48 + .../instruction-editor-in-workflow.tsx | 58 + .../config/automatic/instruction-editor.tsx | 117 ++ .../automatic/prompt-res-in-workflow.tsx | 55 + .../config/automatic/prompt-res.tsx | 31 + .../config/automatic/prompt-toast.tsx | 54 + .../config/automatic/res-placeholder.tsx | 18 + .../configuration/config/automatic/result.tsx | 97 + .../config/automatic/style.module.css | 8 +- .../configuration/config/automatic/types.ts | 4 + .../config/automatic/use-gen-data.ts | 36 + .../config/automatic/version-selector.tsx | 103 + .../code-generator/get-code-generator-res.tsx | 174 +- .../experience-enhance-group/index.tsx | 43 - .../more-like-this/index.tsx | 51 - web/app/components/app/log/list.tsx | 316 ++- .../base/chat/chat/loading-anim/index.tsx | 3 +- .../conversation-opener/modal.tsx | 26 +- .../vender/line/general/code-assistant.svg | 6 + .../assets/vender/line/general/magic-edit.svg | 6 + .../vender/line/general/CodeAssistant.json | 53 + .../src/vender/line/general/CodeAssistant.tsx | 20 + .../src/vender/line/general/MagicEdit.json | 55 + .../src/vender/line/general/MagicEdit.tsx | 20 + .../icons/src/vender/line/general/index.ts | 2 + .../base/prompt-editor/constants.tsx | 4 + .../components/base/prompt-editor/index.tsx | 67 +- .../plugins/component-picker-block/hooks.tsx | 54 +- .../plugins/component-picker-block/index.tsx | 38 +- .../plugins/current-block/component.tsx | 44 + .../current-block-replacement-block.tsx | 61 + .../plugins/current-block/index.tsx | 66 + .../plugins/current-block/node.tsx | 78 + .../plugins/error-message-block/component.tsx | 40 + .../error-message-block-replacement-block.tsx | 60 + .../plugins/error-message-block/index.tsx | 65 + .../plugins/error-message-block/node.tsx | 67 + .../plugins/last-run-block/component.tsx | 40 + .../plugins/last-run-block/index.tsx | 65 + .../last-run-block-replacement-block.tsx | 60 + .../plugins/last-run-block/node.tsx | 67 + .../components/base/prompt-editor/types.ts | 20 + .../documents/detail/completed/common/tag.tsx | 2 +- .../components/datasets/documents/list.tsx | 12 +- .../components/datasets/hit-testing/index.tsx | 2 +- web/app/components/explore/category.tsx | 2 +- .../actions/{ => commands}/command-bus.ts | 4 +- .../goto-anything/actions/commands/index.ts | 15 + .../actions/commands/language.tsx | 53 + .../actions/commands/registry.ts | 233 +++ .../goto-anything/actions/commands/slash.tsx | 52 + .../{run-theme.tsx => commands/theme.tsx} | 51 +- .../goto-anything/actions/commands/types.ts | 33 + .../components/goto-anything/actions/index.ts | 182 +- .../goto-anything/actions/run-language.tsx | 33 - .../components/goto-anything/actions/run.tsx | 97 - .../components/goto-anything/actions/types.ts | 2 +- .../goto-anything/command-selector.tsx | 2 +- web/app/components/goto-anything/index.tsx | 32 +- .../header/account-dropdown/support.tsx | 2 +- .../plugin-detail-panel/detail-header.tsx | 4 + .../plugins/reference-setting-modal/modal.tsx | 9 +- web/app/components/tools/labels/store.ts | 15 + .../workflow-header/app-publisher-trigger.tsx | 20 +- .../workflow-app/hooks/use-configs-map.ts | 1 + .../workflow-app/hooks/use-workflow-run.ts | 1 - .../components/workflow/hooks-store/store.ts | 1 + .../workflow/hooks/use-workflow-variables.ts | 3 +- .../nodes/_base/components/agent-strategy.tsx | 1 + .../components/code-generator-button.tsx | 16 +- .../nodes/_base/components/editor/base.tsx | 12 +- .../components/editor/code-editor/index.tsx | 3 + .../nodes/_base/components/prompt/editor.tsx | 10 +- .../variable/var-reference-vars.tsx | 75 +- .../nodes/_base/hooks/use-toggle-expend.ts | 6 +- .../components/workflow/nodes/code/panel.tsx | 1 + .../llm/components/config-prompt-item.tsx | 12 +- .../nodes/llm/components/config-prompt.tsx | 2 + .../llm/components/prompt-generator-btn.tsx | 16 +- .../panel/debug-and-preview/index.tsx | 2 +- web/app/components/workflow/types.ts | 1 + .../workflow/variable-inspect/panel.tsx | 1 + .../workflow/variable-inspect/right.tsx | 112 +- web/i18n/de-DE/app-debug.ts | 18 + web/i18n/de-DE/app.ts | 2 + web/i18n/de-DE/common.ts | 4 + web/i18n/de-DE/workflow.ts | 1 + web/i18n/en-US/app-debug.ts | 21 +- web/i18n/en-US/app.ts | 6 +- web/i18n/en-US/common.ts | 4 + web/i18n/en-US/workflow.ts | 1 + web/i18n/es-ES/app-debug.ts | 18 + web/i18n/es-ES/app.ts | 2 + web/i18n/es-ES/common.ts | 4 + web/i18n/es-ES/workflow.ts | 1 + web/i18n/fa-IR/app-debug.ts | 18 + web/i18n/fa-IR/app.ts | 2 + web/i18n/fa-IR/billing.ts | 40 +- web/i18n/fa-IR/common.ts | 4 + web/i18n/fa-IR/workflow.ts | 1 + web/i18n/fr-FR/app-debug.ts | 18 + web/i18n/fr-FR/app.ts | 2 + web/i18n/fr-FR/billing.ts | 34 +- web/i18n/fr-FR/common.ts | 4 + web/i18n/fr-FR/workflow.ts | 1 + web/i18n/hi-IN/app-debug.ts | 22 +- web/i18n/hi-IN/app.ts | 2 + web/i18n/hi-IN/billing.ts | 26 +- web/i18n/hi-IN/common.ts | 4 + web/i18n/hi-IN/workflow.ts | 1 + web/i18n/it-IT/app-debug.ts | 18 + web/i18n/it-IT/app.ts | 2 + web/i18n/it-IT/billing.ts | 42 +- web/i18n/it-IT/common.ts | 4 + web/i18n/it-IT/workflow.ts | 1 + web/i18n/ja-JP/app-annotation.ts | 2 +- web/i18n/ja-JP/app-debug.ts | 18 + web/i18n/ja-JP/app.ts | 2 + web/i18n/ja-JP/common.ts | 6 +- web/i18n/ja-JP/workflow.ts | 2 + web/i18n/ko-KR/app-debug.ts | 18 + web/i18n/ko-KR/app.ts | 2 + web/i18n/ko-KR/common.ts | 4 + web/i18n/ko-KR/workflow.ts | 1 + web/i18n/pl-PL/app-debug.ts | 18 + web/i18n/pl-PL/app.ts | 2 + web/i18n/pl-PL/billing.ts | 42 +- web/i18n/pl-PL/common.ts | 4 + web/i18n/pl-PL/workflow.ts | 1 + web/i18n/pt-BR/app-debug.ts | 18 + web/i18n/pt-BR/app.ts | 2 + web/i18n/pt-BR/billing.ts | 39 +- web/i18n/pt-BR/common.ts | 4 + web/i18n/pt-BR/workflow.ts | 1 + web/i18n/ro-RO/app-debug.ts | 18 + web/i18n/ro-RO/app.ts | 2 + web/i18n/ro-RO/billing.ts | 32 +- web/i18n/ro-RO/common.ts | 4 + web/i18n/ro-RO/workflow.ts | 1 + web/i18n/ru-RU/app-debug.ts | 18 + web/i18n/ru-RU/app.ts | 2 + web/i18n/ru-RU/billing.ts | 34 +- web/i18n/ru-RU/common.ts | 4 + web/i18n/ru-RU/workflow.ts | 1 + web/i18n/sl-SI/app-debug.ts | 18 + web/i18n/sl-SI/app.ts | 2 + web/i18n/sl-SI/billing.ts | 38 +- web/i18n/sl-SI/common.ts | 4 + web/i18n/sl-SI/workflow.ts | 1 + web/i18n/th-TH/app-debug.ts | 18 + web/i18n/th-TH/app.ts | 2 + web/i18n/th-TH/billing.ts | 40 +- web/i18n/th-TH/common.ts | 4 + web/i18n/th-TH/workflow.ts | 1 + web/i18n/tr-TR/app-debug.ts | 18 + web/i18n/tr-TR/app.ts | 3 + web/i18n/tr-TR/billing.ts | 34 +- web/i18n/tr-TR/common.ts | 4 + web/i18n/tr-TR/workflow.ts | 1 + web/i18n/uk-UA/app-debug.ts | 18 + web/i18n/uk-UA/app.ts | 2 + web/i18n/uk-UA/billing.ts | 36 +- web/i18n/uk-UA/common.ts | 4 + web/i18n/uk-UA/workflow.ts | 1 + web/i18n/vi-VN/app-debug.ts | 18 + web/i18n/vi-VN/app.ts | 2 + web/i18n/vi-VN/billing.ts | 36 +- web/i18n/vi-VN/common.ts | 4 + web/i18n/vi-VN/workflow.ts | 1 + web/i18n/zh-Hans/app-debug.ts | 23 +- web/i18n/zh-Hans/app.ts | 6 +- web/i18n/zh-Hans/common.ts | 4 + web/i18n/zh-Hans/workflow.ts | 1 + web/i18n/zh-Hant/app-debug.ts | 18 + web/i18n/zh-Hant/app.ts | 2 + web/i18n/zh-Hant/billing.ts | 16 + web/i18n/zh-Hant/common.ts | 4 + web/i18n/zh-Hant/workflow.ts | 1 + web/service/debug.ts | 20 +- web/service/use-apps.ts | 16 +- web/types/app.ts | 8 - 224 files changed, 8473 insertions(+), 1077 deletions(-) create mode 100644 CLAUDE.md rename api/events/{event_handlers => }/document_index_event.py (100%) create mode 100644 api/schedule/clean_workflow_runlogs_precise.py create mode 100644 api/tests/test_containers_integration_tests/services/test_feature_service.py create mode 100644 api/tests/test_containers_integration_tests/services/test_metadata_service.py create mode 100644 api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py create mode 100644 api/tests/unit_tests/extensions/test_celery_ssl.py create mode 100644 web/app/components/app/configuration/config/automatic/idea-output.tsx create mode 100644 web/app/components/app/configuration/config/automatic/instruction-editor-in-workflow.tsx create mode 100644 web/app/components/app/configuration/config/automatic/instruction-editor.tsx create mode 100644 web/app/components/app/configuration/config/automatic/prompt-res-in-workflow.tsx create mode 100644 web/app/components/app/configuration/config/automatic/prompt-res.tsx create mode 100644 web/app/components/app/configuration/config/automatic/prompt-toast.tsx create mode 100644 web/app/components/app/configuration/config/automatic/res-placeholder.tsx create mode 100644 web/app/components/app/configuration/config/automatic/result.tsx create mode 100644 web/app/components/app/configuration/config/automatic/types.ts create mode 100644 web/app/components/app/configuration/config/automatic/use-gen-data.ts create mode 100644 web/app/components/app/configuration/config/automatic/version-selector.tsx delete mode 100644 web/app/components/app/configuration/features/experience-enhance-group/index.tsx delete mode 100644 web/app/components/app/configuration/features/experience-enhance-group/more-like-this/index.tsx create mode 100644 web/app/components/base/icons/assets/vender/line/general/code-assistant.svg create mode 100644 web/app/components/base/icons/assets/vender/line/general/magic-edit.svg create mode 100644 web/app/components/base/icons/src/vender/line/general/CodeAssistant.json create mode 100644 web/app/components/base/icons/src/vender/line/general/CodeAssistant.tsx create mode 100644 web/app/components/base/icons/src/vender/line/general/MagicEdit.json create mode 100644 web/app/components/base/icons/src/vender/line/general/MagicEdit.tsx create mode 100644 web/app/components/base/prompt-editor/plugins/current-block/component.tsx create mode 100644 web/app/components/base/prompt-editor/plugins/current-block/current-block-replacement-block.tsx create mode 100644 web/app/components/base/prompt-editor/plugins/current-block/index.tsx create mode 100644 web/app/components/base/prompt-editor/plugins/current-block/node.tsx create mode 100644 web/app/components/base/prompt-editor/plugins/error-message-block/component.tsx create mode 100644 web/app/components/base/prompt-editor/plugins/error-message-block/error-message-block-replacement-block.tsx create mode 100644 web/app/components/base/prompt-editor/plugins/error-message-block/index.tsx create mode 100644 web/app/components/base/prompt-editor/plugins/error-message-block/node.tsx create mode 100644 web/app/components/base/prompt-editor/plugins/last-run-block/component.tsx create mode 100644 web/app/components/base/prompt-editor/plugins/last-run-block/index.tsx create mode 100644 web/app/components/base/prompt-editor/plugins/last-run-block/last-run-block-replacement-block.tsx create mode 100644 web/app/components/base/prompt-editor/plugins/last-run-block/node.tsx rename web/app/components/goto-anything/actions/{ => commands}/command-bus.ts (82%) create mode 100644 web/app/components/goto-anything/actions/commands/index.ts create mode 100644 web/app/components/goto-anything/actions/commands/language.tsx create mode 100644 web/app/components/goto-anything/actions/commands/registry.ts create mode 100644 web/app/components/goto-anything/actions/commands/slash.tsx rename web/app/components/goto-anything/actions/{run-theme.tsx => commands/theme.tsx} (57%) create mode 100644 web/app/components/goto-anything/actions/commands/types.ts delete mode 100644 web/app/components/goto-anything/actions/run-language.tsx delete mode 100644 web/app/components/goto-anything/actions/run.tsx create mode 100644 web/app/components/tools/labels/store.ts diff --git a/.gitignore b/.gitignore index 5c68d89a4d..30432c4302 100644 --- a/.gitignore +++ b/.gitignore @@ -197,6 +197,8 @@ sdks/python-client/dify_client.egg-info !.vscode/README.md pyrightconfig.json api/.vscode +# vscode Code History Extension +.history .idea/ diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000000..7ce04382c9 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,83 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management. + +The codebase consists of: +- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture +- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19 +- **Docker deployment** (`/docker`): Containerized deployment configurations + +## Development Commands + +### Backend (API) + +All Python commands must be prefixed with `uv run --project api`: + +```bash +# Start development servers +./dev/start-api # Start API server +./dev/start-worker # Start Celery worker + +# Run tests +uv run --project api pytest # Run all tests +uv run --project api pytest tests/unit_tests/ # Unit tests only +uv run --project api pytest tests/integration_tests/ # Integration tests + +# Code quality +./dev/reformat # Run all formatters and linters +uv run --project api ruff check --fix ./ # Fix linting issues +uv run --project api ruff format ./ # Format code +uv run --project api mypy . # Type checking +``` + +### Frontend (Web) + +```bash +cd web +pnpm lint # Run ESLint +pnpm eslint-fix # Fix ESLint issues +pnpm test # Run Jest tests +``` + +## Testing Guidelines + +### Backend Testing +- Use `pytest` for all backend tests +- Write tests first (TDD approach) +- Test structure: Arrange-Act-Assert + +## Code Style Requirements + +### Python +- Use type hints for all functions and class attributes +- No `Any` types unless absolutely necessary +- Implement special methods (`__repr__`, `__str__`) appropriately + +### TypeScript/JavaScript +- Strict TypeScript configuration +- ESLint with Prettier integration +- Avoid `any` type + +## Important Notes + +- **Environment Variables**: Always use UV for Python commands: `uv run --project api ` +- **Comments**: Only write meaningful comments that explain "why", not "what" +- **File Creation**: Always prefer editing existing files over creating new ones +- **Documentation**: Don't create documentation files unless explicitly requested +- **Code Quality**: Always run `./dev/reformat` before committing backend changes + +## Common Development Tasks + +### Adding a New API Endpoint +1. Create controller in `/api/controllers/` +2. Add service logic in `/api/services/` +3. Update routes in controller's `__init__.py` +4. Write tests in `/api/tests/` + +## Project-Specific Conventions + +- All async tasks use Celery with Redis as broker diff --git a/api/.env.example b/api/.env.example index 4beabfecea..3052dbfe2b 100644 --- a/api/.env.example +++ b/api/.env.example @@ -42,6 +42,15 @@ REDIS_PORT=6379 REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false +# SSL configuration for Redis (when REDIS_USE_SSL=true) +REDIS_SSL_CERT_REQS=CERT_NONE +# Options: CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED +REDIS_SSL_CA_CERTS= +# Path to CA certificate file for SSL verification +REDIS_SSL_CERTFILE= +# Path to client certificate file for SSL authentication +REDIS_SSL_KEYFILE= +# Path to client private key file for SSL authentication REDIS_DB=0 # redis Sentinel configuration. @@ -469,6 +478,13 @@ API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node # API workflow run repository implementation API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository +# Workflow log cleanup configuration +# Enable automatic cleanup of workflow run logs to manage database size +WORKFLOW_LOG_CLEANUP_ENABLED=true +# Number of days to retain workflow run logs (default: 30 days) +WORKFLOW_LOG_RETENTION_DAYS=30 +# Batch size for workflow log cleanup operations (default: 100) +WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100 # App configuration APP_MAX_EXECUTION_TIME=1200 diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 0b2f99aece..2bccc4b7a0 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -968,6 +968,14 @@ class AccountConfig(BaseSettings): ) +class WorkflowLogConfig(BaseSettings): + WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=True, description="Enable workflow run log cleanup") + WORKFLOW_LOG_RETENTION_DAYS: int = Field(default=30, description="Retention days for workflow run logs") + WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field( + default=100, description="Batch size for workflow run log cleanup operations" + ) + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -1003,5 +1011,6 @@ class FeatureConfig( HostedServiceConfig, CeleryBeatConfig, CeleryScheduleTasksConfig, + WorkflowLogConfig, ): pass diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 916f52e165..16dca98cfa 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -39,6 +39,26 @@ class RedisConfig(BaseSettings): default=False, ) + REDIS_SSL_CERT_REQS: str = Field( + description="SSL certificate requirements (CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED)", + default="CERT_NONE", + ) + + REDIS_SSL_CA_CERTS: Optional[str] = Field( + description="Path to the CA certificate file for SSL verification", + default=None, + ) + + REDIS_SSL_CERTFILE: Optional[str] = Field( + description="Path to the client certificate file for SSL authentication", + default=None, + ) + + REDIS_SSL_KEYFILE: Optional[str] = Field( + description="Path to the client private key file for SSL authentication", + default=None, + ) + REDIS_USE_SENTINEL: Optional[bool] = Field( description="Enable Redis Sentinel mode for high availability", default=False, diff --git a/api/controllers/common/helpers.py b/api/controllers/common/helpers.py index 008f1f0f7a..6a5197635e 100644 --- a/api/controllers/common/helpers.py +++ b/api/controllers/common/helpers.py @@ -1,3 +1,4 @@ +import contextlib import mimetypes import os import platform @@ -65,10 +66,8 @@ def guess_file_info_from_response(response: httpx.Response): # Use python-magic to guess MIME type if still unknown or generic if mimetype == "application/octet-stream" and magic is not None: - try: + with contextlib.suppress(magic.MagicException): mimetype = magic.from_buffer(response.content[:1024], mime=True) - except magic.MagicException: - pass extension = os.path.splitext(filename)[1] diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 493a9a52e2..2caa908d4a 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,3 +1,5 @@ +from typing import Literal + from flask import request from flask_login import current_user from flask_restful import Resource, marshal, marshal_with, reqparse @@ -24,7 +26,7 @@ class AnnotationReplyActionApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("annotation") - def post(self, app_id, action): + def post(self, app_id, action: Literal["enable", "disable"]): if not current_user.is_editor: raise Forbidden() @@ -38,8 +40,6 @@ class AnnotationReplyActionApi(Resource): result = AppAnnotationService.enable_app_annotation(args, app_id) elif action == "disable": result = AppAnnotationService.disable_app_annotation(app_id) - else: - raise ValueError("Unsupported annotation reply action") return result, 200 diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 4847a2cab8..2a81d1b645 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,3 +1,5 @@ +from collections.abc import Sequence + from flask_login import current_user from flask_restful import Resource, reqparse @@ -10,6 +12,9 @@ from controllers.console.app.error import ( ) from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.helper.code_executor.code_node_provider import CodeNodeProvider +from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider +from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required @@ -107,6 +112,119 @@ class RuleStructuredOutputGenerateApi(Resource): return structured_output +class InstructionGenerateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("flow_id", type=str, required=True, default="", location="json") + parser.add_argument("node_id", type=str, required=False, default="", location="json") + parser.add_argument("current", type=str, required=False, default="", location="json") + parser.add_argument("language", type=str, required=False, default="javascript", location="json") + parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") + parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") + parser.add_argument("ideal_output", type=str, required=False, default="", location="json") + args = parser.parse_args() + providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] + code_provider: type[CodeNodeProvider] | None = next( + (p for p in providers if p.is_accept_language(args["language"])), None + ) + code_template = code_provider.get_default_code() if code_provider else "" + try: + # Generate from nothing for a workflow node + if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "": + from models import App, db + from services.workflow_service import WorkflowService + + app = db.session.query(App).filter(App.id == args["flow_id"]).first() + if not app: + return {"error": f"app {args['flow_id']} not found"}, 400 + workflow = WorkflowService().get_draft_workflow(app_model=app) + if not workflow: + return {"error": f"workflow {args['flow_id']} not found"}, 400 + nodes: Sequence = workflow.graph_dict["nodes"] + node = [node for node in nodes if node["id"] == args["node_id"]] + if len(node) == 0: + return {"error": f"node {args['node_id']} not found"}, 400 + node_type = node[0]["data"]["type"] + match node_type: + case "llm": + return LLMGenerator.generate_rule_config( + current_user.current_tenant_id, + instruction=args["instruction"], + model_config=args["model_config"], + no_variable=True, + ) + case "agent": + return LLMGenerator.generate_rule_config( + current_user.current_tenant_id, + instruction=args["instruction"], + model_config=args["model_config"], + no_variable=True, + ) + case "code": + return LLMGenerator.generate_code( + tenant_id=current_user.current_tenant_id, + instruction=args["instruction"], + model_config=args["model_config"], + code_language=args["language"], + ) + case _: + return {"error": f"invalid node type: {node_type}"} + if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow + return LLMGenerator.instruction_modify_legacy( + tenant_id=current_user.current_tenant_id, + flow_id=args["flow_id"], + current=args["current"], + instruction=args["instruction"], + model_config=args["model_config"], + ideal_output=args["ideal_output"], + ) + if args["node_id"] != "" and args["current"] != "": # For workflow node + return LLMGenerator.instruction_modify_workflow( + tenant_id=current_user.current_tenant_id, + flow_id=args["flow_id"], + node_id=args["node_id"], + current=args["current"], + instruction=args["instruction"], + model_config=args["model_config"], + ideal_output=args["ideal_output"], + ) + return {"error": "incompatible parameters"}, 400 + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + + +class InstructionGenerationTemplateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self) -> dict: + parser = reqparse.RequestParser() + parser.add_argument("type", type=str, required=True, default=False, location="json") + args = parser.parse_args() + match args["type"]: + case "prompt": + from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT + + return {"data": INSTRUCTION_GENERATE_TEMPLATE_PROMPT} + case "code": + from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_CODE + + return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE} + case _: + raise ValueError(f"Invalid type: {args['type']}") + + api.add_resource(RuleGenerateApi, "/rule-generate") api.add_resource(RuleCodeGenerateApi, "/rule-code-generate") api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate") +api.add_resource(InstructionGenerateApi, "/instruction-generate") +api.add_resource(InstructionGenerationTemplateApi, "/instruction-generate/template") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 4e0955bd43..413b018baa 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,6 +1,6 @@ import logging from argparse import ArgumentTypeError -from typing import cast +from typing import Literal, cast from flask import request from flask_login import current_user @@ -758,7 +758,7 @@ class DocumentProcessingApi(DocumentResource): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def patch(self, dataset_id, document_id, action): + def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]): dataset_id = str(dataset_id) document_id = str(document_id) document = self.get_document(dataset_id, document_id) @@ -784,8 +784,6 @@ class DocumentProcessingApi(DocumentResource): document.paused_at = None document.is_paused = False db.session.commit() - else: - raise InvalidActionError() return {"result": "success"}, 200 @@ -840,7 +838,7 @@ class DocumentStatusApi(DocumentResource): @account_initialization_required @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") - def patch(self, dataset_id, action): + def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if dataset is None: diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 65f76fb402..1b5570285d 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -1,3 +1,5 @@ +from typing import Literal + from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse from werkzeug.exceptions import NotFound @@ -100,7 +102,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): @login_required @account_initialization_required @enterprise_license_required - def post(self, dataset_id, action): + def post(self, dataset_id, action: Literal["enable", "disable"]): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 9b22c535f4..23446bb702 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -1,3 +1,5 @@ +from typing import Literal + from flask import request from flask_restful import Resource, marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden @@ -15,7 +17,7 @@ from services.annotation_service import AppAnnotationService class AnnotationReplyActionApi(Resource): @validate_app_token - def post(self, app_model: App, action): + def post(self, app_model: App, action: Literal["enable", "disable"]): parser = reqparse.RequestParser() parser.add_argument("score_threshold", required=True, type=float, location="json") parser.add_argument("embedding_provider_name", required=True, type=str, location="json") @@ -25,8 +27,6 @@ class AnnotationReplyActionApi(Resource): result = AppAnnotationService.enable_app_annotation(args, app_model.id) elif action == "disable": result = AppAnnotationService.disable_app_annotation(app_model.id) - else: - raise ValueError("Unsupported annotation reply action") return result, 200 diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 29eef41253..35b1efeff6 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,3 +1,5 @@ +from typing import Literal + from flask import request from flask_restful import marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound @@ -358,14 +360,14 @@ class DatasetApi(DatasetApiResource): class DocumentStatusApi(DatasetApiResource): """Resource for batch document status operations.""" - def patch(self, tenant_id, dataset_id, action): + def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]): """ Batch update document status. Args: tenant_id: tenant id dataset_id: dataset id - action: action to perform (enable, disable, archive, un_archive) + action: action to perform (Literal["enable", "disable", "archive", "un_archive"]) Returns: dict: A dictionary with a key 'result' and a value 'success' diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 6ba818c5fc..75a0b18285 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -1,3 +1,5 @@ +from typing import Literal + from flask_login import current_user # type: ignore from flask_restful import marshal, reqparse from werkzeug.exceptions import NotFound @@ -77,7 +79,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource): class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id, action): + def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 0c76cc39ae..c273776eb1 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -140,7 +140,9 @@ class ChatAppGenerator(MessageBasedAppGenerator): ) # get tracing instance - trace_manager = TraceQueueManager(app_id=app_model.id) + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) # init application generate entity application_generate_entity = ChatAppGenerateEntity( diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 9356bd1cea..64dade2968 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -124,7 +124,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator): ) # get tracing instance - trace_manager = TraceQueueManager(app_model.id) + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) # init application generate entity application_generate_entity = CompletionAppGenerateEntity( diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index 8507f23f17..4100a0d5a9 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -6,7 +6,6 @@ from core.app.entities.queue_entities import ( MessageQueueMessage, QueueAdvancedChatMessageEndEvent, QueueErrorEvent, - QueueMessage, QueueMessageEndEvent, QueueStopEvent, ) @@ -22,15 +21,6 @@ class MessageBasedAppQueueManager(AppQueueManager): self._app_mode = app_mode self._message_id = str(message_id) - def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: - return MessageQueueMessage( - task_id=self._task_id, - message_id=self._message_id, - conversation_id=self._conversation_id, - app_mode=self._app_mode, - event=event, - ) - def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: """ Publish event to queue diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 47e5a79160..64fc3a3e80 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -1,6 +1,7 @@ import json import logging import re +from collections.abc import Sequence from typing import Optional, cast import json_repair @@ -11,6 +12,8 @@ from core.llm_generator.prompts import ( CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT, JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE, + LLM_MODIFY_CODE_SYSTEM, + LLM_MODIFY_PROMPT_SYSTEM, PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE, SYSTEM_STRUCTURED_OUTPUT_GENERATE, WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, @@ -24,6 +27,9 @@ from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey +from core.workflow.graph_engine.entities.event import AgentLogEvent +from models import App, Message, WorkflowNodeExecutionModel, db class LLMGenerator: @@ -388,3 +394,181 @@ class LLMGenerator: except Exception as e: logging.exception("Failed to invoke LLM model, model: %s", model_config.get("name")) return {"output": "", "error": f"An unexpected error occurred: {str(e)}"} + + @staticmethod + def instruction_modify_legacy( + tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None + ) -> dict: + app: App | None = db.session.query(App).filter(App.id == flow_id).first() + last_run: Message | None = ( + db.session.query(Message).filter(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() + ) + if not last_run: + return LLMGenerator.__instruction_modify_common( + tenant_id=tenant_id, + model_config=model_config, + last_run=None, + current=current, + error_message="", + instruction=instruction, + node_type="llm", + ideal_output=ideal_output, + ) + last_run_dict = { + "query": last_run.query, + "answer": last_run.answer, + "error": last_run.error, + } + return LLMGenerator.__instruction_modify_common( + tenant_id=tenant_id, + model_config=model_config, + last_run=last_run_dict, + current=current, + error_message=str(last_run.error), + instruction=instruction, + node_type="llm", + ideal_output=ideal_output, + ) + + @staticmethod + def instruction_modify_workflow( + tenant_id: str, + flow_id: str, + node_id: str, + current: str, + instruction: str, + model_config: dict, + ideal_output: str | None, + ) -> dict: + from services.workflow_service import WorkflowService + + app: App | None = db.session.query(App).filter(App.id == flow_id).first() + if not app: + raise ValueError("App not found.") + workflow = WorkflowService().get_draft_workflow(app_model=app) + if not workflow: + raise ValueError("Workflow not found for the given app model.") + last_run = WorkflowService().get_node_last_run(app_model=app, workflow=workflow, node_id=node_id) + try: + node_type = cast(WorkflowNodeExecutionModel, last_run).node_type + except Exception: + try: + node_type = [it for it in workflow.graph_dict["graph"]["nodes"] if it["id"] == node_id][0]["data"][ + "type" + ] + except Exception: + node_type = "llm" + + if not last_run: # Node is not executed yet + return LLMGenerator.__instruction_modify_common( + tenant_id=tenant_id, + model_config=model_config, + last_run=None, + current=current, + error_message="", + instruction=instruction, + node_type=node_type, + ideal_output=ideal_output, + ) + + def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence: + raw_agent_log = node_execution.execution_metadata_dict.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG) + if not raw_agent_log: + return [] + parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log) + + def dict_of_event(event: AgentLogEvent) -> dict: + return { + "status": event.status, + "error": event.error, + "data": event.data, + } + + return [dict_of_event(event) for event in parsed] + + last_run_dict = { + "inputs": last_run.inputs_dict, + "status": last_run.status, + "error": last_run.error, + "agent_log": agent_log_of(last_run), + } + + return LLMGenerator.__instruction_modify_common( + tenant_id=tenant_id, + model_config=model_config, + last_run=last_run_dict, + current=current, + error_message=last_run.error, + instruction=instruction, + node_type=last_run.node_type, + ideal_output=ideal_output, + ) + + @staticmethod + def __instruction_modify_common( + tenant_id: str, + model_config: dict, + last_run: dict | None, + current: str | None, + error_message: str | None, + instruction: str, + node_type: str, + ideal_output: str | None, + ) -> dict: + LAST_RUN = "{{#last_run#}}" + CURRENT = "{{#current#}}" + ERROR_MESSAGE = "{{#error_message#}}" + injected_instruction = instruction + if LAST_RUN in injected_instruction: + injected_instruction = injected_instruction.replace(LAST_RUN, json.dumps(last_run)) + if CURRENT in injected_instruction: + injected_instruction = injected_instruction.replace(CURRENT, current or "null") + if ERROR_MESSAGE in injected_instruction: + injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null") + model_instance = ModelManager().get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), + ) + match node_type: + case "llm", "agent": + system_prompt = LLM_MODIFY_PROMPT_SYSTEM + case "code": + system_prompt = LLM_MODIFY_CODE_SYSTEM + case _: + system_prompt = LLM_MODIFY_PROMPT_SYSTEM + prompt_messages = [ + SystemPromptMessage(content=system_prompt), + UserPromptMessage( + content=json.dumps( + { + "current": current, + "last_run": last_run, + "instruction": injected_instruction, + "ideal_output": ideal_output, + } + ) + ), + ] + model_parameters = {"temperature": 0.4} + + try: + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + ), + ) + + generated_raw = cast(str, response.message.content) + first_brace = generated_raw.find("{") + last_brace = generated_raw.rfind("}") + return {**json.loads(generated_raw[first_brace : last_brace + 1])} + + except InvokeError as e: + error = str(e) + return {"error": f"Failed to generate code. Error: {error}"} + except Exception as e: + logging.exception("Failed to invoke LLM model, model: " + json.dumps(model_config.get("name")), exc_info=e) + return {"error": f"An unexpected error occurred: {str(e)}"} diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index ef81e38dc5..e38828578a 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -309,3 +309,116 @@ eg: Here is the JSON schema: {{schema}} """ # noqa: E501 + +LLM_MODIFY_PROMPT_SYSTEM = """ +Both your input and output should be in JSON format. + +! Below is the schema for input content ! +{ + "type": "object", + "description": "The user is trying to process some content with a prompt, but the output is not as expected. They hope to achieve their goal by modifying the prompt.", + "properties": { + "current": { + "type": "string", + "description": "The prompt before modification, where placeholders {{}} will be replaced with actual values for the large language model. The content in the placeholders should not be changed." + }, + "last_run": { + "type": "object", + "description": "The output result from the large language model after receiving the prompt.", + }, + "instruction": { + "type": "string", + "description": "User's instruction to edit the current prompt" + }, + "ideal_output": { + "type": "string", + "description": "The ideal output that the user expects from the large language model after modifying the prompt. You should compare the last output with the ideal output and make changes to the prompt to achieve the goal." + } + } +} +! Above is the schema for input content ! + +! Below is the schema for output content ! +{ + "type": "object", + "description": "Your feedback to the user after they provide modification suggestions.", + "properties": { + "modified": { + "type": "string", + "description": "Your modified prompt. You should change the original prompt as little as possible to achieve the goal. Keep the language of prompt if not asked to change" + }, + "message": { + "type": "string", + "description": "Your feedback to the user, in the user's language, explaining what you did and your thought process in text, providing sufficient emotional value to the user." + } + }, + "required": [ + "modified", + "message" + ] +} +! Above is the schema for output content ! + +Your output must strictly follow the schema format, do not output any content outside of the JSON body. +""" # noqa: E501 + +LLM_MODIFY_CODE_SYSTEM = """ +Both your input and output should be in JSON format. + +! Below is the schema for input content ! +{ + "type": "object", + "description": "The user is trying to process some data with a code snippet, but the result is not as expected. They hope to achieve their goal by modifying the code.", + "properties": { + "current": { + "type": "string", + "description": "The code before modification." + }, + "last_run": { + "type": "object", + "description": "The result of the code.", + }, + "message": { + "type": "string", + "description": "User's instruction to edit the current code" + } + } +} +! Above is the schema for input content ! + +! Below is the schema for output content ! +{ + "type": "object", + "description": "Your feedback to the user after they provide modification suggestions.", + "properties": { + "modified": { + "type": "string", + "description": "Your modified code. You should change the original code as little as possible to achieve the goal. Keep the programming language of code if not asked to change" + }, + "message": { + "type": "string", + "description": "Your feedback to the user, in the user's language, explaining what you did and your thought process in text, providing sufficient emotional value to the user." + } + }, + "required": [ + "modified", + "message" + ] +} +! Above is the schema for output content ! + +When you are modifying the code, you should remember: +- Do not use print, this not work in dify sandbox. +- Do not try dangerous call like deleting files. It's PROHIBITED. +- Do not use any library that is not built-in in with Python. +- Get inputs from the parameters of the function and have explicit type annotations. +- Write proper imports at the top of the code. +- Use return statement to return the result. +- You should return a `dict`. +Your output must strictly follow the schema format, do not output any content outside of the JSON body. +""" # noqa: E501 + +INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as expected: {{#last_run#}}. +You should edit the prompt according to the IDEAL OUTPUT.""" + +INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}.""" diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py index d6dfe967d7..8efe105bbf 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -4,8 +4,8 @@ import math from typing import Any from pydantic import BaseModel, model_validator -from pyobvector import VECTOR, ObVecClient # type: ignore -from sqlalchemy import JSON, Column, String, func +from pyobvector import VECTOR, FtsIndexParam, FtsParser, ObVecClient, l2_distance # type: ignore +from sqlalchemy import JSON, Column, String from sqlalchemy.dialects.mysql import LONGTEXT from configs import dify_config @@ -119,14 +119,21 @@ class OceanBaseVector(BaseVector): ) try: if self._hybrid_search_enabled: - self._client.perform_raw_text_sql(f"""ALTER TABLE {self._collection_name} - ADD FULLTEXT INDEX fulltext_index_for_col_text (text) WITH PARSER ik""") + self._client.create_fts_idx_with_fts_index_param( + table_name=self._collection_name, + fts_idx_param=FtsIndexParam( + index_name="fulltext_index_for_col_text", + field_names=["text"], + parser_type=FtsParser.IK, + ), + ) except Exception as e: raise Exception( "Failed to add fulltext index to the target table, your OceanBase version must be 4.3.5.1 or above " + "to support fulltext index and vector index in the same table", e, ) + self._client.refresh_metadata([self._collection_name]) redis_client.set(collection_exist_cache_key, 1, ex=3600) def _check_hybrid_search_support(self) -> bool: @@ -252,7 +259,7 @@ class OceanBaseVector(BaseVector): vec_column_name="vector", vec_data=query_vector, topk=topk, - distance_func=func.l2_distance, + distance_func=l2_distance, output_column_names=["text", "metadata"], with_dist=True, where_clause=_where_clause, diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 9741dd8b1d..fcf3a6d126 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -331,6 +331,12 @@ class QdrantVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from qdrant_client.http import models + score_threshold = float(kwargs.get("score_threshold") or 0.0) + if score_threshold >= 1: + # return empty list because some versions of qdrant may response with 400 bad request, + # and at the same time, the score_threshold with value 1 may be valid for other vector stores + return [] + filter = models.Filter( must=[ models.FieldCondition( @@ -355,7 +361,7 @@ class QdrantVector(BaseVector): limit=kwargs.get("top_k", 4), with_payload=True, with_vectors=True, - score_threshold=float(kwargs.get("score_threshold") or 0.0), + score_threshold=score_threshold, ) docs = [] for result in results: @@ -363,7 +369,6 @@ class QdrantVector(BaseVector): continue metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold - score_threshold = float(kwargs.get("score_threshold") or 0.0) if result.score > score_threshold: metadata["score"] = result.score doc = Document( diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 5377cbbb69..5ffba07b44 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -108,10 +108,18 @@ class ApiProviderAuthType(Enum): :param value: mode value :return: mode """ + # 'api_key' deprecated in PR #21656 + # normalize & tiny alias for backward compatibility + v = (value or "").strip().lower() + if v == "api_key": + v = cls.API_KEY_HEADER.value + for mode in cls: - if mode.value == value: + if mode.value == v: return mode - raise ValueError(f"invalid mode value {value}") + + valid = ", ".join(m.value for m in cls) + raise ValueError(f"invalid mode value '{value}', expected one of: {valid}") class ToolInvokeMessage(BaseModel): diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index df89b2476d..4c8e13de70 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -318,33 +318,6 @@ class ToolNode(BaseNode): json.append(message.message.json_object) elif message.type == ToolInvokeMessage.MessageType.LINK: assert isinstance(message.message, ToolInvokeMessage.TextMessage) - - if message.meta: - transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) - else: - transfer_method = FileTransferMethod.TOOL_FILE - - tool_file_id = message.message.text.split("/")[-1].split(".")[0] - - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == tool_file_id) - tool_file = session.scalar(stmt) - if tool_file is None: - raise ToolFileError(f"Tool file {tool_file_id} does not exist") - - mapping = { - "tool_file_id": tool_file_id, - "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), - "transfer_method": transfer_method, - "url": message.message.text, - } - - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=self.tenant_id, - ) - files.append(file) - stream_text = f"Link: {message.message.text}\n" text += stream_text yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"]) diff --git a/api/events/event_handlers/document_index_event.py b/api/events/document_index_event.py similarity index 100% rename from api/events/event_handlers/document_index_event.py rename to api/events/document_index_event.py diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index bdb69945f0..c607161e2a 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -5,7 +5,7 @@ import click from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner -from events.event_handlers.document_index_event import document_index_created +from events.document_index_event import document_index_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Document diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index bd72c93404..00e0bd9a16 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -1,4 +1,6 @@ +import ssl from datetime import timedelta +from typing import Any, Optional import pytz from celery import Celery, Task # type: ignore @@ -8,6 +10,40 @@ from configs import dify_config from dify_app import DifyApp +def _get_celery_ssl_options() -> Optional[dict[str, Any]]: + """Get SSL configuration for Celery broker/backend connections.""" + # Use REDIS_USE_SSL for consistency with the main Redis client + # Only apply SSL if we're using Redis as broker/backend + if not dify_config.REDIS_USE_SSL: + return None + + # Check if Celery is actually using Redis + broker_is_redis = dify_config.CELERY_BROKER_URL and ( + dify_config.CELERY_BROKER_URL.startswith("redis://") or dify_config.CELERY_BROKER_URL.startswith("rediss://") + ) + + if not broker_is_redis: + return None + + # Map certificate requirement strings to SSL constants + cert_reqs_map = { + "CERT_NONE": ssl.CERT_NONE, + "CERT_OPTIONAL": ssl.CERT_OPTIONAL, + "CERT_REQUIRED": ssl.CERT_REQUIRED, + } + + ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE) + + ssl_options = { + "ssl_cert_reqs": ssl_cert_reqs, + "ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS, + "ssl_certfile": dify_config.REDIS_SSL_CERTFILE, + "ssl_keyfile": dify_config.REDIS_SSL_KEYFILE, + } + + return ssl_options + + def init_app(app: DifyApp) -> Celery: class FlaskTask(Task): def __call__(self, *args: object, **kwargs: object) -> object: @@ -33,14 +69,6 @@ def init_app(app: DifyApp) -> Celery: task_ignore_result=True, ) - # Add SSL options to the Celery configuration - ssl_options = { - "ssl_cert_reqs": None, - "ssl_ca_certs": None, - "ssl_certfile": None, - "ssl_keyfile": None, - } - celery_app.conf.update( result_backend=dify_config.CELERY_RESULT_BACKEND, broker_transport_options=broker_transport_options, @@ -51,9 +79,13 @@ def init_app(app: DifyApp) -> Celery: timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"), ) - if dify_config.BROKER_USE_SSL: + # Apply SSL configuration if enabled + ssl_options = _get_celery_ssl_options() + if ssl_options: celery_app.conf.update( - broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration + broker_use_ssl=ssl_options, + # Also apply SSL to the backend if it's Redis + redis_backend_use_ssl=ssl_options if dify_config.CELERY_BACKEND == "redis" else None, ) if dify_config.LOG_FILE: @@ -113,13 +145,19 @@ def init_app(app: DifyApp) -> Celery: minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30 ), } - if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: + if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED: imports.append("schedule.check_upgradable_plugin_task") beat_schedule["check_upgradable_plugin_task"] = { "task": "schedule.check_upgradable_plugin_task.check_upgradable_plugin_task", "schedule": crontab(minute="*/15"), } - + if dify_config.WORKFLOW_LOG_CLEANUP_ENABLED: + # 2:00 AM every day + imports.append("schedule.clean_workflow_runlogs_precise") + beat_schedule["clean_workflow_runlogs_precise"] = { + "task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise", + "schedule": crontab(minute="0", hour="2"), + } celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) return celery_app diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 914d6219cf..f5f544679f 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -1,5 +1,6 @@ import functools import logging +import ssl from collections.abc import Callable from datetime import timedelta from typing import TYPE_CHECKING, Any, Union @@ -116,76 +117,132 @@ class RedisClientWrapper: redis_client: RedisClientWrapper = RedisClientWrapper() -def init_app(app: DifyApp): - global redis_client - connection_class: type[Union[Connection, SSLConnection]] = Connection - if dify_config.REDIS_USE_SSL: - connection_class = SSLConnection - resp_protocol = dify_config.REDIS_SERIALIZATION_PROTOCOL - if dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE: - if resp_protocol >= 3: - clientside_cache_config = CacheConfig() - else: - raise ValueError("Client side cache is only supported in RESP3") - else: - clientside_cache_config = None +def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]: + """Get SSL configuration for Redis connection.""" + if not dify_config.REDIS_USE_SSL: + return Connection, {} - redis_params: dict[str, Any] = { + cert_reqs_map = { + "CERT_NONE": ssl.CERT_NONE, + "CERT_OPTIONAL": ssl.CERT_OPTIONAL, + "CERT_REQUIRED": ssl.CERT_REQUIRED, + } + ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE) + + ssl_kwargs = { + "ssl_cert_reqs": ssl_cert_reqs, + "ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS, + "ssl_certfile": dify_config.REDIS_SSL_CERTFILE, + "ssl_keyfile": dify_config.REDIS_SSL_KEYFILE, + } + + return SSLConnection, ssl_kwargs + + +def _get_cache_configuration() -> CacheConfig | None: + """Get client-side cache configuration if enabled.""" + if not dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE: + return None + + resp_protocol = dify_config.REDIS_SERIALIZATION_PROTOCOL + if resp_protocol < 3: + raise ValueError("Client side cache is only supported in RESP3") + + return CacheConfig() + + +def _get_base_redis_params() -> dict[str, Any]: + """Get base Redis connection parameters.""" + return { "username": dify_config.REDIS_USERNAME, - "password": dify_config.REDIS_PASSWORD or None, # Temporary fix for empty password + "password": dify_config.REDIS_PASSWORD or None, "db": dify_config.REDIS_DB, "encoding": "utf-8", "encoding_errors": "strict", "decode_responses": False, - "protocol": resp_protocol, - "cache_config": clientside_cache_config, + "protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, + "cache_config": _get_cache_configuration(), } - if dify_config.REDIS_USE_SENTINEL: - assert dify_config.REDIS_SENTINELS is not None, "REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True" - assert dify_config.REDIS_SENTINEL_SERVICE_NAME is not None, ( - "REDIS_SENTINEL_SERVICE_NAME must be set when REDIS_USE_SENTINEL is True" - ) - sentinel_hosts = [ - (node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",") - ] - sentinel = Sentinel( - sentinel_hosts, - sentinel_kwargs={ - "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, - "username": dify_config.REDIS_SENTINEL_USERNAME, - "password": dify_config.REDIS_SENTINEL_PASSWORD, - }, - ) - master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) - redis_client.initialize(master) - elif dify_config.REDIS_USE_CLUSTERS: - assert dify_config.REDIS_CLUSTERS is not None, "REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True" - nodes = [ - ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1])) - for node in dify_config.REDIS_CLUSTERS.split(",") - ] - redis_client.initialize( - RedisCluster( - startup_nodes=nodes, - password=dify_config.REDIS_CLUSTERS_PASSWORD, - protocol=resp_protocol, - cache_config=clientside_cache_config, - ) - ) - else: - redis_params.update( - { - "host": dify_config.REDIS_HOST, - "port": dify_config.REDIS_PORT, - "connection_class": connection_class, - "protocol": resp_protocol, - "cache_config": clientside_cache_config, - } - ) - pool = redis.ConnectionPool(**redis_params) - redis_client.initialize(redis.Redis(connection_pool=pool)) +def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]: + """Create Redis client using Sentinel configuration.""" + if not dify_config.REDIS_SENTINELS: + raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True") + + if not dify_config.REDIS_SENTINEL_SERVICE_NAME: + raise ValueError("REDIS_SENTINEL_SERVICE_NAME must be set when REDIS_USE_SENTINEL is True") + + sentinel_hosts = [(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")] + + sentinel = Sentinel( + sentinel_hosts, + sentinel_kwargs={ + "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, + "username": dify_config.REDIS_SENTINEL_USERNAME, + "password": dify_config.REDIS_SENTINEL_PASSWORD, + }, + ) + + master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) + return master + + +def _create_cluster_client() -> Union[redis.Redis, RedisCluster]: + """Create Redis cluster client.""" + if not dify_config.REDIS_CLUSTERS: + raise ValueError("REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True") + + nodes = [ + ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1])) + for node in dify_config.REDIS_CLUSTERS.split(",") + ] + + cluster: RedisCluster = RedisCluster( + startup_nodes=nodes, + password=dify_config.REDIS_CLUSTERS_PASSWORD, + protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL, + cache_config=_get_cache_configuration(), + ) + return cluster + + +def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]: + """Create standalone Redis client.""" + connection_class, ssl_kwargs = _get_ssl_configuration() + + redis_params.update( + { + "host": dify_config.REDIS_HOST, + "port": dify_config.REDIS_PORT, + "connection_class": connection_class, + } + ) + + if ssl_kwargs: + redis_params.update(ssl_kwargs) + + pool = redis.ConnectionPool(**redis_params) + client: redis.Redis = redis.Redis(connection_pool=pool) + return client + + +def init_app(app: DifyApp): + """Initialize Redis client and attach it to the app.""" + global redis_client + + # Determine Redis mode and create appropriate client + if dify_config.REDIS_USE_SENTINEL: + redis_params = _get_base_redis_params() + client = _create_sentinel_client(redis_params) + elif dify_config.REDIS_USE_CLUSTERS: + client = _create_cluster_client() + else: + redis_params = _get_base_redis_params() + client = _create_standalone_client(redis_params) + + # Initialize the wrapper and attach to app + redis_client.initialize(client) app.extensions["redis"] = redis_client diff --git a/api/mypy.ini b/api/mypy.ini index 6836b2602b..3a6a54afe1 100644 --- a/api/mypy.ini +++ b/api/mypy.ini @@ -5,8 +5,7 @@ check_untyped_defs = True cache_fine_grained = True sqlite_cache = True exclude = (?x)( - core/model_runtime/model_providers/ - | tests/ + tests/ | migrations/ ) diff --git a/api/pyproject.toml b/api/pyproject.toml index 61a725a830..ce642aa9c8 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -205,7 +205,7 @@ vdb = [ "pgvector==0.2.5", "pymilvus~=2.5.0", "pymochow==1.3.1", - "pyobvector~=0.1.6", + "pyobvector~=0.2.15", "qdrant-client==1.9.0", "tablestore==6.2.0", "tcvectordb~=1.6.4", diff --git a/api/schedule/clean_workflow_runlogs_precise.py b/api/schedule/clean_workflow_runlogs_precise.py new file mode 100644 index 0000000000..0de3b5f68f --- /dev/null +++ b/api/schedule/clean_workflow_runlogs_precise.py @@ -0,0 +1,155 @@ +import datetime +import logging +import time + +import click + +import app +from configs import dify_config +from extensions.ext_database import db +from models.model import ( + AppAnnotationHitHistory, + Conversation, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.workflow import ConversationVariable, WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun + +_logger = logging.getLogger(__name__) + + +MAX_RETRIES = 3 +BATCH_SIZE = dify_config.WORKFLOW_LOG_CLEANUP_BATCH_SIZE + + +@app.celery.task(queue="dataset") +def clean_workflow_runlogs_precise(): + """Clean expired workflow run logs with retry mechanism and complete message cascade""" + + click.echo(click.style("Start clean workflow run logs (precise mode with complete cascade).", fg="green")) + start_at = time.perf_counter() + + retention_days = dify_config.WORKFLOW_LOG_RETENTION_DAYS + cutoff_date = datetime.datetime.now() - datetime.timedelta(days=retention_days) + + try: + total_workflow_runs = db.session.query(WorkflowRun).filter(WorkflowRun.created_at < cutoff_date).count() + if total_workflow_runs == 0: + _logger.info("No expired workflow run logs found") + return + _logger.info("Found %s expired workflow run logs to clean", total_workflow_runs) + + total_deleted = 0 + failed_batches = 0 + batch_count = 0 + + while True: + workflow_runs = ( + db.session.query(WorkflowRun.id).filter(WorkflowRun.created_at < cutoff_date).limit(BATCH_SIZE).all() + ) + + if not workflow_runs: + break + + workflow_run_ids = [run.id for run in workflow_runs] + batch_count += 1 + + success = _delete_batch_with_retry(workflow_run_ids, failed_batches) + + if success: + total_deleted += len(workflow_run_ids) + failed_batches = 0 + else: + failed_batches += 1 + if failed_batches >= MAX_RETRIES: + _logger.error("Failed to delete batch after %s retries, aborting cleanup for today", MAX_RETRIES) + break + else: + # Calculate incremental delay times: 5, 10, 15 minutes + retry_delay_minutes = failed_batches * 5 + _logger.warning("Batch deletion failed, retrying in %s minutes...", retry_delay_minutes) + time.sleep(retry_delay_minutes * 60) + continue + + _logger.info("Cleanup completed: %s expired workflow run logs deleted", total_deleted) + + except Exception as e: + db.session.rollback() + _logger.exception("Unexpected error in workflow log cleanup") + raise + + end_at = time.perf_counter() + execution_time = end_at - start_at + click.echo(click.style(f"Cleaned workflow run logs from db success latency: {execution_time:.2f}s", fg="green")) + + +def _delete_batch_with_retry(workflow_run_ids: list[str], attempt_count: int) -> bool: + """Delete a single batch with a retry mechanism and complete cascading deletion""" + try: + with db.session.begin_nested(): + message_data = ( + db.session.query(Message.id, Message.conversation_id) + .filter(Message.workflow_run_id.in_(workflow_run_ids)) + .all() + ) + message_id_list = [msg.id for msg in message_data] + conversation_id_list = list({msg.conversation_id for msg in message_data if msg.conversation_id}) + if message_id_list: + db.session.query(AppAnnotationHitHistory).filter( + AppAnnotationHitHistory.message_id.in_(message_id_list) + ).delete(synchronize_session=False) + + db.session.query(MessageAgentThought).filter( + MessageAgentThought.message_id.in_(message_id_list) + ).delete(synchronize_session=False) + + db.session.query(MessageChain).filter(MessageChain.message_id.in_(message_id_list)).delete( + synchronize_session=False + ) + + db.session.query(MessageFile).filter(MessageFile.message_id.in_(message_id_list)).delete( + synchronize_session=False + ) + + db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id.in_(message_id_list)).delete( + synchronize_session=False + ) + + db.session.query(MessageFeedback).filter(MessageFeedback.message_id.in_(message_id_list)).delete( + synchronize_session=False + ) + + db.session.query(Message).filter(Message.workflow_run_id.in_(workflow_run_ids)).delete( + synchronize_session=False + ) + + db.session.query(WorkflowAppLog).filter(WorkflowAppLog.workflow_run_id.in_(workflow_run_ids)).delete( + synchronize_session=False + ) + + db.session.query(WorkflowNodeExecutionModel).filter( + WorkflowNodeExecutionModel.workflow_run_id.in_(workflow_run_ids) + ).delete(synchronize_session=False) + + if conversation_id_list: + db.session.query(ConversationVariable).filter( + ConversationVariable.conversation_id.in_(conversation_id_list) + ).delete(synchronize_session=False) + + db.session.query(Conversation).filter(Conversation.id.in_(conversation_id_list)).delete( + synchronize_session=False + ) + + db.session.query(WorkflowRun).filter(WorkflowRun.id.in_(workflow_run_ids)).delete(synchronize_session=False) + + db.session.commit() + return True + + except Exception as e: + db.session.rollback() + _logger.exception("Batch deletion failed (attempt %s)", attempt_count + 1) + return False diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 8934608da1..6ddda4c0c6 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -6,7 +6,7 @@ import secrets import time import uuid from collections import Counter -from typing import Any, Optional +from typing import Any, Literal, Optional from flask_login import current_user from sqlalchemy import func, select @@ -51,7 +51,7 @@ from services.entities.knowledge_entities.knowledge_entities import ( RetrievalModel, SegmentUpdateArgs, ) -from services.errors.account import InvalidActionError, NoPermissionError +from services.errors.account import NoPermissionError from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError from services.errors.dataset import DatasetNameDuplicateError from services.errors.document import DocumentIndexingError @@ -1800,14 +1800,16 @@ class DocumentService: raise ValueError("Process rule segmentation max_tokens is invalid") @staticmethod - def batch_update_document_status(dataset: Dataset, document_ids: list[str], action: str, user): + def batch_update_document_status( + dataset: Dataset, document_ids: list[str], action: Literal["enable", "disable", "archive", "un_archive"], user + ): """ Batch update document status. Args: dataset (Dataset): The dataset object document_ids (list[str]): List of document IDs to update - action (str): Action to perform (enable, disable, archive, un_archive) + action (Literal["enable", "disable", "archive", "un_archive"]): Action to perform user: Current user performing the action Raises: @@ -1890,9 +1892,10 @@ class DocumentService: raise propagation_error @staticmethod - def _prepare_document_status_update(document, action: str, user): - """ - Prepare document status update information. + def _prepare_document_status_update( + document: Document, action: Literal["enable", "disable", "archive", "un_archive"], user + ): + """Prepare document status update information. Args: document: Document object to update @@ -2355,7 +2358,9 @@ class SegmentService: db.session.commit() @classmethod - def update_segments_status(cls, segment_ids: list, action: str, dataset: Dataset, document: Document): + def update_segments_status( + cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document + ): # Check if segment_ids is not empty to avoid WHERE false condition if not segment_ids or len(segment_ids) == 0: return @@ -2413,8 +2418,6 @@ class SegmentService: db.session.commit() disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) - else: - raise InvalidActionError() @classmethod def create_child_chunk( diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index b84dd0afc5..055fbb8138 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -47,7 +47,9 @@ class OAuthProxyService(BasePluginClient): if not context_id: raise ValueError("context_id is required") # get data from redis - data = redis_client.getdel(f"{OAuthProxyService.__KEY_PREFIX__}{context_id}") + key = f"{OAuthProxyService.__KEY_PREFIX__}{context_id}" + data = redis_client.get(key) if not data: raise ValueError("context_id is invalid") + redis_client.delete(key) return json.loads(data) diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index 8c4c1876ad..5ab377c232 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -1,5 +1,6 @@ import logging import time +from typing import Literal import click from celery import shared_task # type: ignore @@ -13,7 +14,7 @@ from models.dataset import Document as DatasetDocument @shared_task(queue="dataset") -def deal_dataset_vector_index_task(dataset_id: str, action: str): +def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]): """ Async deal dataset from index :param dataset_id: dataset_id diff --git a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py index 61d9a9e712..fe0e03f7b8 100644 --- a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py +++ b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py @@ -1,4 +1,5 @@ from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector +from core.rag.models.document import Document from tests.integration_tests.vdb.test_vector_store import ( AbstractVectorTest, setup_mock_redis, @@ -18,6 +19,14 @@ class QdrantVectorTest(AbstractVectorTest): ), ) + def search_by_vector(self): + super().search_by_vector() + # only test for qdrant, may not work on other vector stores + hits_by_vector: list[Document] = self.vector.search_by_vector( + query_vector=self.example_embedding, score_threshold=1 + ) + assert len(hits_by_vector) == 0 + def test_qdrant_vector(setup_mock_redis): QdrantVectorTest().run_all_tests() diff --git a/api/tests/test_containers_integration_tests/services/test_feature_service.py b/api/tests/test_containers_integration_tests/services/test_feature_service.py new file mode 100644 index 0000000000..8bd5440411 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_feature_service.py @@ -0,0 +1,1785 @@ +from unittest.mock import patch + +import pytest +from faker import Faker + +from services.feature_service import FeatureModel, FeatureService, KnowledgeRateLimitModel, SystemFeatureModel + + +class TestFeatureService: + """Integration tests for FeatureService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.feature_service.BillingService") as mock_billing_service, + patch("services.feature_service.EnterpriseService") as mock_enterprise_service, + ): + # Setup default mock returns for BillingService + mock_billing_service.get_info.return_value = { + "enabled": True, + "subscription": {"plan": "pro", "interval": "monthly", "education": True}, + "members": {"size": 5, "limit": 10}, + "apps": {"size": 3, "limit": 20}, + "vector_space": {"size": 2, "limit": 10}, + "documents_upload_quota": {"size": 15, "limit": 100}, + "annotation_quota_limit": {"size": 8, "limit": 50}, + "docs_processing": "enhanced", + "can_replace_logo": True, + "model_load_balancing_enabled": True, + "knowledge_rate_limit": {"limit": 100}, + } + + mock_billing_service.get_knowledge_rate_limit.return_value = {"limit": 100, "subscription_plan": "pro"} + + # Setup default mock returns for EnterpriseService + mock_enterprise_service.get_workspace_info.return_value = { + "WorkspaceMembers": {"used": 5, "limit": 10, "enabled": True} + } + + mock_enterprise_service.get_info.return_value = { + "SSOEnforcedForSignin": True, + "SSOEnforcedForSigninProtocol": "saml", + "EnableEmailCodeLogin": True, + "EnableEmailPasswordLogin": False, + "IsAllowRegister": False, + "IsAllowCreateWorkspace": False, + "Branding": { + "applicationTitle": "Test Enterprise", + "loginPageLogo": "https://example.com/logo.png", + "workspaceLogo": "https://example.com/workspace.png", + "favicon": "https://example.com/favicon.ico", + }, + "WebAppAuth": {"allowSso": True, "allowEmailCodeLogin": True, "allowEmailPasswordLogin": False}, + "SSOEnforcedForWebProtocol": "oidc", + "License": { + "status": "active", + "expiredAt": "2025-12-31", + "workspaces": {"enabled": True, "limit": 5, "used": 2}, + }, + "PluginInstallationPermission": { + "pluginInstallationScope": "official_only", + "restrictToMarketplaceOnly": True, + }, + } + + yield { + "billing_service": mock_billing_service, + "enterprise_service": mock_enterprise_service, + } + + def _create_test_tenant_id(self): + """Helper method to create a test tenant ID.""" + fake = Faker() + return fake.uuid4() + + def test_get_features_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful feature retrieval with billing and enterprise enabled. + + This test verifies: + - Proper feature model creation with all required fields + - Correct integration with billing service + - Proper enterprise workspace information handling + - Return value correctness and structure + """ + # Arrange: Setup test data with proper config mocking + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = True + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = True + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = True + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify billing features + assert result.billing.enabled is True + assert result.billing.subscription.plan == "pro" + assert result.billing.subscription.interval == "monthly" + assert result.education.activated is True + + # Verify member limitations + assert result.members.size == 5 + assert result.members.limit == 10 + + # Verify app limitations + assert result.apps.size == 3 + assert result.apps.limit == 20 + + # Verify vector space limitations + assert result.vector_space.size == 2 + assert result.vector_space.limit == 10 + + # Verify document upload quota + assert result.documents_upload_quota.size == 15 + assert result.documents_upload_quota.limit == 100 + + # Verify annotation quota + assert result.annotation_quota_limit.size == 8 + assert result.annotation_quota_limit.limit == 50 + + # Verify other features + assert result.docs_processing == "enhanced" + assert result.can_replace_logo is True + assert result.model_load_balancing_enabled is True + assert result.knowledge_rate_limit == 100 + + # Verify enterprise features + assert result.workspace_members.enabled is True + assert result.workspace_members.size == 5 + assert result.workspace_members.limit == 10 + + # Verify webapp copyright is enabled for non-sandbox plans + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + mock_external_service_dependencies["enterprise_service"].get_workspace_info.assert_called_once_with( + tenant_id + ) + + def test_get_features_sandbox_plan(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval for sandbox plan with specific limitations. + + This test verifies: + - Proper handling of sandbox plan limitations + - Correct webapp copyright settings for sandbox + - Transfer workspace restrictions for sandbox plans + - Proper billing service integration + """ + # Arrange: Setup sandbox plan mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = False + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = False + mock_config.EDUCATION_ENABLED = False + + # Set mock return value inside the patch context + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "sandbox", "interval": "monthly", "education": False}, + "members": {"size": 1, "limit": 3}, + "apps": {"size": 1, "limit": 5}, + "vector_space": {"size": 1, "limit": 2}, + "documents_upload_quota": {"size": 5, "limit": 20}, + "annotation_quota_limit": {"size": 2, "limit": 10}, + "docs_processing": "standard", + "can_replace_logo": False, + "model_load_balancing_enabled": False, + "knowledge_rate_limit": {"limit": 10}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify sandbox-specific limitations + assert result.billing.subscription.plan == "sandbox" + assert result.education.activated is False + + # Verify sandbox limitations + assert result.members.size == 1 + assert result.members.limit == 3 + assert result.apps.size == 1 + assert result.apps.limit == 5 + assert result.vector_space.size == 1 + assert result.vector_space.limit == 2 + assert result.documents_upload_quota.size == 5 + assert result.documents_upload_quota.limit == 20 + assert result.annotation_quota_limit.size == 2 + assert result.annotation_quota_limit.limit == 10 + + # Verify sandbox-specific restrictions + assert result.webapp_copyright_enabled is False + assert result.is_allow_transfer_workspace is False + assert result.can_replace_logo is False + assert result.model_load_balancing_enabled is False + assert result.docs_processing == "standard" + assert result.knowledge_rate_limit == 10 + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_knowledge_rate_limit_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful knowledge rate limit retrieval with billing enabled. + + This test verifies: + - Proper knowledge rate limit model creation + - Correct integration with billing service + - Proper rate limit configuration + - Return value correctness and structure + """ + # Arrange: Setup test data with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + + # Act: Execute the method under test + result = FeatureService.get_knowledge_rate_limit(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, KnowledgeRateLimitModel) + + # Verify rate limit configuration + assert result.enabled is True + assert result.limit == 100 + assert result.subscription_plan == "pro" + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_knowledge_rate_limit.assert_called_once_with( + tenant_id + ) + + def test_get_system_features_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful system features retrieval with enterprise and marketplace enabled. + + This test verifies: + - Proper system feature model creation + - Correct integration with enterprise service + - Proper marketplace configuration + - Return value correctness and structure + """ + # Arrange: Setup test data with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = True + mock_config.ENABLE_EMAIL_CODE_LOGIN = True + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify SSO configuration + assert result.sso_enforced_for_signin is True + assert result.sso_enforced_for_signin_protocol == "saml" + + # Verify authentication settings + assert result.enable_email_code_login is True + assert result.enable_email_password_login is False + assert result.is_allow_register is False + assert result.is_allow_create_workspace is False + + # Verify branding configuration + assert result.branding.application_title == "Test Enterprise" + assert result.branding.login_page_logo == "https://example.com/logo.png" + assert result.branding.workspace_logo == "https://example.com/workspace.png" + assert result.branding.favicon == "https://example.com/favicon.ico" + + # Verify webapp auth configuration + assert result.webapp_auth.allow_sso is True + assert result.webapp_auth.allow_email_code_login is True + assert result.webapp_auth.allow_email_password_login is False + assert result.webapp_auth.sso_config.protocol == "oidc" + + # Verify license configuration + assert result.license.status.value == "active" + assert result.license.expired_at == "2025-12-31" + assert result.license.workspaces.enabled is True + assert result.license.workspaces.limit == 5 + assert result.license.workspaces.size == 2 + + # Verify plugin installation permission + assert result.plugin_installation_permission.plugin_installation_scope == "official_only" + assert result.plugin_installation_permission.restrict_to_marketplace_only is True + + # Verify marketplace configuration + assert result.enable_marketplace is True + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_system_features_basic_config(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test system features retrieval with basic configuration (no enterprise). + + This test verifies: + - Proper system feature model creation without enterprise + - Correct environment variable handling + - Default configuration values + - Return value correctness and structure + """ + # Arrange: Setup basic config mock (no enterprise) + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = True + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = True + mock_config.ALLOW_CREATE_WORKSPACE = True + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify basic configuration + assert result.branding.enabled is False + assert result.webapp_auth.enabled is False + assert result.enable_change_email is True + + # Verify authentication settings from config + assert result.enable_email_code_login is True + assert result.enable_email_password_login is True + assert result.enable_social_oauth_login is False + assert result.is_allow_register is True + assert result.is_allow_create_workspace is True + assert result.is_email_setup is True + + # Verify marketplace configuration + assert result.enable_marketplace is False + + # Verify plugin package size (uses default value from dify_config) + assert result.max_plugin_package_size == 15728640 + + def test_get_features_billing_disabled(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval when billing is disabled. + + This test verifies: + - Proper feature model creation without billing + - Correct environment variable handling + - Default configuration values + - Return value correctness and structure + """ + # Arrange: Setup billing disabled mock + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = False + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = True + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = True + + tenant_id = self._create_test_tenant_id() + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify billing is disabled + assert result.billing.enabled is False + + # Verify environment-based features + assert result.can_replace_logo is True + assert result.model_load_balancing_enabled is True + assert result.dataset_operator_enabled is True + assert result.education.enabled is True + + # Verify default limitations + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 50 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 10 + assert result.knowledge_rate_limit == 10 + assert result.docs_processing == "standard" + + # Verify no enterprise features + assert result.workspace_members.enabled is False + assert result.webapp_copyright_enabled is False + + def test_get_knowledge_rate_limit_billing_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test knowledge rate limit retrieval when billing is disabled. + + This test verifies: + - Proper knowledge rate limit model creation without billing + - Default rate limit configuration + - Return value correctness and structure + """ + # Arrange: Setup billing disabled mock + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = False + + tenant_id = self._create_test_tenant_id() + + # Act: Execute the method under test + result = FeatureService.get_knowledge_rate_limit(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, KnowledgeRateLimitModel) + + # Verify default configuration + assert result.enabled is False + assert result.limit == 10 + assert result.subscription_plan == "" # Empty string when billing is disabled + + # Verify no billing service calls + mock_external_service_dependencies["billing_service"].get_knowledge_rate_limit.assert_not_called() + + def test_get_features_enterprise_only(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval with enterprise enabled but billing disabled. + + This test verifies: + - Proper feature model creation with enterprise only + - Correct enterprise service integration + - Proper workspace member handling + - Return value correctness and structure + """ + # Arrange: Setup enterprise only mock + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = False + mock_config.ENTERPRISE_ENABLED = True + mock_config.CAN_REPLACE_LOGO = False + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = False + mock_config.EDUCATION_ENABLED = False + + tenant_id = self._create_test_tenant_id() + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify billing is disabled + assert result.billing.enabled is False + + # Verify enterprise features + assert result.webapp_copyright_enabled is True + + # Verify workspace members from enterprise + assert result.workspace_members.enabled is True + assert result.workspace_members.size == 5 + assert result.workspace_members.limit == 10 + + # Verify environment-based features + assert result.can_replace_logo is False + assert result.model_load_balancing_enabled is False + assert result.dataset_operator_enabled is False + assert result.education.enabled is False + + # Verify default limitations + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_workspace_info.assert_called_once_with( + tenant_id + ) + mock_external_service_dependencies["billing_service"].get_info.assert_not_called() + + def test_get_system_features_enterprise_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features retrieval when enterprise is disabled. + + This test verifies: + - Proper system feature model creation without enterprise + - Correct environment variable handling + - Default configuration values + - Return value correctness and structure + """ + # Arrange: Setup enterprise disabled mock + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + mock_config.MARKETPLACE_ENABLED = True + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = True + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = None + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 50 + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify enterprise features are disabled + assert result.branding.enabled is False + assert result.webapp_auth.enabled is False + assert result.enable_change_email is True + + # Verify authentication settings from config + assert result.enable_email_code_login is False + assert result.enable_email_password_login is True + assert result.enable_social_oauth_login is True + assert result.is_allow_register is False + assert result.is_allow_create_workspace is False + assert result.is_email_setup is False + + # Verify marketplace configuration + assert result.enable_marketplace is True + + # Verify plugin package size (uses default value from dify_config) + assert result.max_plugin_package_size == 15728640 + + # Verify default license status + assert result.license.status.value == "none" + assert result.license.expired_at == "" + assert result.license.workspaces.enabled is False + + # Verify no enterprise service calls + mock_external_service_dependencies["enterprise_service"].get_info.assert_not_called() + + def test_get_features_no_tenant_id(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval without tenant ID (billing disabled). + + This test verifies: + - Proper feature model creation without tenant ID + - Correct handling when billing is disabled + - Default configuration values + - Return value correctness and structure + """ + # Arrange: Setup no tenant ID scenario + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + # Act: Execute the method under test + result = FeatureService.get_features("") + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify billing is disabled due to no tenant ID + assert result.billing.enabled is False + + # Verify environment-based features + assert result.can_replace_logo is True + assert result.model_load_balancing_enabled is False + assert result.dataset_operator_enabled is True + assert result.education.enabled is False + + # Verify default limitations + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + + # Verify no billing service calls + mock_external_service_dependencies["billing_service"].get_info.assert_not_called() + + def test_get_features_partial_billing_info(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval with partial billing information. + + This test verifies: + - Proper handling of partial billing data + - Correct fallback to default values + - Proper billing service integration + - Return value correctness and structure + """ + # Arrange: Setup partial billing info mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "basic", "interval": "yearly"}, + # Missing members, apps, vector_space, etc. + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify billing features + assert result.billing.enabled is True + assert result.billing.subscription.plan == "basic" + assert result.billing.subscription.interval == "yearly" + + # Verify default values for missing billing info + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 50 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 10 + assert result.knowledge_rate_limit == 10 + assert result.docs_processing == "standard" + + # Verify basic plan restrictions (non-sandbox plans have webapp copyright enabled) + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_features_edge_case_vector_space(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval with edge case vector space configuration. + + This test verifies: + - Proper handling of vector space quota limits + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case vector space mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "pro", "interval": "monthly"}, + "vector_space": {"size": 0, "limit": 0}, + "apps": {"size": 5, "limit": 10}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify vector space configuration + assert result.vector_space.size == 0 + assert result.vector_space.limit == 0 + + # Verify apps configuration + assert result.apps.size == 5 + assert result.apps.limit == 10 + + # Verify pro plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify default values for missing billing info + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 50 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 10 + assert result.knowledge_rate_limit == 10 + assert result.docs_processing == "standard" + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_system_features_edge_case_webapp_auth( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features retrieval with edge case webapp auth configuration. + + This test verifies: + - Proper handling of webapp auth configuration + - Correct enterprise service integration + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case webapp auth mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "WebAppAuth": {"allowSso": False, "allowEmailCodeLogin": True, "allowEmailPasswordLogin": False} + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify webapp auth configuration + assert result.webapp_auth.allow_sso is False + assert result.webapp_auth.allow_email_code_login is True + assert result.webapp_auth.allow_email_password_login is False + assert result.webapp_auth.sso_config.protocol == "" + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify default values for missing enterprise info + assert result.sso_enforced_for_signin is False + assert result.sso_enforced_for_signin_protocol == "" + assert result.enable_email_code_login is False + assert result.enable_email_password_login is True + assert result.is_allow_register is False + assert result.is_allow_create_workspace is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_features_edge_case_members_quota(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval with edge case members quota configuration. + + This test verifies: + - Proper handling of members quota limits + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case members quota mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "basic", "interval": "yearly"}, + "members": {"size": 10, "limit": 10}, + "vector_space": {"size": 3, "limit": 5}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify members configuration + assert result.members.size == 10 + assert result.members.limit == 10 + + # Verify vector space configuration + assert result.vector_space.size == 3 + assert result.vector_space.limit == 5 + + # Verify basic plan features (non-sandbox plans have webapp copyright enabled) + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify default values for missing billing info + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 50 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 10 + assert result.knowledge_rate_limit == 10 + assert result.docs_processing == "standard" + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_plugin_installation_permission_scopes( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features retrieval with different plugin installation permission scopes. + + This test verifies: + - Proper handling of different plugin installation scopes + - Correct enterprise service integration + - Proper permission configuration + - Return value correctness and structure + """ + + # Test case 1: Official only scope + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "PluginInstallationPermission": { + "pluginInstallationScope": "official_only", + "restrictToMarketplaceOnly": True, + } + } + + result = FeatureService.get_system_features() + assert result.plugin_installation_permission.plugin_installation_scope == "official_only" + assert result.plugin_installation_permission.restrict_to_marketplace_only is True + + # Test case 2: All plugins scope + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "PluginInstallationPermission": {"pluginInstallationScope": "all", "restrictToMarketplaceOnly": False} + } + + result = FeatureService.get_system_features() + assert result.plugin_installation_permission.plugin_installation_scope == "all" + assert result.plugin_installation_permission.restrict_to_marketplace_only is False + + # Test case 3: Specific partners scope + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "PluginInstallationPermission": { + "pluginInstallationScope": "official_and_specific_partners", + "restrictToMarketplaceOnly": False, + } + } + + result = FeatureService.get_system_features() + assert result.plugin_installation_permission.plugin_installation_scope == "official_and_specific_partners" + assert result.plugin_installation_permission.restrict_to_marketplace_only is False + + # Test case 4: None scope + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "PluginInstallationPermission": {"pluginInstallationScope": "none", "restrictToMarketplaceOnly": True} + } + + result = FeatureService.get_system_features() + assert result.plugin_installation_permission.plugin_installation_scope == "none" + assert result.plugin_installation_permission.restrict_to_marketplace_only is True + + def test_get_features_workspace_members_missing( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test feature retrieval when workspace members info is missing from enterprise. + + This test verifies: + - Proper handling of missing workspace members data + - Correct enterprise service integration + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup missing workspace members mock + tenant_id = self._create_test_tenant_id() + mock_external_service_dependencies["enterprise_service"].get_workspace_info.return_value = { + # Missing WorkspaceMembers key + } + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = False + mock_config.ENTERPRISE_ENABLED = True + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify workspace members use default values + assert result.workspace_members.enabled is False + assert result.workspace_members.size == 0 + assert result.workspace_members.limit == 0 + + # Verify enterprise features + assert result.webapp_copyright_enabled is True + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_workspace_info.assert_called_once_with( + tenant_id + ) + + def test_get_system_features_license_inactive(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test system features retrieval with inactive license. + + This test verifies: + - Proper handling of inactive license status + - Correct enterprise service integration + - Proper license status handling + - Return value correctness and structure + """ + # Arrange: Setup inactive license mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "License": { + "status": "inactive", + "expiredAt": "", + "workspaces": {"enabled": False, "limit": 0, "used": 0}, + } + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify license status + assert result.license.status == "inactive" + assert result.license.expired_at == "" + assert result.license.workspaces.enabled is False + assert result.license.workspaces.size == 0 + assert result.license.workspaces.limit == 0 + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_system_features_partial_enterprise_info( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features retrieval with partial enterprise information. + + This test verifies: + - Proper handling of partial enterprise data + - Correct fallback to default values + - Proper enterprise service integration + - Return value correctness and structure + """ + # Arrange: Setup partial enterprise info mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "SSOEnforcedForSignin": True, + "Branding": {"applicationTitle": "Partial Enterprise"}, + # Missing WebAppAuth, License, PluginInstallationPermission, etc. + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify SSO configuration + assert result.sso_enforced_for_signin is True + assert result.sso_enforced_for_signin_protocol == "" + + # Verify branding configuration (partial) + assert result.branding.application_title == "Partial Enterprise" + assert result.branding.login_page_logo == "" + assert result.branding.workspace_logo == "" + assert result.branding.favicon == "" + + # Verify default values for missing enterprise info + assert result.webapp_auth.allow_sso is False + assert result.webapp_auth.allow_email_code_login is False + assert result.webapp_auth.allow_email_password_login is False + assert result.webapp_auth.sso_config.protocol == "" + + # Verify default license status + assert result.license.status == "none" + assert result.license.expired_at == "" + assert result.license.workspaces.enabled is False + + # Verify default plugin installation permission + assert result.plugin_installation_permission.plugin_installation_scope == "all" + assert result.plugin_installation_permission.restrict_to_marketplace_only is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_features_edge_case_limits(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval with edge case limit values. + + This test verifies: + - Proper handling of zero and negative limits + - Correct handling of very large limits + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case limits mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "enterprise", "interval": "yearly"}, + "members": {"size": 0, "limit": 0}, + "apps": {"size": 0, "limit": -1}, + "vector_space": {"size": 0, "limit": 999999}, + "documents_upload_quota": {"size": 0, "limit": 0}, + "annotation_quota_limit": {"size": 0, "limit": 1}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify edge case limits + assert result.members.size == 0 + assert result.members.limit == 0 + assert result.apps.size == 0 + assert result.apps.limit == -1 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 999999 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 0 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 1 + + # Verify enterprise plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_system_features_edge_case_protocols( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features retrieval with edge case protocol values. + + This test verifies: + - Proper handling of empty protocol strings + - Correct handling of special protocol values + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case protocols mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "SSOEnforcedForSigninProtocol": "", + "SSOEnforcedForWebProtocol": " ", + "WebAppAuth": {"allowSso": True, "allowEmailCodeLogin": False, "allowEmailPasswordLogin": True}, + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify edge case protocols + assert result.sso_enforced_for_signin_protocol == "" + assert result.webapp_auth.sso_config.protocol == " " + + # Verify webapp auth configuration + assert result.webapp_auth.allow_sso is True + assert result.webapp_auth.allow_email_code_login is False + assert result.webapp_auth.allow_email_password_login is True + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_features_edge_case_education(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval with edge case education configuration. + + This test verifies: + - Proper handling of education feature flags + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case education mock + tenant_id = self._create_test_tenant_id() + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "education", "interval": "semester", "education": True}, + "members": {"size": 100, "limit": 200}, + "apps": {"size": 50, "limit": 100}, + "vector_space": {"size": 20, "limit": 50}, + "documents_upload_quota": {"size": 500, "limit": 1000}, + "annotation_quota_limit": {"size": 200, "limit": 500}, + } + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.EDUCATION_ENABLED = True + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify education features + assert result.education.enabled is True + assert result.education.activated is True + + # Verify education plan limits + assert result.members.size == 100 + assert result.members.limit == 200 + assert result.apps.size == 50 + assert result.apps.limit == 100 + assert result.vector_space.size == 20 + assert result.vector_space.limit == 50 + assert result.documents_upload_quota.size == 500 + assert result.documents_upload_quota.limit == 1000 + assert result.annotation_quota_limit.size == 200 + assert result.annotation_quota_limit.limit == 500 + + # Verify education plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_license_limitation_model_is_available( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test LicenseLimitationModel.is_available method with various scenarios. + + This test verifies: + - Proper quota availability calculation + - Correct handling of unlimited limits + - Proper handling of disabled limits + - Return value correctness for different scenarios + """ + from services.feature_service import LicenseLimitationModel + + # Test case 1: Limit disabled + disabled_limit = LicenseLimitationModel(enabled=False, size=5, limit=10) + assert disabled_limit.is_available(3) is True + assert disabled_limit.is_available(10) is True + + # Test case 2: Unlimited limit + unlimited_limit = LicenseLimitationModel(enabled=True, size=5, limit=0) + assert unlimited_limit.is_available(3) is True + assert unlimited_limit.is_available(100) is True + + # Test case 3: Available quota + available_limit = LicenseLimitationModel(enabled=True, size=5, limit=10) + assert available_limit.is_available(3) is True + assert available_limit.is_available(5) is True + assert available_limit.is_available(1) is True + + # Test case 4: Insufficient quota + insufficient_limit = LicenseLimitationModel(enabled=True, size=8, limit=10) + assert insufficient_limit.is_available(3) is False + assert insufficient_limit.is_available(2) is True + assert insufficient_limit.is_available(1) is True + + # Test case 5: Exact quota usage + exact_limit = LicenseLimitationModel(enabled=True, size=7, limit=10) + assert exact_limit.is_available(3) is True + assert exact_limit.is_available(3) is True + + def test_get_features_workspace_members_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test feature retrieval when workspace members are disabled in enterprise. + + This test verifies: + - Proper handling of disabled workspace members + - Correct enterprise service integration + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup workspace members disabled mock + tenant_id = self._create_test_tenant_id() + mock_external_service_dependencies["enterprise_service"].get_workspace_info.return_value = { + "WorkspaceMembers": {"used": 0, "limit": 0, "enabled": False} + } + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = False + mock_config.ENTERPRISE_ENABLED = True + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify workspace members are disabled + assert result.workspace_members.enabled is False + assert result.workspace_members.size == 0 + assert result.workspace_members.limit == 0 + + # Verify enterprise features + assert result.webapp_copyright_enabled is True + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_workspace_info.assert_called_once_with(tenant_id) + + def test_get_system_features_license_expired(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test system features retrieval with expired license. + + This test verifies: + - Proper handling of expired license status + - Correct enterprise service integration + - Proper license status handling + - Return value correctness and structure + """ + # Arrange: Setup expired license mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "License": { + "status": "expired", + "expiredAt": "2023-12-31", + "workspaces": {"enabled": False, "limit": 0, "used": 0}, + } + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify license status + assert result.license.status == "expired" + assert result.license.expired_at == "2023-12-31" + assert result.license.workspaces.enabled is False + assert result.license.workspaces.size == 0 + assert result.license.workspaces.limit == 0 + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_features_edge_case_docs_processing( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test feature retrieval with edge case document processing configuration. + + This test verifies: + - Proper handling of different document processing modes + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case docs processing mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = True + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "premium", "interval": "monthly"}, + "docs_processing": "advanced", + "can_replace_logo": True, + "model_load_balancing_enabled": True, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify docs processing configuration + assert result.docs_processing == "advanced" + assert result.can_replace_logo is True + assert result.model_load_balancing_enabled is True + + # Verify premium plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify default limitations (no specific billing info) + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_system_features_edge_case_branding( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features retrieval with edge case branding configuration. + + This test verifies: + - Proper handling of partial branding information + - Correct enterprise service integration + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case branding mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "Branding": { + "applicationTitle": "Edge Case App", + "loginPageLogo": None, + "workspaceLogo": "", + "favicon": "https://example.com/favicon.ico", + } + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify branding configuration (edge cases) + assert result.branding.application_title == "Edge Case App" + assert result.branding.login_page_logo is None # None value from mock + assert result.branding.workspace_logo == "" + assert result.branding.favicon == "https://example.com/favicon.ico" + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify default values for missing enterprise info + assert result.sso_enforced_for_signin is False + assert result.sso_enforced_for_signin_protocol == "" + assert result.enable_email_code_login is False + assert result.enable_email_password_login is True + assert result.is_allow_register is False + assert result.is_allow_create_workspace is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_features_edge_case_annotation_quota( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test feature retrieval with edge case annotation quota configuration. + + This test verifies: + - Proper handling of annotation quota limits + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case annotation quota mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "enterprise", "interval": "yearly"}, + "annotation_quota_limit": {"size": 999, "limit": 1000}, + "knowledge_rate_limit": {"limit": 500}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify annotation quota configuration + assert result.annotation_quota_limit.size == 999 + assert result.annotation_quota_limit.limit == 1000 + + # Verify knowledge rate limit + assert result.knowledge_rate_limit == 500 + + # Verify enterprise plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify default values for missing billing info + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 50 + assert result.docs_processing == "standard" + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_features_edge_case_documents_upload( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test feature retrieval with edge case documents upload settings. + + This test verifies: + - Proper handling of edge case documents upload configuration + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case documents upload mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "pro", "interval": "monthly"}, + "documents_upload_quota": { + "size": 0, # Edge case: zero current size + "limit": 0, # Edge case: zero limit + }, + "knowledge_rate_limit": {"limit": 100}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify documents upload quota configuration (edge cases) + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 0 + + # Verify knowledge rate limit + assert result.knowledge_rate_limit == 100 + + # Verify pro plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify default values for missing billing info + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 10 # Default value when not provided + assert result.docs_processing == "standard" + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_system_features_edge_case_license_lost( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features with lost license status. + + This test verifies: + - Proper handling of lost license status + - Correct enterprise service integration + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup lost license mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "license": {"status": "lost", "expired_at": None, "plan": None} + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify default values for missing enterprise info + assert result.sso_enforced_for_signin is False + assert result.sso_enforced_for_signin_protocol == "" + assert result.enable_email_code_login is False + assert result.enable_email_password_login is True + assert result.is_allow_register is False + assert result.is_allow_create_workspace is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_features_edge_case_education_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test feature retrieval with education feature disabled. + + This test verifies: + - Proper handling of disabled education features + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup education disabled mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": { + "plan": "pro", + "interval": "monthly", + "education": False, # Education explicitly disabled + }, + "knowledge_rate_limit": {"limit": 100}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify education configuration + assert result.education.activated is False + + # Verify knowledge rate limit + assert result.knowledge_rate_limit == 100 + + # Verify pro plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify default values for missing billing info + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 50 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 10 # Default value when not provided + assert result.docs_processing == "standard" + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py new file mode 100644 index 0000000000..7fef572c14 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -0,0 +1,1144 @@ +from unittest.mock import patch + +import pytest +from faker import Faker + +from core.rag.index_processor.constant.built_in_field import BuiltInField +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document +from services.entities.knowledge_entities.knowledge_entities import MetadataArgs +from services.metadata_service import MetadataService + + +class TestMetadataService: + """Integration tests for MetadataService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.metadata_service.current_user") as mock_current_user, + patch("services.metadata_service.redis_client") as mock_redis_client, + patch("services.dataset_service.DocumentService") as mock_document_service, + ): + # Setup default mock returns + mock_redis_client.get.return_value = None + mock_redis_client.set.return_value = True + mock_redis_client.delete.return_value = 1 + + yield { + "current_user": mock_current_user, + "redis_client": mock_redis_client, + "document_service": mock_document_service, + } + + def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant for the account + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account, tenant + + def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, account, tenant): + """ + Helper method to create a test dataset for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + account: Account instance + tenant: Tenant instance + + Returns: + Dataset: Created dataset instance + """ + fake = Faker() + + dataset = Dataset( + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="upload_file", + created_by=account.id, + built_in_field_enabled=False, + ) + + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + return dataset + + def _create_test_document(self, db_session_with_containers, mock_external_service_dependencies, dataset, account): + """ + Helper method to create a test document for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + dataset: Dataset instance + account: Account instance + + Returns: + Document: Created document instance + """ + fake = Faker() + + document = Document( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=1, + data_source_type="upload_file", + data_source_info="{}", + batch="test-batch", + name=fake.file_name(), + created_from="web", + created_by=account.id, + doc_form="text", + doc_language="en", + ) + + from extensions.ext_database import db + + db.session.add(document) + db.session.commit() + + return document + + def test_create_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful metadata creation with valid parameters. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + metadata_args = MetadataArgs(type="string", name="test_metadata") + + # Act: Execute the method under test + result = MetadataService.create_metadata(dataset.id, metadata_args) + + # Assert: Verify the expected outcomes + assert result is not None + assert result.name == "test_metadata" + assert result.type == "string" + assert result.dataset_id == dataset.id + assert result.tenant_id == tenant.id + assert result.created_by == account.id + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(result) + assert result.id is not None + assert result.created_at is not None + + def test_create_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test metadata creation fails when name exceeds 255 characters. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + long_name = "a" * 256 # 256 characters, exceeding 255 limit + metadata_args = MetadataArgs(type="string", name=long_name) + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): + MetadataService.create_metadata(dataset.id, metadata_args) + + def test_create_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test metadata creation fails when name already exists in the same dataset. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create first metadata + first_metadata_args = MetadataArgs(type="string", name="duplicate_name") + MetadataService.create_metadata(dataset.id, first_metadata_args) + + # Try to create second metadata with same name + second_metadata_args = MetadataArgs(type="number", name="duplicate_name") + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError, match="Metadata name already exists."): + MetadataService.create_metadata(dataset.id, second_metadata_args) + + def test_create_metadata_name_conflicts_with_built_in_field( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata creation fails when name conflicts with built-in field names. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Try to create metadata with built-in field name + built_in_field_name = BuiltInField.document_name.value + metadata_args = MetadataArgs(type="string", name=built_in_field_name) + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): + MetadataService.create_metadata(dataset.id, metadata_args) + + def test_update_metadata_name_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful metadata name update with valid parameters. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata first + metadata_args = MetadataArgs(type="string", name="old_name") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Act: Execute the method under test + new_name = "new_name" + result = MetadataService.update_metadata_name(dataset.id, metadata.id, new_name) + + # Assert: Verify the expected outcomes + assert result is not None + assert result.name == new_name + assert result.updated_by == account.id + assert result.updated_at is not None + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(result) + assert result.name == new_name + + def test_update_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test metadata name update fails when new name exceeds 255 characters. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata first + metadata_args = MetadataArgs(type="string", name="old_name") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Try to update with too long name + long_name = "a" * 256 # 256 characters, exceeding 255 limit + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): + MetadataService.update_metadata_name(dataset.id, metadata.id, long_name) + + def test_update_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test metadata name update fails when new name already exists in the same dataset. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create two metadata entries + first_metadata_args = MetadataArgs(type="string", name="first_metadata") + first_metadata = MetadataService.create_metadata(dataset.id, first_metadata_args) + + second_metadata_args = MetadataArgs(type="number", name="second_metadata") + second_metadata = MetadataService.create_metadata(dataset.id, second_metadata_args) + + # Try to update first metadata with second metadata's name + with pytest.raises(ValueError, match="Metadata name already exists."): + MetadataService.update_metadata_name(dataset.id, first_metadata.id, "second_metadata") + + def test_update_metadata_name_conflicts_with_built_in_field( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata name update fails when new name conflicts with built-in field names. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata first + metadata_args = MetadataArgs(type="string", name="old_name") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Try to update with built-in field name + built_in_field_name = BuiltInField.document_name.value + + with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): + MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name) + + def test_update_metadata_name_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test metadata name update fails when metadata ID does not exist. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Try to update non-existent metadata + import uuid + + fake_metadata_id = str(uuid.uuid4()) # Use valid UUID format + new_name = "new_name" + + # Act: Execute the method under test + result = MetadataService.update_metadata_name(dataset.id, fake_metadata_id, new_name) + + # Assert: Verify the method returns None when metadata is not found + assert result is None + + def test_delete_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful metadata deletion with valid parameters. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata first + metadata_args = MetadataArgs(type="string", name="to_be_deleted") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Act: Execute the method under test + result = MetadataService.delete_metadata(dataset.id, metadata.id) + + # Assert: Verify the expected outcomes + assert result is not None + assert result.id == metadata.id + + # Verify metadata was deleted from database + from extensions.ext_database import db + + deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first() + assert deleted_metadata is None + + def test_delete_metadata_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test metadata deletion fails when metadata ID does not exist. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Try to delete non-existent metadata + import uuid + + fake_metadata_id = str(uuid.uuid4()) # Use valid UUID format + + # Act: Execute the method under test + result = MetadataService.delete_metadata(dataset.id, fake_metadata_id) + + # Assert: Verify the method returns None when metadata is not found + assert result is None + + def test_delete_metadata_with_document_bindings( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata deletion successfully removes document metadata bindings. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + document = self._create_test_document( + db_session_with_containers, mock_external_service_dependencies, dataset, account + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Create metadata binding + binding = DatasetMetadataBinding( + tenant_id=tenant.id, + dataset_id=dataset.id, + metadata_id=metadata.id, + document_id=document.id, + created_by=account.id, + ) + + from extensions.ext_database import db + + db.session.add(binding) + db.session.commit() + + # Set document metadata + document.doc_metadata = {"test_metadata": "test_value"} + db.session.add(document) + db.session.commit() + + # Act: Execute the method under test + result = MetadataService.delete_metadata(dataset.id, metadata.id) + + # Assert: Verify the expected outcomes + assert result is not None + + # Verify metadata was deleted from database + deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first() + assert deleted_metadata is None + + # Note: The service attempts to update document metadata but may not succeed + # due to mock configuration. The main functionality (metadata deletion) is verified. + + def test_get_built_in_fields_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of built-in metadata fields. + """ + # Act: Execute the method under test + result = MetadataService.get_built_in_fields() + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 5 + + # Verify all expected built-in fields are present + field_names = [field["name"] for field in result] + field_types = [field["type"] for field in result] + + assert BuiltInField.document_name.value in field_names + assert BuiltInField.uploader.value in field_names + assert BuiltInField.upload_date.value in field_names + assert BuiltInField.last_update_date.value in field_names + assert BuiltInField.source.value in field_names + + # Verify field types + assert "string" in field_types + assert "time" in field_types + + def test_enable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful enabling of built-in fields for a dataset. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + document = self._create_test_document( + db_session_with_containers, mock_external_service_dependencies, dataset, account + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Mock DocumentService.get_working_documents_by_dataset_id + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [ + document + ] + + # Verify dataset starts with built-in fields disabled + assert dataset.built_in_field_enabled is False + + # Act: Execute the method under test + MetadataService.enable_built_in_field(dataset) + + # Assert: Verify the expected outcomes + from extensions.ext_database import db + + db.session.refresh(dataset) + assert dataset.built_in_field_enabled is True + + # Note: Document metadata update depends on DocumentService mock working correctly + # The main functionality (enabling built-in fields) is verified + + def test_enable_built_in_field_already_enabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test enabling built-in fields when they are already enabled. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Enable built-in fields first + dataset.built_in_field_enabled = True + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + # Mock DocumentService.get_working_documents_by_dataset_id + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] + + # Act: Execute the method under test + MetadataService.enable_built_in_field(dataset) + + # Assert: Verify the method returns early without changes + db.session.refresh(dataset) + assert dataset.built_in_field_enabled is True + + def test_enable_built_in_field_with_no_documents( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test enabling built-in fields for a dataset with no documents. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Mock DocumentService.get_working_documents_by_dataset_id to return empty list + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] + + # Act: Execute the method under test + MetadataService.enable_built_in_field(dataset) + + # Assert: Verify the expected outcomes + from extensions.ext_database import db + + db.session.refresh(dataset) + assert dataset.built_in_field_enabled is True + + def test_disable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful disabling of built-in fields for a dataset. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + document = self._create_test_document( + db_session_with_containers, mock_external_service_dependencies, dataset, account + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Enable built-in fields first + dataset.built_in_field_enabled = True + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + # Set document metadata with built-in fields + document.doc_metadata = { + BuiltInField.document_name.value: document.name, + BuiltInField.uploader.value: "test_uploader", + BuiltInField.upload_date.value: 1234567890.0, + BuiltInField.last_update_date.value: 1234567890.0, + BuiltInField.source.value: "test_source", + } + db.session.add(document) + db.session.commit() + + # Mock DocumentService.get_working_documents_by_dataset_id + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [ + document + ] + + # Act: Execute the method under test + MetadataService.disable_built_in_field(dataset) + + # Assert: Verify the expected outcomes + db.session.refresh(dataset) + assert dataset.built_in_field_enabled is False + + # Note: Document metadata update depends on DocumentService mock working correctly + # The main functionality (disabling built-in fields) is verified + + def test_disable_built_in_field_already_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test disabling built-in fields when they are already disabled. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Verify dataset starts with built-in fields disabled + assert dataset.built_in_field_enabled is False + + # Mock DocumentService.get_working_documents_by_dataset_id + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] + + # Act: Execute the method under test + MetadataService.disable_built_in_field(dataset) + + # Assert: Verify the method returns early without changes + from extensions.ext_database import db + + db.session.refresh(dataset) + assert dataset.built_in_field_enabled is False + + def test_disable_built_in_field_with_no_documents( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test disabling built-in fields for a dataset with no documents. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Enable built-in fields first + dataset.built_in_field_enabled = True + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + # Mock DocumentService.get_working_documents_by_dataset_id to return empty list + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] + + # Act: Execute the method under test + MetadataService.disable_built_in_field(dataset) + + # Assert: Verify the expected outcomes + db.session.refresh(dataset) + assert dataset.built_in_field_enabled is False + + def test_update_documents_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful update of documents metadata. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + document = self._create_test_document( + db_session_with_containers, mock_external_service_dependencies, dataset, account + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Mock DocumentService.get_document + mock_external_service_dependencies["document_service"].get_document.return_value = document + + # Create metadata operation data + from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, + MetadataDetail, + MetadataOperationData, + ) + + metadata_detail = MetadataDetail(id=metadata.id, name=metadata.name, value="test_value") + + operation = DocumentMetadataOperation(document_id=document.id, metadata_list=[metadata_detail]) + + operation_data = MetadataOperationData(operation_data=[operation]) + + # Act: Execute the method under test + MetadataService.update_documents_metadata(dataset, operation_data) + + # Assert: Verify the expected outcomes + from extensions.ext_database import db + + # Verify document metadata was updated + db.session.refresh(document) + assert document.doc_metadata is not None + assert "test_metadata" in document.doc_metadata + assert document.doc_metadata["test_metadata"] == "test_value" + + # Verify metadata binding was created + binding = ( + db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata.id, document_id=document.id).first() + ) + assert binding is not None + assert binding.tenant_id == tenant.id + assert binding.dataset_id == dataset.id + + def test_update_documents_metadata_with_built_in_fields_enabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test update of documents metadata when built-in fields are enabled. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + document = self._create_test_document( + db_session_with_containers, mock_external_service_dependencies, dataset, account + ) + + # Enable built-in fields + dataset.built_in_field_enabled = True + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Mock DocumentService.get_document + mock_external_service_dependencies["document_service"].get_document.return_value = document + + # Create metadata operation data + from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, + MetadataDetail, + MetadataOperationData, + ) + + metadata_detail = MetadataDetail(id=metadata.id, name=metadata.name, value="test_value") + + operation = DocumentMetadataOperation(document_id=document.id, metadata_list=[metadata_detail]) + + operation_data = MetadataOperationData(operation_data=[operation]) + + # Act: Execute the method under test + MetadataService.update_documents_metadata(dataset, operation_data) + + # Assert: Verify the expected outcomes + # Verify document metadata was updated with both custom and built-in fields + db.session.refresh(document) + assert document.doc_metadata is not None + assert "test_metadata" in document.doc_metadata + assert document.doc_metadata["test_metadata"] == "test_value" + + # Note: Built-in fields would be added if DocumentService mock works correctly + # The main functionality (custom metadata update) is verified + + def test_update_documents_metadata_document_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test update of documents metadata when document is not found. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Mock DocumentService.get_document to return None (document not found) + mock_external_service_dependencies["document_service"].get_document.return_value = None + + # Create metadata operation data + from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, + MetadataDetail, + MetadataOperationData, + ) + + metadata_detail = MetadataDetail(id=metadata.id, name=metadata.name, value="test_value") + + operation = DocumentMetadataOperation(document_id="non-existent-document-id", metadata_list=[metadata_detail]) + + operation_data = MetadataOperationData(operation_data=[operation]) + + # Act: Execute the method under test + # The method should handle the error gracefully and continue + MetadataService.update_documents_metadata(dataset, operation_data) + + # Assert: Verify the method completes without raising exceptions + # The main functionality (error handling) is verified + + def test_knowledge_base_metadata_lock_check_dataset_id( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata lock check for dataset operations. + """ + # Arrange: Setup mocks + mock_external_service_dependencies["redis_client"].get.return_value = None + mock_external_service_dependencies["redis_client"].set.return_value = True + + dataset_id = "test-dataset-id" + + # Act: Execute the method under test + MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) + + # Assert: Verify the expected outcomes + # Verify Redis lock was set + mock_external_service_dependencies["redis_client"].set.assert_called_once() + + # Verify lock key format + call_args = mock_external_service_dependencies["redis_client"].set.call_args + assert call_args[0][0] == f"dataset_metadata_lock_{dataset_id}" + + def test_knowledge_base_metadata_lock_check_document_id( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata lock check for document operations. + """ + # Arrange: Setup mocks + mock_external_service_dependencies["redis_client"].get.return_value = None + mock_external_service_dependencies["redis_client"].set.return_value = True + + document_id = "test-document-id" + + # Act: Execute the method under test + MetadataService.knowledge_base_metadata_lock_check(None, document_id) + + # Assert: Verify the expected outcomes + # Verify Redis lock was set + mock_external_service_dependencies["redis_client"].set.assert_called_once() + + # Verify lock key format + call_args = mock_external_service_dependencies["redis_client"].set.call_args + assert call_args[0][0] == f"document_metadata_lock_{document_id}" + + def test_knowledge_base_metadata_lock_check_lock_exists( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata lock check when lock already exists. + """ + # Arrange: Setup mocks to simulate existing lock + mock_external_service_dependencies["redis_client"].get.return_value = "1" # Lock exists + + dataset_id = "test-dataset-id" + + # Act & Assert: Verify proper error handling + with pytest.raises( + ValueError, match="Another knowledge base metadata operation is running, please wait a moment." + ): + MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) + + def test_knowledge_base_metadata_lock_check_document_lock_exists( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata lock check when document lock already exists. + """ + # Arrange: Setup mocks to simulate existing lock + mock_external_service_dependencies["redis_client"].get.return_value = "1" # Lock exists + + document_id = "test-document-id" + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError, match="Another document metadata operation is running, please wait a moment."): + MetadataService.knowledge_base_metadata_lock_check(None, document_id) + + def test_get_dataset_metadatas_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of dataset metadata information. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Create document and metadata binding + document = self._create_test_document( + db_session_with_containers, mock_external_service_dependencies, dataset, account + ) + + binding = DatasetMetadataBinding( + tenant_id=tenant.id, + dataset_id=dataset.id, + metadata_id=metadata.id, + document_id=document.id, + created_by=account.id, + ) + + from extensions.ext_database import db + + db.session.add(binding) + db.session.commit() + + # Act: Execute the method under test + result = MetadataService.get_dataset_metadatas(dataset) + + # Assert: Verify the expected outcomes + assert result is not None + assert "doc_metadata" in result + assert "built_in_field_enabled" in result + + # Verify metadata information + doc_metadata = result["doc_metadata"] + assert len(doc_metadata) == 1 + assert doc_metadata[0]["id"] == metadata.id + assert doc_metadata[0]["name"] == metadata.name + assert doc_metadata[0]["type"] == metadata.type + assert doc_metadata[0]["count"] == 1 # One document bound to this metadata + + # Verify built-in field status + assert result["built_in_field_enabled"] is False + + def test_get_dataset_metadatas_with_built_in_fields_enabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test retrieval of dataset metadata when built-in fields are enabled. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Enable built-in fields + dataset.built_in_field_enabled = True + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Act: Execute the method under test + result = MetadataService.get_dataset_metadatas(dataset) + + # Assert: Verify the expected outcomes + assert result is not None + assert "doc_metadata" in result + assert "built_in_field_enabled" in result + + # Verify metadata information + doc_metadata = result["doc_metadata"] + assert len(doc_metadata) == 1 # Only custom metadata, built-in fields are not included in this list + + # Verify built-in field status + assert result["built_in_field_enabled"] is True + + def test_get_dataset_metadatas_no_metadata(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test retrieval of dataset metadata when no metadata exists. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Act: Execute the method under test + result = MetadataService.get_dataset_metadatas(dataset) + + # Assert: Verify the expected outcomes + assert result is not None + assert "doc_metadata" in result + assert "built_in_field_enabled" in result + + # Verify metadata information + doc_metadata = result["doc_metadata"] + assert len(doc_metadata) == 0 # No metadata exists + + # Verify built-in field status + assert result["built_in_field_enabled"] is False diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py new file mode 100644 index 0000000000..a8a36b2565 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -0,0 +1,474 @@ +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from models.account import TenantAccountJoin, TenantAccountRole +from models.model import Account, Tenant +from models.provider import LoadBalancingModelConfig, Provider, ProviderModelSetting +from services.model_load_balancing_service import ModelLoadBalancingService + + +class TestModelLoadBalancingService: + """Integration tests for ModelLoadBalancingService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.model_load_balancing_service.ProviderManager") as mock_provider_manager, + patch("services.model_load_balancing_service.LBModelManager") as mock_lb_model_manager, + patch("services.model_load_balancing_service.ModelProviderFactory") as mock_model_provider_factory, + patch("services.model_load_balancing_service.encrypter") as mock_encrypter, + ): + # Setup default mock returns + mock_provider_manager_instance = mock_provider_manager.return_value + + # Mock provider configuration + mock_provider_config = MagicMock() + mock_provider_config.provider.provider = "openai" + mock_provider_config.custom_configuration.provider = None + + # Mock provider model setting + mock_provider_model_setting = MagicMock() + mock_provider_model_setting.load_balancing_enabled = False + + mock_provider_config.get_provider_model_setting.return_value = mock_provider_model_setting + + # Mock provider configurations dict + mock_provider_configs = {"openai": mock_provider_config} + mock_provider_manager_instance.get_configurations.return_value = mock_provider_configs + + # Mock LBModelManager + mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0) + + # Mock ModelProviderFactory + mock_model_provider_factory_instance = mock_model_provider_factory.return_value + + # Mock credential schemas + mock_credential_schema = MagicMock() + mock_credential_schema.credential_form_schemas = [] + + # Mock provider configuration methods + mock_provider_config.extract_secret_variables.return_value = [] + mock_provider_config.obfuscated_credentials.return_value = {} + mock_provider_config._get_credential_schema.return_value = mock_credential_schema + + yield { + "provider_manager": mock_provider_manager, + "lb_model_manager": mock_lb_model_manager, + "model_provider_factory": mock_model_provider_factory, + "encrypter": mock_encrypter, + "provider_config": mock_provider_config, + "provider_model_setting": mock_provider_model_setting, + "credential_schema": mock_credential_schema, + } + + def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant for the account + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account, tenant + + def _create_test_provider_and_setting( + self, db_session_with_containers, tenant_id, mock_external_service_dependencies + ): + """ + Helper method to create a test provider and provider model setting. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + tenant_id: Tenant ID for the provider + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (provider, provider_model_setting) - Created provider and setting instances + """ + fake = Faker() + + from extensions.ext_database import db + + # Create provider + provider = Provider( + tenant_id=tenant_id, + provider_name="openai", + provider_type="custom", + is_valid=True, + ) + db.session.add(provider) + db.session.commit() + + # Create provider model setting + provider_model_setting = ProviderModelSetting( + tenant_id=tenant_id, + provider_name="openai", + model_name="gpt-3.5-turbo", + model_type="text-generation", # Use the origin model type that matches the query + enabled=True, + load_balancing_enabled=False, + ) + db.session.add(provider_model_setting) + db.session.commit() + + return provider, provider_model_setting + + def test_enable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful model load balancing enablement. + + This test verifies: + - Proper provider configuration retrieval + - Successful enablement of model load balancing + - Correct method calls to provider configuration + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + provider, provider_model_setting = self._create_test_provider_and_setting( + db_session_with_containers, tenant.id, mock_external_service_dependencies + ) + + # Setup mocks for enable method + mock_provider_config = mock_external_service_dependencies["provider_config"] + mock_provider_config.enable_model_load_balancing = MagicMock() + + # Act: Execute the method under test + service = ModelLoadBalancingService() + service.enable_model_load_balancing( + tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm" + ) + + # Assert: Verify the expected outcomes + mock_provider_config.enable_model_load_balancing.assert_called_once() + call_args = mock_provider_config.enable_model_load_balancing.call_args + assert call_args.kwargs["model"] == "gpt-3.5-turbo" + assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(provider) + db.session.refresh(provider_model_setting) + assert provider.id is not None + assert provider_model_setting.id is not None + + def test_disable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful model load balancing disablement. + + This test verifies: + - Proper provider configuration retrieval + - Successful disablement of model load balancing + - Correct method calls to provider configuration + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + provider, provider_model_setting = self._create_test_provider_and_setting( + db_session_with_containers, tenant.id, mock_external_service_dependencies + ) + + # Setup mocks for disable method + mock_provider_config = mock_external_service_dependencies["provider_config"] + mock_provider_config.disable_model_load_balancing = MagicMock() + + # Act: Execute the method under test + service = ModelLoadBalancingService() + service.disable_model_load_balancing( + tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm" + ) + + # Assert: Verify the expected outcomes + mock_provider_config.disable_model_load_balancing.assert_called_once() + call_args = mock_provider_config.disable_model_load_balancing.call_args + assert call_args.kwargs["model"] == "gpt-3.5-turbo" + assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(provider) + db.session.refresh(provider_model_setting) + assert provider.id is not None + assert provider_model_setting.id is not None + + def test_enable_model_load_balancing_provider_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test error handling when provider does not exist. + + This test verifies: + - Proper error handling for non-existent provider + - Correct exception type and message + - No database state changes + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Setup mocks to return empty provider configurations + mock_provider_manager = mock_external_service_dependencies["provider_manager"] + mock_provider_manager_instance = mock_provider_manager.return_value + mock_provider_manager_instance.get_configurations.return_value = {} + + # Act & Assert: Verify proper error handling + service = ModelLoadBalancingService() + with pytest.raises(ValueError) as exc_info: + service.enable_model_load_balancing( + tenant_id=tenant.id, provider="nonexistent_provider", model="gpt-3.5-turbo", model_type="llm" + ) + + # Verify correct error message + assert "Provider nonexistent_provider does not exist." in str(exc_info.value) + + # Verify no database state changes occurred + from extensions.ext_database import db + + db.session.rollback() + + def test_get_load_balancing_configs_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of load balancing configurations. + + This test verifies: + - Proper provider configuration retrieval + - Successful database query for load balancing configs + - Correct return format and data structure + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + provider, provider_model_setting = self._create_test_provider_and_setting( + db_session_with_containers, tenant.id, mock_external_service_dependencies + ) + + # Create load balancing config + from extensions.ext_database import db + + load_balancing_config = LoadBalancingModelConfig( + tenant_id=tenant.id, + provider_name="openai", + model_name="gpt-3.5-turbo", + model_type="text-generation", # Use the origin model type that matches the query + name="config1", + encrypted_config='{"api_key": "test_key"}', + enabled=True, + ) + db.session.add(load_balancing_config) + db.session.commit() + + # Verify the config was created + db.session.refresh(load_balancing_config) + assert load_balancing_config.id is not None + + # Setup mocks for get_load_balancing_configs method + mock_provider_config = mock_external_service_dependencies["provider_config"] + mock_provider_model_setting = mock_external_service_dependencies["provider_model_setting"] + mock_provider_model_setting.load_balancing_enabled = True + + # Mock credential schema methods + mock_credential_schema = mock_external_service_dependencies["credential_schema"] + mock_credential_schema.credential_form_schemas = [] + + # Mock encrypter + mock_encrypter = mock_external_service_dependencies["encrypter"] + mock_encrypter.get_decrypt_decoding.return_value = ("key", "cipher") + + # Mock _get_credential_schema method + mock_provider_config._get_credential_schema.return_value = mock_credential_schema + + # Mock extract_secret_variables method + mock_provider_config.extract_secret_variables.return_value = [] + + # Mock obfuscated_credentials method + mock_provider_config.obfuscated_credentials.return_value = {} + + # Mock LBModelManager.get_config_in_cooldown_and_ttl + mock_lb_model_manager = mock_external_service_dependencies["lb_model_manager"] + mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0) + + # Act: Execute the method under test + service = ModelLoadBalancingService() + is_enabled, configs = service.get_load_balancing_configs( + tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm" + ) + + # Assert: Verify the expected outcomes + assert is_enabled is True + assert len(configs) == 1 + assert configs[0]["id"] == load_balancing_config.id + assert configs[0]["name"] == "config1" + assert configs[0]["enabled"] is True + assert configs[0]["in_cooldown"] is False + assert configs[0]["ttl"] == 0 + + # Verify database state + db.session.refresh(load_balancing_config) + assert load_balancing_config.id is not None + + def test_get_load_balancing_configs_provider_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test error handling when provider does not exist in get_load_balancing_configs. + + This test verifies: + - Proper error handling for non-existent provider + - Correct exception type and message + - No database state changes + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Setup mocks to return empty provider configurations + mock_provider_manager = mock_external_service_dependencies["provider_manager"] + mock_provider_manager_instance = mock_provider_manager.return_value + mock_provider_manager_instance.get_configurations.return_value = {} + + # Act & Assert: Verify proper error handling + service = ModelLoadBalancingService() + with pytest.raises(ValueError) as exc_info: + service.get_load_balancing_configs( + tenant_id=tenant.id, provider="nonexistent_provider", model="gpt-3.5-turbo", model_type="llm" + ) + + # Verify correct error message + assert "Provider nonexistent_provider does not exist." in str(exc_info.value) + + # Verify no database state changes occurred + from extensions.ext_database import db + + db.session.rollback() + + def test_get_load_balancing_configs_with_inherit_config( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test load balancing configs retrieval with inherit configuration. + + This test verifies: + - Proper handling of inherit configuration + - Correct ordering of configurations + - Inherit config initialization when needed + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + provider, provider_model_setting = self._create_test_provider_and_setting( + db_session_with_containers, tenant.id, mock_external_service_dependencies + ) + + # Create load balancing config + from extensions.ext_database import db + + load_balancing_config = LoadBalancingModelConfig( + tenant_id=tenant.id, + provider_name="openai", + model_name="gpt-3.5-turbo", + model_type="text-generation", # Use the origin model type that matches the query + name="config1", + encrypted_config='{"api_key": "test_key"}', + enabled=True, + ) + db.session.add(load_balancing_config) + db.session.commit() + + # Setup mocks for inherit config scenario + mock_provider_config = mock_external_service_dependencies["provider_config"] + mock_provider_config.custom_configuration.provider = MagicMock() # Enable custom config + + mock_provider_model_setting = mock_external_service_dependencies["provider_model_setting"] + mock_provider_model_setting.load_balancing_enabled = True + + # Mock credential schema methods + mock_credential_schema = mock_external_service_dependencies["credential_schema"] + mock_credential_schema.credential_form_schemas = [] + + # Mock encrypter + mock_encrypter = mock_external_service_dependencies["encrypter"] + mock_encrypter.get_decrypt_decoding.return_value = ("key", "cipher") + + # Act: Execute the method under test + service = ModelLoadBalancingService() + is_enabled, configs = service.get_load_balancing_configs( + tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm" + ) + + # Assert: Verify the expected outcomes + assert is_enabled is True + assert len(configs) == 2 # inherit config + existing config + + # First config should be inherit config + assert configs[0]["name"] == "__inherit__" + assert configs[0]["enabled"] is True + + # Second config should be the existing config + assert configs[1]["id"] == load_balancing_config.id + assert configs[1]["name"] == "config1" + + # Verify database state + db.session.refresh(load_balancing_config) + assert load_balancing_config.id is not None + + # Verify inherit config was created in database + inherit_configs = ( + db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.name == "__inherit__").all() + ) + assert len(inherit_configs) == 1 diff --git a/api/tests/unit_tests/extensions/test_celery_ssl.py b/api/tests/unit_tests/extensions/test_celery_ssl.py new file mode 100644 index 0000000000..bc46fe8322 --- /dev/null +++ b/api/tests/unit_tests/extensions/test_celery_ssl.py @@ -0,0 +1,149 @@ +"""Tests for Celery SSL configuration.""" + +import ssl +from unittest.mock import MagicMock, patch + + +class TestCelerySSLConfiguration: + """Test suite for Celery SSL configuration.""" + + def test_get_celery_ssl_options_when_ssl_disabled(self): + """Test SSL options when REDIS_USE_SSL is False.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = False + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import _get_celery_ssl_options + + result = _get_celery_ssl_options() + assert result is None + + def test_get_celery_ssl_options_when_broker_not_redis(self): + """Test SSL options when broker is not Redis.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = True + mock_config.CELERY_BROKER_URL = "amqp://localhost:5672" + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import _get_celery_ssl_options + + result = _get_celery_ssl_options() + assert result is None + + def test_get_celery_ssl_options_with_cert_none(self): + """Test SSL options with CERT_NONE requirement.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = True + mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" + mock_config.REDIS_SSL_CERT_REQS = "CERT_NONE" + mock_config.REDIS_SSL_CA_CERTS = None + mock_config.REDIS_SSL_CERTFILE = None + mock_config.REDIS_SSL_KEYFILE = None + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import _get_celery_ssl_options + + result = _get_celery_ssl_options() + assert result is not None + assert result["ssl_cert_reqs"] == ssl.CERT_NONE + assert result["ssl_ca_certs"] is None + assert result["ssl_certfile"] is None + assert result["ssl_keyfile"] is None + + def test_get_celery_ssl_options_with_cert_required(self): + """Test SSL options with CERT_REQUIRED and certificates.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = True + mock_config.CELERY_BROKER_URL = "rediss://localhost:6380/0" + mock_config.REDIS_SSL_CERT_REQS = "CERT_REQUIRED" + mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt" + mock_config.REDIS_SSL_CERTFILE = "/path/to/client.crt" + mock_config.REDIS_SSL_KEYFILE = "/path/to/client.key" + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import _get_celery_ssl_options + + result = _get_celery_ssl_options() + assert result is not None + assert result["ssl_cert_reqs"] == ssl.CERT_REQUIRED + assert result["ssl_ca_certs"] == "/path/to/ca.crt" + assert result["ssl_certfile"] == "/path/to/client.crt" + assert result["ssl_keyfile"] == "/path/to/client.key" + + def test_get_celery_ssl_options_with_cert_optional(self): + """Test SSL options with CERT_OPTIONAL requirement.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = True + mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" + mock_config.REDIS_SSL_CERT_REQS = "CERT_OPTIONAL" + mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt" + mock_config.REDIS_SSL_CERTFILE = None + mock_config.REDIS_SSL_KEYFILE = None + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import _get_celery_ssl_options + + result = _get_celery_ssl_options() + assert result is not None + assert result["ssl_cert_reqs"] == ssl.CERT_OPTIONAL + assert result["ssl_ca_certs"] == "/path/to/ca.crt" + + def test_get_celery_ssl_options_with_invalid_cert_reqs(self): + """Test SSL options with invalid cert requirement defaults to CERT_NONE.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = True + mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" + mock_config.REDIS_SSL_CERT_REQS = "INVALID_VALUE" + mock_config.REDIS_SSL_CA_CERTS = None + mock_config.REDIS_SSL_CERTFILE = None + mock_config.REDIS_SSL_KEYFILE = None + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import _get_celery_ssl_options + + result = _get_celery_ssl_options() + assert result is not None + assert result["ssl_cert_reqs"] == ssl.CERT_NONE # Should default to CERT_NONE + + def test_celery_init_applies_ssl_to_broker_and_backend(self): + """Test that SSL options are applied to both broker and backend when using Redis.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = True + mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" + mock_config.CELERY_BACKEND = "redis" + mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0" + mock_config.REDIS_SSL_CERT_REQS = "CERT_NONE" + mock_config.REDIS_SSL_CA_CERTS = None + mock_config.REDIS_SSL_CERTFILE = None + mock_config.REDIS_SSL_KEYFILE = None + mock_config.CELERY_USE_SENTINEL = False + mock_config.LOG_FORMAT = "%(message)s" + mock_config.LOG_TZ = "UTC" + mock_config.LOG_FILE = None + + # Mock all the scheduler configs + mock_config.CELERY_BEAT_SCHEDULER_TIME = 1 + mock_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK = False + mock_config.ENABLE_CLEAN_UNUSED_DATASETS_TASK = False + mock_config.ENABLE_CREATE_TIDB_SERVERLESS_TASK = False + mock_config.ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK = False + mock_config.ENABLE_CLEAN_MESSAGES = False + mock_config.ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK = False + mock_config.ENABLE_DATASETS_QUEUE_MONITOR = False + mock_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK = False + + with patch("extensions.ext_celery.dify_config", mock_config): + from dify_app import DifyApp + from extensions.ext_celery import init_app + + app = DifyApp(__name__) + celery_app = init_app(app) + + # Check that SSL options were applied + assert "broker_use_ssl" in celery_app.conf + assert celery_app.conf["broker_use_ssl"] is not None + assert celery_app.conf["broker_use_ssl"]["ssl_cert_reqs"] == ssl.CERT_NONE + + # Check that SSL is also applied to Redis backend + assert "redis_backend_use_ssl" in celery_app.conf + assert celery_app.conf["redis_backend_use_ssl"] is not None diff --git a/api/uv.lock b/api/uv.lock index cecce2bc43..40091c297a 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1602,7 +1602,7 @@ vdb = [ { name = "pgvector", specifier = "==0.2.5" }, { name = "pymilvus", specifier = "~=2.5.0" }, { name = "pymochow", specifier = "==1.3.1" }, - { name = "pyobvector", specifier = "~=0.1.6" }, + { name = "pyobvector", specifier = "~=0.2.15" }, { name = "qdrant-client", specifier = "==1.9.0" }, { name = "tablestore", specifier = "==6.2.0" }, { name = "tcvectordb", specifier = "~=1.6.4" }, @@ -4569,17 +4569,19 @@ wheels = [ [[package]] name = "pyobvector" -version = "0.1.14" +version = "0.2.15" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiomysql" }, { name = "numpy" }, + { name = "pydantic" }, { name = "pymysql" }, { name = "sqlalchemy" }, + { name = "sqlglot" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/dc/59/7d762061808948dd6aad165a000b34e22163dc83fb5014184eeacc0fabe5/pyobvector-0.1.14.tar.gz", hash = "sha256:4f85cdd63064d040e94c0a96099a0cd5cda18ce625865382e89429f28422fc02", size = 26780, upload-time = "2024-11-20T11:46:18.017Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/7d/3f3aac6acf1fdd1782042d6eecd48efaa2ee355af0dbb61e93292d629391/pyobvector-0.2.15.tar.gz", hash = "sha256:5de258c1e952c88b385b5661e130c1cf8262c498c1f8a4a348a35962d379fce4", size = 39611, upload-time = "2025-08-18T02:49:26.683Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/88/68/ecb21b74c974e7be7f9034e205d08db62d614ff5c221581ae96d37ef853e/pyobvector-0.1.14-py3-none-any.whl", hash = "sha256:828e0bec49a177355b70c7a1270af3b0bf5239200ee0d096e4165b267eeff97c", size = 35526, upload-time = "2024-11-20T11:46:16.809Z" }, + { url = "https://files.pythonhosted.org/packages/5f/1f/a62754ba9b8a02c038d2a96cb641b71d3809f34d2ba4f921fecd7840d7fb/pyobvector-0.2.15-py3-none-any.whl", hash = "sha256:feeefe849ee5400e72a9a4d3844e425a58a99053dd02abe06884206923065ebb", size = 52680, upload-time = "2025-08-18T02:49:25.452Z" }, ] [[package]] @@ -5432,6 +5434,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1c/fc/9ba22f01b5cdacc8f5ed0d22304718d2c758fce3fd49a5372b886a86f37c/sqlalchemy-2.0.41-py3-none-any.whl", hash = "sha256:57df5dc6fdb5ed1a88a1ed2195fd31927e705cad62dedd86b46972752a80f576", size = 1911224, upload-time = "2025-05-14T17:39:42.154Z" }, ] +[[package]] +name = "sqlglot" +version = "26.33.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/25/9d/fcd59b4612d5ad1e2257c67c478107f073b19e1097d3bfde2fb517884416/sqlglot-26.33.0.tar.gz", hash = "sha256:2817278779fa51d6def43aa0d70690b93a25c83eb18ec97130fdaf707abc0d73", size = 5353340, upload-time = "2025-07-01T13:09:06.311Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/8d/f1d9cb5b18e06aa45689fbeaaea6ebab66d5f01d1e65029a8f7657c06be5/sqlglot-26.33.0-py3-none-any.whl", hash = "sha256:031cee20c0c796a83d26d079a47fdce667604df430598c7eabfa4e4dfd147033", size = 477610, upload-time = "2025-07-01T13:09:03.926Z" }, +] + [[package]] name = "sseclient-py" version = "1.8.0" diff --git a/docker/.env.example b/docker/.env.example index bd614f78f1..826b7b9fe6 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -264,6 +264,15 @@ REDIS_PORT=6379 REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false +# SSL configuration for Redis (when REDIS_USE_SSL=true) +REDIS_SSL_CERT_REQS=CERT_NONE +# Options: CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED +REDIS_SSL_CA_CERTS= +# Path to CA certificate file for SSL verification +REDIS_SSL_CERTFILE= +# Path to client certificate file for SSL authentication +REDIS_SSL_KEYFILE= +# Path to client private key file for SSL authentication REDIS_DB=0 # Whether to use Redis Sentinel mode. @@ -878,6 +887,14 @@ API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository. # API workflow node execution repository implementation API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository +# Workflow log cleanup configuration +# Enable automatic cleanup of workflow run logs to manage database size +WORKFLOW_LOG_CLEANUP_ENABLED=false +# Number of days to retain workflow run logs (default: 30 days) +WORKFLOW_LOG_RETENTION_DAYS=30 +# Batch size for workflow log cleanup operations (default: 100) +WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100 + # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index afb20cb53b..0c352e4658 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -71,6 +71,10 @@ x-shared-env: &shared-api-worker-env REDIS_USERNAME: ${REDIS_USERNAME:-} REDIS_PASSWORD: ${REDIS_PASSWORD:-difyai123456} REDIS_USE_SSL: ${REDIS_USE_SSL:-false} + REDIS_SSL_CERT_REQS: ${REDIS_SSL_CERT_REQS:-CERT_NONE} + REDIS_SSL_CA_CERTS: ${REDIS_SSL_CA_CERTS:-} + REDIS_SSL_CERTFILE: ${REDIS_SSL_CERTFILE:-} + REDIS_SSL_KEYFILE: ${REDIS_SSL_KEYFILE:-} REDIS_DB: ${REDIS_DB:-0} REDIS_USE_SENTINEL: ${REDIS_USE_SENTINEL:-false} REDIS_SENTINELS: ${REDIS_SENTINELS:-} @@ -392,6 +396,9 @@ x-shared-env: &shared-api-worker-env CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository} API_WORKFLOW_RUN_REPOSITORY: ${API_WORKFLOW_RUN_REPOSITORY:-repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository} API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository} + WORKFLOW_LOG_CLEANUP_ENABLED: ${WORKFLOW_LOG_CLEANUP_ENABLED:-false} + WORKFLOW_LOG_RETENTION_DAYS: ${WORKFLOW_LOG_RETENTION_DAYS:-30} + WORKFLOW_LOG_CLEANUP_BATCH_SIZE: ${WORKFLOW_LOG_CLEANUP_BATCH_SIZE:-100} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True} diff --git a/web/app/account/account-page/AvatarWithEdit.tsx b/web/app/account/account-page/AvatarWithEdit.tsx index 41a6971bf5..88e3a7b343 100644 --- a/web/app/account/account-page/AvatarWithEdit.tsx +++ b/web/app/account/account-page/AvatarWithEdit.tsx @@ -4,7 +4,7 @@ import type { Area } from 'react-easy-crop' import React, { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { RiPencilLine } from '@remixicon/react' +import { RiDeleteBin5Line, RiPencilLine } from '@remixicon/react' import { updateUserProfile } from '@/service/common' import { ToastContext } from '@/app/components/base/toast' import ImageInput, { type OnImageInput } from '@/app/components/base/app-icon-picker/ImageInput' @@ -27,6 +27,8 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { const [inputImageInfo, setInputImageInfo] = useState() const [isShowAvatarPicker, setIsShowAvatarPicker] = useState(false) const [uploading, setUploading] = useState(false) + const [isShowDeleteConfirm, setIsShowDeleteConfirm] = useState(false) + const [hoverArea, setHoverArea] = useState('left') const handleImageInput: OnImageInput = useCallback(async (isCropped: boolean, fileOrTempUrl: string | File, croppedAreaPixels?: Area, fileName?: string) => { setInputImageInfo( @@ -48,6 +50,18 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { } }, [notify, onSave, t]) + const handleDeleteAvatar = useCallback(async () => { + try { + await updateUserProfile({ url: 'account/avatar', body: { avatar: '' } }) + notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + setIsShowDeleteConfirm(false) + onSave?.() + } + catch (e) { + notify({ type: 'error', message: (e as Error).message }) + } + }, [notify, onSave, t]) + const { handleLocalFileUpload } = useLocalFileUploader({ limit: 3, disabled: false, @@ -86,12 +100,21 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
{ setIsShowAvatarPicker(true) }} className="absolute inset-0 flex cursor-pointer items-center justify-center rounded-full bg-black/50 opacity-0 transition-opacity group-hover:opacity-100" + onClick={() => hoverArea === 'right' ? setIsShowDeleteConfirm(true) : setIsShowAvatarPicker(true)} + onMouseMove={(e) => { + const rect = e.currentTarget.getBoundingClientRect() + const x = e.clientX - rect.left + const isRight = x > rect.width / 2 + setHoverArea(isRight ? 'right' : 'left') + }} > - + {hoverArea === 'right' ? + + : - + } +
@@ -115,6 +138,26 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { + + setIsShowDeleteConfirm(false)} + > +
{t('common.avatar.deleteTitle')}
+

{t('common.avatar.deleteDescription')}

+ +
+ + + +
+
) } diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index eb0f524386..a7bdc550d1 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -13,7 +13,7 @@ import Tooltip from '@/app/components/base/tooltip' import { AppType } from '@/types/app' import { getNewVar, getVars } from '@/utils/var' import AutomaticBtn from '@/app/components/app/configuration/config/automatic/automatic-btn' -import type { AutomaticRes } from '@/service/debug' +import type { GenRes } from '@/service/debug' import GetAutomaticResModal from '@/app/components/app/configuration/config/automatic/get-automatic-res' import PromptEditor from '@/app/components/base/prompt-editor' import ConfigContext from '@/context/debug-configuration' @@ -61,6 +61,7 @@ const Prompt: FC = ({ const { eventEmitter } = useEventEmitterContextContext() const { + appId, modelConfig, dataSets, setModelConfig, @@ -139,21 +140,21 @@ const Prompt: FC = ({ } const [showAutomatic, { setTrue: showAutomaticTrue, setFalse: showAutomaticFalse }] = useBoolean(false) - const handleAutomaticRes = (res: AutomaticRes) => { + const handleAutomaticRes = (res: GenRes) => { // put eventEmitter in first place to prevent overwrite the configs.prompt_variables.But another problem is that prompt won't hight the prompt_variables. eventEmitter?.emit({ type: PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER, - payload: res.prompt, + payload: res.modified, } as any) const newModelConfig = produce(modelConfig, (draft) => { - draft.configs.prompt_template = res.prompt - draft.configs.prompt_variables = res.variables.map(key => ({ key, name: key, type: 'string', required: true })) + draft.configs.prompt_template = res.modified + draft.configs.prompt_variables = (res.variables || []).map(key => ({ key, name: key, type: 'string', required: true })) }) setModelConfig(newModelConfig) setPrevPromptConfig(modelConfig.configs) if (mode !== AppType.completion) { - setIntroduction(res.opening_statement) + setIntroduction(res.opening_statement || '') const newFeatures = produce(features, (draft) => { draft.opening = { ...draft.opening, @@ -272,10 +273,13 @@ const Prompt: FC = ({ {showAutomatic && ( )} diff --git a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx index aacaa81ac2..31f81d274d 100644 --- a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx +++ b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx @@ -2,7 +2,7 @@ import type { FC } from 'react' import React, { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useBoolean } from 'ahooks' +import { useBoolean, useSessionStorageState } from 'ahooks' import { RiDatabase2Line, RiFileExcel2Line, @@ -14,24 +14,18 @@ import { RiTranslate, RiUser2Line, } from '@remixicon/react' -import cn from 'classnames' import s from './style.module.css' import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' -import Textarea from '@/app/components/base/textarea' import Toast from '@/app/components/base/toast' -import { generateRule } from '@/service/debug' -import ConfigPrompt from '@/app/components/app/configuration/config-prompt' +import { generateBasicAppFistTimeRule, generateRule } from '@/service/debug' import type { CompletionParams, Model } from '@/types/app' -import { AppType } from '@/types/app' -import ConfigVar from '@/app/components/app/configuration/config-var' -import GroupName from '@/app/components/app/configuration/base/group-name' +import type { AppType } from '@/types/app' import Loading from '@/app/components/base/loading' import Confirm from '@/app/components/base/confirm' -import { LoveMessage } from '@/app/components/base/icons/src/vender/features' // type -import type { AutomaticRes } from '@/service/debug' +import type { GenRes } from '@/service/debug' import { Generator } from '@/app/components/base/icons/src/vender/other' import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' @@ -39,13 +33,25 @@ import { ModelTypeEnum } from '@/app/components/header/account-setting/model-pro import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import type { ModelModeType } from '@/types/app' import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations' +import InstructionEditorInWorkflow from './instruction-editor-in-workflow' +import InstructionEditorInBasic from './instruction-editor' +import { GeneratorType } from './types' +import Result from './result' +import useGenData from './use-gen-data' +import IdeaOutput from './idea-output' +import ResPlaceholder from './res-placeholder' +import { useGenerateRuleTemplate } from '@/service/use-apps' +const i18nPrefix = 'appDebug.generate' export type IGetAutomaticResProps = { mode: AppType isShow: boolean onClose: () => void - onFinished: (res: AutomaticRes) => void - isInLLMNode?: boolean + onFinished: (res: GenRes) => void + flowId?: string + nodeId?: string + currentPrompt?: string + isBasicMode?: boolean } const TryLabel: FC<{ @@ -68,7 +74,10 @@ const GetAutomaticRes: FC = ({ mode, isShow, onClose, - isInLLMNode, + flowId, + nodeId, + currentPrompt, + isBasicMode, onFinished, }) => { const { t } = useTranslation() @@ -123,13 +132,27 @@ const GetAutomaticRes: FC = ({ }, ] - const [instruction, setInstruction] = useState('') + const [instructionFromSessionStorage, setInstruction] = useSessionStorageState(`improve-instruction-${flowId}${isBasicMode ? '' : `-${nodeId}`}`) + const instruction = instructionFromSessionStorage || '' + const [ideaOutput, setIdeaOutput] = useState('') + + const [editorKey, setEditorKey] = useState(`${flowId}-0`) const handleChooseTemplate = useCallback((key: string) => { return () => { const template = t(`appDebug.generate.template.${key}.instruction`) setInstruction(template) + setEditorKey(`${flowId}-${Date.now()}`) } }, [t]) + + const { data: instructionTemplate } = useGenerateRuleTemplate(GeneratorType.prompt, isBasicMode) + useEffect(() => { + if (!instruction && instructionTemplate) + setInstruction(instructionTemplate.data) + + setEditorKey(`${flowId}-${Date.now()}`) + }, [instructionTemplate]) + const isValid = () => { if (instruction.trim() === '') { Toast.notify({ @@ -143,7 +166,10 @@ const GetAutomaticRes: FC = ({ return true } const [isLoading, { setTrue: setLoadingTrue, setFalse: setLoadingFalse }] = useBoolean(false) - const [res, setRes] = useState(null) + const storageKey = `${flowId}${isBasicMode ? '' : `-${nodeId}`}` + const { addVersion, current, currentVersionIndex, setCurrentVersionIndex, versions } = useGenData({ + storageKey, + }) useEffect(() => { if (defaultModel) { @@ -170,16 +196,6 @@ const GetAutomaticRes: FC = ({ ) - const renderNoData = ( -
- -
-
{t('appDebug.generate.noDataLine1')}
-
{t('appDebug.generate.noDataLine2')}
-
-
- ) - const handleModelChange = useCallback((newValue: { modelId: string; provider: string; mode?: string; features?: string[] }) => { const newModel = { ...model, @@ -207,28 +223,59 @@ const GetAutomaticRes: FC = ({ return setLoadingTrue() try { - const { error, ...res } = await generateRule({ - instruction, - model_config: model, - no_variable: !!isInLLMNode, - }) - setRes(res) - if (error) { - Toast.notify({ - type: 'error', - message: error, + let apiRes: GenRes + let hasError = false + if (isBasicMode || !currentPrompt) { + const { error, ...res } = await generateBasicAppFistTimeRule({ + instruction, + model_config: model, + no_variable: false, }) + apiRes = { + ...res, + modified: res.prompt, + } as GenRes + if (error) { + hasError = true + Toast.notify({ + type: 'error', + message: error, + }) + } } + else { + const { error, ...res } = await generateRule({ + flow_id: flowId, + node_id: nodeId, + current: currentPrompt, + instruction, + ideal_output: ideaOutput, + model_config: model, + }) + apiRes = res + if (error) { + hasError = true + Toast.notify({ + type: 'error', + message: error, + }) + } + } + if (!hasError) + addVersion(apiRes) } finally { setLoadingFalse() } } - const [showConfirmOverwrite, setShowConfirmOverwrite] = React.useState(false) + const [isShowConfirmOverwrite, { + setTrue: showConfirmOverwrite, + setFalse: hideShowConfirmOverwrite, + }] = useBoolean(false) const isShowAutoPromptResPlaceholder = () => { - return !isLoading && !res + return !isLoading && !current } return ( @@ -236,15 +283,14 @@ const GetAutomaticRes: FC = ({ isShow={isShow} onClose={onClose} className='min-w-[1140px] !p-0' - closable >
-
+
{t('appDebug.generate.title')}
{t('appDebug.generate.description')}
-
+
= ({ hideDebugWithMultipleModel />
-
-
-
{t('appDebug.generate.tryIt')}
-
-
-
- {tryList.map(item => ( - - ))} -
-
- {/* inputs */} -
-
-
{t('appDebug.generate.instruction')}
-