diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 94e5b0f969..d6f326d4dc 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -9,6 +9,14 @@ # Backend (default owner, more specific rules below will override) api/ @QuantumGhost +# Backend - MCP +api/core/mcp/ @Nov1c444 +api/core/entities/mcp_provider.py @Nov1c444 +api/services/tools/mcp_tools_manage_service.py @Nov1c444 +api/controllers/mcp/ @Nov1c444 +api/controllers/console/app/mcp_server.py @Nov1c444 +api/tests/**/*mcp* @Nov1c444 + # Backend - Workflow - Engine (Core graph execution engine) api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost api/core/workflow/runtime/ @laipz8200 @QuantumGhost diff --git a/.github/ISSUE_TEMPLATE/refactor.yml b/.github/ISSUE_TEMPLATE/refactor.yml index cf74dcc546..dbe8cbb602 100644 --- a/.github/ISSUE_TEMPLATE/refactor.yml +++ b/.github/ISSUE_TEMPLATE/refactor.yml @@ -1,8 +1,6 @@ -name: "✨ Refactor" -description: Refactor existing code for improved readability and maintainability. -title: "[Chore/Refactor] " -labels: - - refactor +name: "✨ Refactor or Chore" +description: Refactor existing code or perform maintenance chores to improve readability and reliability. +title: "[Refactor/Chore] " body: - type: checkboxes attributes: @@ -11,7 +9,7 @@ body: options: - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542). required: true - - label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general). + - label: This is only for refactors or chores; if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general). required: true - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. required: true @@ -25,14 +23,14 @@ body: id: description attributes: label: Description - placeholder: "Describe the refactor you are proposing." + placeholder: "Describe the refactor or chore you are proposing." validations: required: true - type: textarea id: motivation attributes: label: Motivation - placeholder: "Explain why this refactor is necessary." + placeholder: "Explain why this refactor or chore is necessary." validations: required: false - type: textarea diff --git a/.github/ISSUE_TEMPLATE/tracker.yml b/.github/ISSUE_TEMPLATE/tracker.yml deleted file mode 100644 index 35fedefc75..0000000000 --- a/.github/ISSUE_TEMPLATE/tracker.yml +++ /dev/null @@ -1,13 +0,0 @@ -name: "👾 Tracker" -description: For inner usages, please do not use this template. -title: "[Tracker] " -labels: - - tracker -body: - - type: textarea - id: content - attributes: - label: Blockers - placeholder: "- [ ] ..." - validations: - required: true diff --git a/.github/workflows/semantic-pull-request.yml b/.github/workflows/semantic-pull-request.yml new file mode 100644 index 0000000000..b15c26a096 --- /dev/null +++ b/.github/workflows/semantic-pull-request.yml @@ -0,0 +1,21 @@ +name: Semantic Pull Request + +on: + pull_request: + types: + - opened + - edited + - reopened + - synchronize + +jobs: + lint: + name: Validate PR title + permissions: + pull-requests: read + runs-on: ubuntu-latest + steps: + - name: Check title + uses: amannn/action-semantic-pull-request@v6.1.1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index e652657705..5a8a34be79 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -106,7 +106,7 @@ jobs: - name: Web type check if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web - run: pnpm run type-check + run: pnpm run type-check:tsgo docker-compose-template: name: Docker Compose Template diff --git a/.nvmrc b/.nvmrc new file mode 100644 index 0000000000..7af24b7ddb --- /dev/null +++ b/.nvmrc @@ -0,0 +1 @@ +22.11.0 diff --git a/AGENTS.md b/AGENTS.md index 2ef7931efc..782861ad36 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -24,8 +24,8 @@ The codebase is split into: ```bash cd web -pnpm lint pnpm lint:fix +pnpm type-check:tsgo pnpm test ``` @@ -39,7 +39,7 @@ pnpm test ## Language Style - **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). -- **TypeScript**: Use the strict config, lean on ESLint + Prettier workflows, and avoid `any` types. +- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types. ## General Practices diff --git a/README.md b/README.md index 09ba1f634b..b71764a214 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,19 @@ Star Dify on GitHub and be instantly notified of new releases. If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). +#### Customizing Suggested Questions + +You can now customize the "Suggested Questions After Answer" feature to better fit your use case. For example, to generate longer, more technical questions: + +```bash +# In your .env file +SUGGESTED_QUESTIONS_PROMPT='Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: ["question1","question2","question3","question4","question5"]' +SUGGESTED_QUESTIONS_MAX_TOKENS=512 +SUGGESTED_QUESTIONS_TEMPERATURE=0.3 +``` + +See the [Suggested Questions Configuration Guide](docs/suggested-questions-configuration.md) for detailed examples and usage instructions. + ### Metrics Monitoring with Grafana Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more. diff --git a/api/.env.example b/api/.env.example index 50607f5b35..516a119d98 100644 --- a/api/.env.example +++ b/api/.env.example @@ -633,8 +633,30 @@ SWAGGER_UI_PATH=/swagger-ui.html # Set to false to export dataset IDs as plain text for easier cross-environment import DSL_EXPORT_ENCRYPT_DATASET_ID=true +# Suggested Questions After Answer Configuration +# These environment variables allow customization of the suggested questions feature +# +# Custom prompt for generating suggested questions (optional) +# If not set, uses the default prompt that generates 3 questions under 20 characters each +# Example: "Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: [\"question1\",\"question2\",\"question3\",\"question4\",\"question5\"]" +# SUGGESTED_QUESTIONS_PROMPT= + +# Maximum number of tokens for suggested questions generation (default: 256) +# Adjust this value for longer questions or more questions +# SUGGESTED_QUESTIONS_MAX_TOKENS=256 + +# Temperature for suggested questions generation (default: 0.0) +# Higher values (0.5-1.0) produce more creative questions, lower values (0.0-0.3) produce more focused questions +# SUGGESTED_QUESTIONS_TEMPERATURE=0 + # Tenant isolated task queue configuration TENANT_ISOLATED_TASK_CONCURRENCY=1 # Maximum number of segments for dataset segments API (0 for unlimited) DATASET_MAX_SEGMENTS_PER_REQUEST=0 + +# Multimodal knowledgebase limit +SINGLE_CHUNK_ATTACHMENT_LIMIT=10 +ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2 +ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60 +IMAGE_FILE_BATCH_LIMIT=10 diff --git a/api/.ruff.toml b/api/.ruff.toml index 5a29e1d8fa..7206f7fa0f 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -36,17 +36,20 @@ select = [ "UP", # pyupgrade rules "W191", # tab-indentation "W605", # invalid-escape-sequence + "G001", # don't use str format to logging messages + "G003", # don't use + in logging messages + "G004", # don't use f-strings to format logging messages + "UP042", # use StrEnum, + "S110", # disallow the try-except-pass pattern. + # security related linting rules # RCE proctection (sort of) "S102", # exec-builtin, disallow use of `exec` "S307", # suspicious-eval-usage, disallow use of `eval` and `ast.literal_eval` "S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers. "S302", # suspicious-marshal-usage, disallow use of `marshal` module - "S311", # suspicious-non-cryptographic-random-usage - "G001", # don't use str format to logging messages - "G003", # don't use + in logging messages - "G004", # don't use f-strings to format logging messages - "UP042", # use StrEnum + "S311", # suspicious-non-cryptographic-random-usage, + ] ignore = [ @@ -91,18 +94,16 @@ ignore = [ "configs/*" = [ "N802", # invalid-function-name ] -"core/model_runtime/callbacks/base_callback.py" = [ - "T201", -] -"core/workflow/callbacks/workflow_logging_callback.py" = [ - "T201", -] +"core/model_runtime/callbacks/base_callback.py" = ["T201"] +"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"] "libs/gmpy2_pkcs10aep_cipher.py" = [ "N803", # invalid-argument-name ] "tests/*" = [ "F811", # redefined-while-unused - "T201", # allow print in tests + "T201", # allow print in tests, + "S110", # allow ignoring exceptions in tests code (currently) + ] [lint.pyflakes] diff --git a/api/app_factory.py b/api/app_factory.py index ad2065682c..3a3ee03cff 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -1,6 +1,8 @@ import logging import time +from opentelemetry.trace import get_current_span + from configs import dify_config from contexts.wrapper import RecyclableContextVar from dify_app import DifyApp @@ -26,8 +28,25 @@ def create_flask_app_with_configs() -> DifyApp: # add an unique identifier to each request RecyclableContextVar.increment_thread_recycles() + # add after request hook for injecting X-Trace-Id header from OpenTelemetry span context + @dify_app.after_request + def add_trace_id_header(response): + try: + span = get_current_span() + ctx = span.get_span_context() if span else None + if ctx and ctx.is_valid: + trace_id_hex = format(ctx.trace_id, "032x") + # Avoid duplicates if some middleware added it + if "X-Trace-Id" not in response.headers: + response.headers["X-Trace-Id"] = trace_id_hex + except Exception: + # Never break the response due to tracing header injection + logger.warning("Failed to add trace ID to response header", exc_info=True) + return response + # Capture the decorator's return value to avoid pyright reportUnusedFunction _ = before_request + _ = add_trace_id_header return dify_app diff --git a/api/commands.py b/api/commands.py index e15c996a34..a8d89ac200 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1139,6 +1139,7 @@ def remove_orphaned_files_on_storage(force: bool): click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) except Exception as e: click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red")) + return all_files_on_storage = [] for storage_path in storage_paths: diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 9c0c48c955..a5916241df 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -360,6 +360,26 @@ class FileUploadConfig(BaseSettings): default=10, ) + IMAGE_FILE_BATCH_LIMIT: PositiveInt = Field( + description="Maximum number of files allowed in a image batch upload operation", + default=10, + ) + + SINGLE_CHUNK_ATTACHMENT_LIMIT: PositiveInt = Field( + description="Maximum number of files allowed in a single chunk attachment", + default=10, + ) + + ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field( + description="Maximum allowed image file size for attachments in megabytes", + default=2, + ) + + ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: NonNegativeInt = Field( + description="Timeout for downloading image attachments in seconds", + default=60, + ) + inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field( description=( "Comma-separated list of file extensions that are blocked from upload. " @@ -553,7 +573,10 @@ class LoggingConfig(BaseSettings): LOG_FORMAT: str = Field( description="Format string for log messages", - default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s", + default=( + "%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] " + "[%(filename)s:%(lineno)d] %(trace_id)s - %(message)s" + ), ) LOG_DATEFORMAT: str | None = Field( diff --git a/api/controllers/common/schema.py b/api/controllers/common/schema.py new file mode 100644 index 0000000000..e0896a8dc2 --- /dev/null +++ b/api/controllers/common/schema.py @@ -0,0 +1,26 @@ +"""Helpers for registering Pydantic models with Flask-RESTX namespaces.""" + +from flask_restx import Namespace +from pydantic import BaseModel + +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None: + """Register a single BaseModel with a namespace for Swagger documentation.""" + + namespace.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + + +def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None: + """Register multiple BaseModels with a namespace.""" + + for model in models: + register_schema_model(namespace, model) + + +__all__ = [ + "DEFAULT_REF_TEMPLATE_SWAGGER_2_0", + "register_schema_model", + "register_schema_models", +] diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index da9282cd0c..7aa1e6dbd8 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -3,7 +3,8 @@ from functools import wraps from typing import ParamSpec, TypeVar from flask import request -from flask_restx import Resource, fields, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound, Unauthorized @@ -18,6 +19,30 @@ from extensions.ext_database import db from libs.token import extract_access_token from models.model import App, InstalledApp, RecommendedApp +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class InsertExploreAppPayload(BaseModel): + app_id: str = Field(...) + desc: str | None = None + copyright: str | None = None + privacy_policy: str | None = None + custom_disclaimer: str | None = None + language: str = Field(...) + category: str = Field(...) + position: int = Field(...) + + @field_validator("language") + @classmethod + def validate_language(cls, value: str) -> str: + return supported_language(value) + + +console_ns.schema_model( + InsertExploreAppPayload.__name__, + InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + def admin_required(view: Callable[P, R]): @wraps(view) @@ -40,59 +65,34 @@ def admin_required(view: Callable[P, R]): class InsertExploreAppListApi(Resource): @console_ns.doc("insert_explore_app") @console_ns.doc(description="Insert or update an app in the explore list") - @console_ns.expect( - console_ns.model( - "InsertExploreAppRequest", - { - "app_id": fields.String(required=True, description="Application ID"), - "desc": fields.String(description="App description"), - "copyright": fields.String(description="Copyright information"), - "privacy_policy": fields.String(description="Privacy policy"), - "custom_disclaimer": fields.String(description="Custom disclaimer"), - "language": fields.String(required=True, description="Language code"), - "category": fields.String(required=True, description="App category"), - "position": fields.Integer(required=True, description="Display position"), - }, - ) - ) + @console_ns.expect(console_ns.models[InsertExploreAppPayload.__name__]) @console_ns.response(200, "App updated successfully") @console_ns.response(201, "App inserted successfully") @console_ns.response(404, "App not found") @only_edition_cloud @admin_required def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("app_id", type=str, required=True, nullable=False, location="json") - .add_argument("desc", type=str, location="json") - .add_argument("copyright", type=str, location="json") - .add_argument("privacy_policy", type=str, location="json") - .add_argument("custom_disclaimer", type=str, location="json") - .add_argument("language", type=supported_language, required=True, nullable=False, location="json") - .add_argument("category", type=str, required=True, nullable=False, location="json") - .add_argument("position", type=int, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + payload = InsertExploreAppPayload.model_validate(console_ns.payload) - app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none() + app = db.session.execute(select(App).where(App.id == payload.app_id)).scalar_one_or_none() if not app: - raise NotFound(f"App '{args['app_id']}' is not found") + raise NotFound(f"App '{payload.app_id}' is not found") site = app.site if not site: - desc = args["desc"] or "" - copy_right = args["copyright"] or "" - privacy_policy = args["privacy_policy"] or "" - custom_disclaimer = args["custom_disclaimer"] or "" + desc = payload.desc or "" + copy_right = payload.copyright or "" + privacy_policy = payload.privacy_policy or "" + custom_disclaimer = payload.custom_disclaimer or "" else: - desc = site.description or args["desc"] or "" - copy_right = site.copyright or args["copyright"] or "" - privacy_policy = site.privacy_policy or args["privacy_policy"] or "" - custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or "" + desc = site.description or payload.desc or "" + copy_right = site.copyright or payload.copyright or "" + privacy_policy = site.privacy_policy or payload.privacy_policy or "" + custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or "" with Session(db.engine) as session: recommended_app = session.execute( - select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]) + select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id) ).scalar_one_or_none() if not recommended_app: @@ -102,9 +102,9 @@ class InsertExploreAppListApi(Resource): copyright=copy_right, privacy_policy=privacy_policy, custom_disclaimer=custom_disclaimer, - language=args["language"], - category=args["category"], - position=args["position"], + language=payload.language, + category=payload.category, + position=payload.position, ) db.session.add(recommended_app) @@ -118,9 +118,9 @@ class InsertExploreAppListApi(Resource): recommended_app.copyright = copy_right recommended_app.privacy_policy = privacy_policy recommended_app.custom_disclaimer = custom_disclaimer - recommended_app.language = args["language"] - recommended_app.category = args["category"] - recommended_app.position = args["position"] + recommended_app.language = payload.language + recommended_app.category = payload.category + recommended_app.position = payload.position app.is_public = True diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index 7e31d0a844..cfdb9cf417 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -1,4 +1,6 @@ -from flask_restx import Resource, fields, reqparse +from flask import request +from flask_restx import Resource, fields +from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.app.wraps import get_app_model @@ -8,10 +10,21 @@ from libs.login import login_required from models.model import AppMode from services.agent_service import AgentService -parser = ( - reqparse.RequestParser() - .add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID") - .add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID") +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class AgentLogQuery(BaseModel): + message_id: str = Field(..., description="Message UUID") + conversation_id: str = Field(..., description="Conversation UUID") + + @field_validator("message_id", "conversation_id") + @classmethod + def validate_uuid(cls, value: str) -> str: + return uuid_value(value) + + +console_ns.schema_model( + AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) ) @@ -20,7 +33,7 @@ class AgentLogApi(Resource): @console_ns.doc("get_agent_logs") @console_ns.doc(description="Get agent execution logs for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[AgentLogQuery.__name__]) @console_ns.response( 200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")) ) @@ -31,6 +44,6 @@ class AgentLogApi(Resource): @get_app_model(mode=[AppMode.AGENT_CHAT]) def get(self, app_model): """Get agent logs""" - args = parser.parse_args() + args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"]) + return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index edf0cc2cec..3b6fb58931 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,7 +1,8 @@ -from typing import Literal +from typing import Any, Literal from flask import request -from flask_restx import Resource, fields, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, marshal, marshal_with +from pydantic import BaseModel, Field, field_validator from controllers.common.errors import NoFileUploadedError, TooManyFilesError from controllers.console import console_ns @@ -21,22 +22,79 @@ from libs.helper import uuid_value from libs.login import login_required from services.annotation_service import AppAnnotationService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class AnnotationReplyPayload(BaseModel): + score_threshold: float = Field(..., description="Score threshold for annotation matching") + embedding_provider_name: str = Field(..., description="Embedding provider name") + embedding_model_name: str = Field(..., description="Embedding model name") + + +class AnnotationSettingUpdatePayload(BaseModel): + score_threshold: float = Field(..., description="Score threshold") + + +class AnnotationListQuery(BaseModel): + page: int = Field(default=1, ge=1, description="Page number") + limit: int = Field(default=20, ge=1, description="Page size") + keyword: str = Field(default="", description="Search keyword") + + +class CreateAnnotationPayload(BaseModel): + message_id: str | None = Field(default=None, description="Message ID") + question: str | None = Field(default=None, description="Question text") + answer: str | None = Field(default=None, description="Answer text") + content: str | None = Field(default=None, description="Content text") + annotation_reply: dict[str, Any] | None = Field(default=None, description="Annotation reply data") + + @field_validator("message_id") + @classmethod + def validate_message_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class UpdateAnnotationPayload(BaseModel): + question: str | None = None + answer: str | None = None + content: str | None = None + annotation_reply: dict[str, Any] | None = None + + +class AnnotationReplyStatusQuery(BaseModel): + action: Literal["enable", "disable"] + + +class AnnotationFilePayload(BaseModel): + message_id: str = Field(..., description="Message ID") + + @field_validator("message_id") + @classmethod + def validate_message_id(cls, value: str) -> str: + return uuid_value(value) + + +def reg(model: type[BaseModel]) -> None: + console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + + +reg(AnnotationReplyPayload) +reg(AnnotationSettingUpdatePayload) +reg(AnnotationListQuery) +reg(CreateAnnotationPayload) +reg(UpdateAnnotationPayload) +reg(AnnotationReplyStatusQuery) +reg(AnnotationFilePayload) + @console_ns.route("/apps//annotation-reply/") class AnnotationReplyActionApi(Resource): @console_ns.doc("annotation_reply_action") @console_ns.doc(description="Enable or disable annotation reply for an app") @console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"}) - @console_ns.expect( - console_ns.model( - "AnnotationReplyActionRequest", - { - "score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"), - "embedding_provider_name": fields.String(required=True, description="Embedding provider name"), - "embedding_model_name": fields.String(required=True, description="Embedding model name"), - }, - ) - ) + @console_ns.expect(console_ns.models[AnnotationReplyPayload.__name__]) @console_ns.response(200, "Action completed successfully") @console_ns.response(403, "Insufficient permissions") @setup_required @@ -46,15 +104,9 @@ class AnnotationReplyActionApi(Resource): @edit_permission_required def post(self, app_id, action: Literal["enable", "disable"]): app_id = str(app_id) - parser = ( - reqparse.RequestParser() - .add_argument("score_threshold", required=True, type=float, location="json") - .add_argument("embedding_provider_name", required=True, type=str, location="json") - .add_argument("embedding_model_name", required=True, type=str, location="json") - ) - args = parser.parse_args() + args = AnnotationReplyPayload.model_validate(console_ns.payload) if action == "enable": - result = AppAnnotationService.enable_app_annotation(args, app_id) + result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id) elif action == "disable": result = AppAnnotationService.disable_app_annotation(app_id) return result, 200 @@ -82,16 +134,7 @@ class AppAnnotationSettingUpdateApi(Resource): @console_ns.doc("update_annotation_setting") @console_ns.doc(description="Update annotation settings for an app") @console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"}) - @console_ns.expect( - console_ns.model( - "AnnotationSettingUpdateRequest", - { - "score_threshold": fields.Float(required=True, description="Score threshold"), - "embedding_provider_name": fields.String(required=True, description="Embedding provider"), - "embedding_model_name": fields.String(required=True, description="Embedding model"), - }, - ) - ) + @console_ns.expect(console_ns.models[AnnotationSettingUpdatePayload.__name__]) @console_ns.response(200, "Settings updated successfully") @console_ns.response(403, "Insufficient permissions") @setup_required @@ -102,10 +145,9 @@ class AppAnnotationSettingUpdateApi(Resource): app_id = str(app_id) annotation_setting_id = str(annotation_setting_id) - parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json") - args = parser.parse_args() + args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload) - result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args) + result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump()) return result, 200 @@ -142,12 +184,7 @@ class AnnotationApi(Resource): @console_ns.doc("list_annotations") @console_ns.doc(description="Get annotations for an app with pagination") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.parser() - .add_argument("page", type=int, location="args", default=1, help="Page number") - .add_argument("limit", type=int, location="args", default=20, help="Page size") - .add_argument("keyword", type=str, location="args", default="", help="Search keyword") - ) + @console_ns.expect(console_ns.models[AnnotationListQuery.__name__]) @console_ns.response(200, "Annotations retrieved successfully") @console_ns.response(403, "Insufficient permissions") @setup_required @@ -155,9 +192,10 @@ class AnnotationApi(Resource): @account_initialization_required @edit_permission_required def get(self, app_id): - page = request.args.get("page", default=1, type=int) - limit = request.args.get("limit", default=20, type=int) - keyword = request.args.get("keyword", default="", type=str) + args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + page = args.page + limit = args.limit + keyword = args.keyword app_id = str(app_id) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) @@ -173,18 +211,7 @@ class AnnotationApi(Resource): @console_ns.doc("create_annotation") @console_ns.doc(description="Create a new annotation for an app") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "CreateAnnotationRequest", - { - "message_id": fields.String(description="Message ID (optional)"), - "question": fields.String(description="Question text (required when message_id not provided)"), - "answer": fields.String(description="Answer text (use 'answer' or 'content')"), - "content": fields.String(description="Content text (use 'answer' or 'content')"), - "annotation_reply": fields.Raw(description="Annotation reply data"), - }, - ) - ) + @console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__]) @console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns)) @console_ns.response(403, "Insufficient permissions") @setup_required @@ -195,16 +222,9 @@ class AnnotationApi(Resource): @edit_permission_required def post(self, app_id): app_id = str(app_id) - parser = ( - reqparse.RequestParser() - .add_argument("message_id", required=False, type=uuid_value, location="json") - .add_argument("question", required=False, type=str, location="json") - .add_argument("answer", required=False, type=str, location="json") - .add_argument("content", required=False, type=str, location="json") - .add_argument("annotation_reply", required=False, type=dict, location="json") - ) - args = parser.parse_args() - annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id) + args = CreateAnnotationPayload.model_validate(console_ns.payload) + data = args.model_dump(exclude_none=True) + annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id) return annotation @setup_required @@ -256,13 +276,6 @@ class AnnotationExportApi(Resource): return response, 200 -parser = ( - reqparse.RequestParser() - .add_argument("question", required=True, type=str, location="json") - .add_argument("answer", required=True, type=str, location="json") -) - - @console_ns.route("/apps//annotations/") class AnnotationUpdateDeleteApi(Resource): @console_ns.doc("update_delete_annotation") @@ -271,7 +284,7 @@ class AnnotationUpdateDeleteApi(Resource): @console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns)) @console_ns.response(204, "Annotation deleted successfully") @console_ns.response(403, "Insufficient permissions") - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -281,8 +294,10 @@ class AnnotationUpdateDeleteApi(Resource): def post(self, app_id, annotation_id): app_id = str(app_id) annotation_id = str(annotation_id) - args = parser.parse_args() - annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id) + args = UpdateAnnotationPayload.model_validate(console_ns.payload) + annotation = AppAnnotationService.update_app_annotation_directly( + args.model_dump(exclude_none=True), app_id, annotation_id + ) return annotation @setup_required diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index d6adacd84d..62e997dae2 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -31,7 +31,6 @@ from fields.app_fields import ( from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict from libs.helper import AppIconUrlField, TimestampField from libs.login import current_account_with_tenant, login_required -from libs.validators import validate_description_length from models import App, Workflow from services.app_dsl_service import AppDslService, ImportMode from services.app_service import AppService @@ -76,51 +75,30 @@ class AppListQuery(BaseModel): class CreateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") - description: str | None = Field(default=None, description="App description (max 400 chars)") + description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode") icon_type: str | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") - @field_validator("description") - @classmethod - def validate_description(cls, value: str | None) -> str | None: - if value is None: - return value - return validate_description_length(value) - class UpdateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") - description: str | None = Field(default=None, description="App description (max 400 chars)") + description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) icon_type: str | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon") max_active_requests: int | None = Field(default=None, description="Maximum active requests") - @field_validator("description") - @classmethod - def validate_description(cls, value: str | None) -> str | None: - if value is None: - return value - return validate_description_length(value) - class CopyAppPayload(BaseModel): name: str | None = Field(default=None, description="Name for the copied app") - description: str | None = Field(default=None, description="Description for the copied app") + description: str | None = Field(default=None, description="Description for the copied app", max_length=400) icon_type: str | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") - @field_validator("description") - @classmethod - def validate_description(cls, value: str | None) -> str | None: - if value is None: - return value - return validate_description_length(value) - class AppExportQuery(BaseModel): include_secret: bool = Field(default=False, description="Include secrets in export") @@ -146,7 +124,14 @@ class AppApiStatusPayload(BaseModel): class AppTracePayload(BaseModel): enabled: bool = Field(..., description="Enable or disable tracing") - tracing_provider: str = Field(..., description="Tracing provider") + tracing_provider: str | None = Field(default=None, description="Tracing provider") + + @field_validator("tracing_provider") + @classmethod + def validate_tracing_provider(cls, value: str | None, info) -> str | None: + if info.data.get("enabled") and not value: + raise ValueError("tracing_provider is required when enabled is True") + return value def reg(cls: type[BaseModel]): @@ -324,10 +309,13 @@ class AppListApi(Resource): NodeType.TRIGGER_PLUGIN, } for workflow in draft_workflows: - for _, node_data in workflow.walk_nodes(): - if node_data.get("type") in trigger_node_types: - draft_trigger_app_ids.add(str(workflow.app_id)) - break + try: + for _, node_data in workflow.walk_nodes(): + if node_data.get("type") in trigger_node_types: + draft_trigger_app_ids.add(str(workflow.app_id)) + break + except Exception: + continue for app in app_pagination.items: app.has_draft_trigger = str(app.id) in draft_trigger_app_ids diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 1b02edd489..22e2aeb720 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -1,4 +1,5 @@ -from flask_restx import Resource, fields, marshal_with, reqparse +from flask_restx import Resource, fields, marshal_with +from pydantic import BaseModel, Field from sqlalchemy.orm import Session from controllers.console.app.wraps import get_app_model @@ -35,23 +36,29 @@ app_import_check_dependencies_model = console_ns.model( "AppImportCheckDependencies", app_import_check_dependencies_fields_copy ) -parser = ( - reqparse.RequestParser() - .add_argument("mode", type=str, required=True, location="json") - .add_argument("yaml_content", type=str, location="json") - .add_argument("yaml_url", type=str, location="json") - .add_argument("name", type=str, location="json") - .add_argument("description", type=str, location="json") - .add_argument("icon_type", type=str, location="json") - .add_argument("icon", type=str, location="json") - .add_argument("icon_background", type=str, location="json") - .add_argument("app_id", type=str, location="json") +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class AppImportPayload(BaseModel): + mode: str = Field(..., description="Import mode") + yaml_content: str | None = None + yaml_url: str | None = None + name: str | None = None + description: str | None = None + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + app_id: str | None = None + + +console_ns.schema_model( + AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) ) @console_ns.route("/apps/imports") class AppImportApi(Resource): - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[AppImportPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -61,7 +68,7 @@ class AppImportApi(Resource): def post(self): # Check user role first current_user, _ = current_account_with_tenant() - args = parser.parse_args() + args = AppImportPayload.model_validate(console_ns.payload) # Create service with session with Session(db.engine) as session: @@ -70,15 +77,15 @@ class AppImportApi(Resource): account = current_user result = import_service.import_app( account=account, - import_mode=args["mode"], - yaml_content=args.get("yaml_content"), - yaml_url=args.get("yaml_url"), - name=args.get("name"), - description=args.get("description"), - icon_type=args.get("icon_type"), - icon=args.get("icon"), - icon_background=args.get("icon_background"), - app_id=args.get("app_id"), + import_mode=args.mode, + yaml_content=args.yaml_content, + yaml_url=args.yaml_url, + name=args.name, + description=args.description, + icon_type=args.icon_type, + icon=args.icon, + icon_background=args.icon_background, + app_id=args.app_id, ) session.commit() if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 86446f1164..d344ede466 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -1,7 +1,8 @@ import logging from flask import request -from flask_restx import Resource, fields, reqparse +from flask_restx import Resource, fields +from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError import services @@ -32,6 +33,27 @@ from services.errors.audio import ( ) logger = logging.getLogger(__name__) +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class TextToSpeechPayload(BaseModel): + message_id: str | None = Field(default=None, description="Message ID") + text: str = Field(..., description="Text to convert") + voice: str | None = Field(default=None, description="Voice name") + streaming: bool | None = Field(default=None, description="Whether to stream audio") + + +class TextToSpeechVoiceQuery(BaseModel): + language: str = Field(..., description="Language code") + + +console_ns.schema_model( + TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) +console_ns.schema_model( + TextToSpeechVoiceQuery.__name__, + TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) @console_ns.route("/apps//audio-to-text") @@ -92,17 +114,7 @@ class ChatMessageTextApi(Resource): @console_ns.doc("chat_message_text_to_speech") @console_ns.doc(description="Convert text to speech for chat messages") @console_ns.doc(params={"app_id": "App ID"}) - @console_ns.expect( - console_ns.model( - "TextToSpeechRequest", - { - "message_id": fields.String(description="Message ID"), - "text": fields.String(required=True, description="Text to convert to speech"), - "voice": fields.String(description="Voice to use for TTS"), - "streaming": fields.Boolean(description="Whether to stream the audio"), - }, - ) - ) + @console_ns.expect(console_ns.models[TextToSpeechPayload.__name__]) @console_ns.response(200, "Text to speech conversion successful") @console_ns.response(400, "Bad request - Invalid parameters") @get_app_model @@ -111,21 +123,14 @@ class ChatMessageTextApi(Resource): @account_initialization_required def post(self, app_model: App): try: - parser = ( - reqparse.RequestParser() - .add_argument("message_id", type=str, location="json") - .add_argument("text", type=str, location="json") - .add_argument("voice", type=str, location="json") - .add_argument("streaming", type=bool, location="json") - ) - args = parser.parse_args() - - message_id = args.get("message_id", None) - text = args.get("text", None) - voice = args.get("voice", None) + payload = TextToSpeechPayload.model_validate(console_ns.payload) response = AudioService.transcript_tts( - app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True + app_model=app_model, + text=payload.text, + voice=payload.voice, + message_id=payload.message_id, + is_draft=True, ) return response except services.errors.app_model_config.AppModelConfigBrokenError: @@ -159,9 +164,7 @@ class TextModesApi(Resource): @console_ns.doc("get_text_to_speech_voices") @console_ns.doc(description="Get available TTS voices for a specific language") @console_ns.doc(params={"app_id": "App ID"}) - @console_ns.expect( - console_ns.parser().add_argument("language", type=str, required=True, location="args", help="Language code") - ) + @console_ns.expect(console_ns.models[TextToSpeechVoiceQuery.__name__]) @console_ns.response( 200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices")) ) @@ -172,12 +175,11 @@ class TextModesApi(Resource): @account_initialization_required def get(self, app_model): try: - parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args") - args = parser.parse_args() + args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore response = AudioService.transcript_tts_voices( tenant_id=app_model.tenant_id, - language=args["language"], + language=args.language, ) return response diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 9dcadc18a4..c16dcfd91f 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -49,7 +49,6 @@ class CompletionConversationQuery(BaseConversationQuery): class ChatConversationQuery(BaseConversationQuery): - message_count_gte: int | None = Field(default=None, ge=1, description="Minimum message count") sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field( default="-updated_at", description="Sort field and direction" ) @@ -509,14 +508,6 @@ class ChatConversationApi(Resource): .having(func.count(MessageAnnotation.id) == 0) ) - if args.message_count_gte and args.message_count_gte >= 1: - query = ( - query.options(joinedload(Conversation.messages)) # type: ignore - .join(Message, Message.conversation_id == Conversation.id) - .group_by(Conversation.id) - .having(func.count(Message.id) >= args.message_count_gte) - ) - if app_model.mode == AppMode.ADVANCED_CHAT: query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER) diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 58d1fb4a2d..dd982b6d7b 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -1,7 +1,8 @@ import json from enum import StrEnum -from flask_restx import Resource, fields, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field from werkzeug.exceptions import NotFound from controllers.console import console_ns @@ -12,6 +13,8 @@ from fields.app_fields import app_server_fields from libs.login import current_account_with_tenant, login_required from models.model import AppMCPServer +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + # Register model for flask_restx to avoid dict type issues in Swagger app_server_model = console_ns.model("AppServer", app_server_fields) @@ -21,6 +24,22 @@ class AppMCPServerStatus(StrEnum): INACTIVE = "inactive" +class MCPServerCreatePayload(BaseModel): + description: str | None = Field(default=None, description="Server description") + parameters: dict = Field(..., description="Server parameters configuration") + + +class MCPServerUpdatePayload(BaseModel): + id: str = Field(..., description="Server ID") + description: str | None = Field(default=None, description="Server description") + parameters: dict = Field(..., description="Server parameters configuration") + status: str | None = Field(default=None, description="Server status") + + +for model in (MCPServerCreatePayload, MCPServerUpdatePayload): + console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + + @console_ns.route("/apps//server") class AppMCPServerController(Resource): @console_ns.doc("get_app_mcp_server") @@ -39,15 +58,7 @@ class AppMCPServerController(Resource): @console_ns.doc("create_app_mcp_server") @console_ns.doc(description="Create MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "MCPServerCreateRequest", - { - "description": fields.String(description="Server description"), - "parameters": fields.Raw(required=True, description="Server parameters configuration"), - }, - ) - ) + @console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__]) @console_ns.response(201, "MCP server configuration created successfully", app_server_model) @console_ns.response(403, "Insufficient permissions") @account_initialization_required @@ -58,21 +69,16 @@ class AppMCPServerController(Resource): @edit_permission_required def post(self, app_model): _, current_tenant_id = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("description", type=str, required=False, location="json") - .add_argument("parameters", type=dict, required=True, location="json") - ) - args = parser.parse_args() + payload = MCPServerCreatePayload.model_validate(console_ns.payload or {}) - description = args.get("description") + description = payload.description if not description: description = app_model.description or "" server = AppMCPServer( name=app_model.name, description=description, - parameters=json.dumps(args["parameters"], ensure_ascii=False), + parameters=json.dumps(payload.parameters, ensure_ascii=False), status=AppMCPServerStatus.ACTIVE, app_id=app_model.id, tenant_id=current_tenant_id, @@ -85,17 +91,7 @@ class AppMCPServerController(Resource): @console_ns.doc("update_app_mcp_server") @console_ns.doc(description="Update MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "MCPServerUpdateRequest", - { - "id": fields.String(required=True, description="Server ID"), - "description": fields.String(description="Server description"), - "parameters": fields.Raw(required=True, description="Server parameters configuration"), - "status": fields.String(description="Server status"), - }, - ) - ) + @console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__]) @console_ns.response(200, "MCP server configuration updated successfully", app_server_model) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Server not found") @@ -106,19 +102,12 @@ class AppMCPServerController(Resource): @marshal_with(app_server_model) @edit_permission_required def put(self, app_model): - parser = ( - reqparse.RequestParser() - .add_argument("id", type=str, required=True, location="json") - .add_argument("description", type=str, required=False, location="json") - .add_argument("parameters", type=dict, required=True, location="json") - .add_argument("status", type=str, required=False, location="json") - ) - args = parser.parse_args() - server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first() + payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {}) + server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first() if not server: raise NotFound() - description = args.get("description") + description = payload.description if description is None: pass elif not description: @@ -126,11 +115,11 @@ class AppMCPServerController(Resource): else: server.description = description - server.parameters = json.dumps(args["parameters"], ensure_ascii=False) - if args["status"]: - if args["status"] not in [status.value for status in AppMCPServerStatus]: + server.parameters = json.dumps(payload.parameters, ensure_ascii=False) + if payload.status: + if payload.status not in [status.value for status in AppMCPServerStatus]: raise ValueError("Invalid status") - server.status = args["status"] + server.status = payload.status db.session.commit() return server diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 377297c84c..12ada8b798 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -61,6 +61,7 @@ class ChatMessagesQuery(BaseModel): class MessageFeedbackPayload(BaseModel): message_id: str = Field(..., description="Message ID") rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating") + content: str | None = Field(default=None, description="Feedback content") @field_validator("message_id") @classmethod @@ -324,6 +325,7 @@ class MessageFeedbackApi(Resource): db.session.delete(feedback) elif args.rating and feedback: feedback.rating = args.rating + feedback.content = args.content elif not args.rating and not feedback: raise ValueError("rating cannot be None when feedback not exists") else: @@ -335,6 +337,7 @@ class MessageFeedbackApi(Resource): conversation_id=message.conversation_id, message_id=message.id, rating=rating_value, + content=args.content, from_source="admin", from_account_id=current_user.id, ) diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index 19c1a11258..cbcf513162 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -1,4 +1,8 @@ -from flask_restx import Resource, fields, reqparse +from typing import Any + +from flask import request +from flask_restx import Resource, fields +from pydantic import BaseModel, Field from werkzeug.exceptions import BadRequest from controllers.console import console_ns @@ -7,6 +11,26 @@ from controllers.console.wraps import account_initialization_required, setup_req from libs.login import login_required from services.ops_service import OpsService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class TraceProviderQuery(BaseModel): + tracing_provider: str = Field(..., description="Tracing provider name") + + +class TraceConfigPayload(BaseModel): + tracing_provider: str = Field(..., description="Tracing provider name") + tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data") + + +console_ns.schema_model( + TraceProviderQuery.__name__, + TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + @console_ns.route("/apps//trace-config") class TraceAppConfigApi(Resource): @@ -17,11 +41,7 @@ class TraceAppConfigApi(Resource): @console_ns.doc("get_trace_app_config") @console_ns.doc(description="Get tracing configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.parser().add_argument( - "tracing_provider", type=str, required=True, location="args", help="Tracing provider name" - ) - ) + @console_ns.expect(console_ns.models[TraceProviderQuery.__name__]) @console_ns.response( 200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data") ) @@ -30,11 +50,10 @@ class TraceAppConfigApi(Resource): @login_required @account_initialization_required def get(self, app_id): - parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args") - args = parser.parse_args() + args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore try: - trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) + trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider) if not trace_config: return {"has_not_configured": True} return trace_config @@ -44,15 +63,7 @@ class TraceAppConfigApi(Resource): @console_ns.doc("create_trace_app_config") @console_ns.doc(description="Create a new tracing configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "TraceConfigCreateRequest", - { - "tracing_provider": fields.String(required=True, description="Tracing provider name"), - "tracing_config": fields.Raw(required=True, description="Tracing configuration data"), - }, - ) - ) + @console_ns.expect(console_ns.models[TraceConfigPayload.__name__]) @console_ns.response( 201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data") ) @@ -62,16 +73,11 @@ class TraceAppConfigApi(Resource): @account_initialization_required def post(self, app_id): """Create a new trace app configuration""" - parser = ( - reqparse.RequestParser() - .add_argument("tracing_provider", type=str, required=True, location="json") - .add_argument("tracing_config", type=dict, required=True, location="json") - ) - args = parser.parse_args() + args = TraceConfigPayload.model_validate(console_ns.payload) try: result = OpsService.create_tracing_app_config( - app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] + app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config ) if not result: raise TracingConfigIsExist() @@ -84,15 +90,7 @@ class TraceAppConfigApi(Resource): @console_ns.doc("update_trace_app_config") @console_ns.doc(description="Update an existing tracing configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "TraceConfigUpdateRequest", - { - "tracing_provider": fields.String(required=True, description="Tracing provider name"), - "tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"), - }, - ) - ) + @console_ns.expect(console_ns.models[TraceConfigPayload.__name__]) @console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response")) @console_ns.response(400, "Invalid request parameters or configuration not found") @setup_required @@ -100,16 +98,11 @@ class TraceAppConfigApi(Resource): @account_initialization_required def patch(self, app_id): """Update an existing trace app configuration""" - parser = ( - reqparse.RequestParser() - .add_argument("tracing_provider", type=str, required=True, location="json") - .add_argument("tracing_config", type=dict, required=True, location="json") - ) - args = parser.parse_args() + args = TraceConfigPayload.model_validate(console_ns.payload) try: result = OpsService.update_tracing_app_config( - app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] + app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config ) if not result: raise TracingConfigNotExist() @@ -120,11 +113,7 @@ class TraceAppConfigApi(Resource): @console_ns.doc("delete_trace_app_config") @console_ns.doc(description="Delete an existing tracing configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.parser().add_argument( - "tracing_provider", type=str, required=True, location="args", help="Tracing provider name" - ) - ) + @console_ns.expect(console_ns.models[TraceProviderQuery.__name__]) @console_ns.response(204, "Tracing configuration deleted successfully") @console_ns.response(400, "Invalid request parameters or configuration not found") @setup_required @@ -132,11 +121,10 @@ class TraceAppConfigApi(Resource): @account_initialization_required def delete(self, app_id): """Delete an existing trace app configuration""" - parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args") - args = parser.parse_args() + args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore try: - result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) + result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider) if not result: raise TracingConfigNotExist() return {"result": "success"}, 204 diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index d46b8c5c9d..db218d8b81 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,4 +1,7 @@ -from flask_restx import Resource, fields, marshal_with, reqparse +from typing import Literal + +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import NotFound from constants.languages import supported_language @@ -16,69 +19,50 @@ from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import Site +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class AppSiteUpdatePayload(BaseModel): + title: str | None = Field(default=None) + icon_type: str | None = Field(default=None) + icon: str | None = Field(default=None) + icon_background: str | None = Field(default=None) + description: str | None = Field(default=None) + default_language: str | None = Field(default=None) + chat_color_theme: str | None = Field(default=None) + chat_color_theme_inverted: bool | None = Field(default=None) + customize_domain: str | None = Field(default=None) + copyright: str | None = Field(default=None) + privacy_policy: str | None = Field(default=None) + custom_disclaimer: str | None = Field(default=None) + customize_token_strategy: Literal["must", "allow", "not_allow"] | None = Field(default=None) + prompt_public: bool | None = Field(default=None) + show_workflow_steps: bool | None = Field(default=None) + use_icon_as_answer_icon: bool | None = Field(default=None) + + @field_validator("default_language") + @classmethod + def validate_language(cls, value: str | None) -> str | None: + if value is None: + return value + return supported_language(value) + + +console_ns.schema_model( + AppSiteUpdatePayload.__name__, + AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + # Register model for flask_restx to avoid dict type issues in Swagger app_site_model = console_ns.model("AppSite", app_site_fields) -def parse_app_site_args(): - parser = ( - reqparse.RequestParser() - .add_argument("title", type=str, required=False, location="json") - .add_argument("icon_type", type=str, required=False, location="json") - .add_argument("icon", type=str, required=False, location="json") - .add_argument("icon_background", type=str, required=False, location="json") - .add_argument("description", type=str, required=False, location="json") - .add_argument("default_language", type=supported_language, required=False, location="json") - .add_argument("chat_color_theme", type=str, required=False, location="json") - .add_argument("chat_color_theme_inverted", type=bool, required=False, location="json") - .add_argument("customize_domain", type=str, required=False, location="json") - .add_argument("copyright", type=str, required=False, location="json") - .add_argument("privacy_policy", type=str, required=False, location="json") - .add_argument("custom_disclaimer", type=str, required=False, location="json") - .add_argument( - "customize_token_strategy", - type=str, - choices=["must", "allow", "not_allow"], - required=False, - location="json", - ) - .add_argument("prompt_public", type=bool, required=False, location="json") - .add_argument("show_workflow_steps", type=bool, required=False, location="json") - .add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json") - ) - return parser.parse_args() - - @console_ns.route("/apps//site") class AppSite(Resource): @console_ns.doc("update_app_site") @console_ns.doc(description="Update application site configuration") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "AppSiteRequest", - { - "title": fields.String(description="Site title"), - "icon_type": fields.String(description="Icon type"), - "icon": fields.String(description="Icon"), - "icon_background": fields.String(description="Icon background color"), - "description": fields.String(description="Site description"), - "default_language": fields.String(description="Default language"), - "chat_color_theme": fields.String(description="Chat color theme"), - "chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"), - "customize_domain": fields.String(description="Custom domain"), - "copyright": fields.String(description="Copyright text"), - "privacy_policy": fields.String(description="Privacy policy"), - "custom_disclaimer": fields.String(description="Custom disclaimer"), - "customize_token_strategy": fields.String( - enum=["must", "allow", "not_allow"], description="Token strategy" - ), - "prompt_public": fields.Boolean(description="Make prompt public"), - "show_workflow_steps": fields.Boolean(description="Show workflow steps"), - "use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"), - }, - ) - ) + @console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__]) @console_ns.response(200, "Site configuration updated successfully", app_site_model) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "App not found") @@ -89,7 +73,7 @@ class AppSite(Resource): @get_app_model @marshal_with(app_site_model) def post(self, app_model): - args = parse_app_site_args() + args = AppSiteUpdatePayload.model_validate(console_ns.payload or {}) current_user, _ = current_account_with_tenant() site = db.session.query(Site).where(Site.app_id == app_model.id).first() if not site: @@ -113,7 +97,7 @@ class AppSite(Resource): "show_workflow_steps", "use_icon_as_answer_icon", ]: - value = args.get(attr_name) + value = getattr(args, attr_name) if value is not None: setattr(site, attr_name, value) diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 41ae8727de..3382b65acc 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -1,10 +1,11 @@ import logging from collections.abc import Callable from functools import wraps -from typing import NoReturn, ParamSpec, TypeVar +from typing import Any, NoReturn, ParamSpec, TypeVar -from flask import Response -from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse +from flask import Response, request +from flask_restx import Resource, fields, marshal, marshal_with +from pydantic import BaseModel, Field from sqlalchemy.orm import Session from controllers.console import console_ns @@ -29,6 +30,27 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList, from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class WorkflowDraftVariableListQuery(BaseModel): + page: int = Field(default=1, ge=1, le=100_000, description="Page number") + limit: int = Field(default=20, ge=1, le=100, description="Items per page") + + +class WorkflowDraftVariableUpdatePayload(BaseModel): + name: str | None = Field(default=None, description="Variable name") + value: Any | None = Field(default=None, description="Variable value") + + +console_ns.schema_model( + WorkflowDraftVariableListQuery.__name__, + WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + WorkflowDraftVariableUpdatePayload.__name__, + WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) def _convert_values_to_json_serializable_object(value: Segment): @@ -57,22 +79,6 @@ def _serialize_var_value(variable: WorkflowDraftVariable): return _convert_values_to_json_serializable_object(value) -def _create_pagination_parser(): - parser = ( - reqparse.RequestParser() - .add_argument( - "page", - type=inputs.int_range(1, 100_000), - required=False, - default=1, - location="args", - help="the page of data requested", - ) - .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") - ) - return parser - - def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str: value_type = workflow_draft_var.value_type return value_type.exposed_type().value @@ -201,7 +207,7 @@ def _api_prerequisite(f: Callable[P, R]): @console_ns.route("/apps//workflows/draft/variables") class WorkflowVariableCollectionApi(Resource): - @console_ns.expect(_create_pagination_parser()) + @console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__]) @console_ns.doc("get_workflow_variables") @console_ns.doc(description="Get draft workflow variables") @console_ns.doc(params={"app_id": "Application ID"}) @@ -215,8 +221,7 @@ class WorkflowVariableCollectionApi(Resource): """ Get draft workflow """ - parser = _create_pagination_parser() - args = parser.parse_args() + args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore # fetch draft workflow by app_model workflow_service = WorkflowService() @@ -323,15 +328,7 @@ class VariableApi(Resource): @console_ns.doc("update_variable") @console_ns.doc(description="Update a workflow variable") - @console_ns.expect( - console_ns.model( - "UpdateVariableRequest", - { - "name": fields.String(description="Variable name"), - "value": fields.Raw(description="Variable value"), - }, - ) - ) + @console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__]) @console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model) @console_ns.response(404, "Variable not found") @_api_prerequisite @@ -358,16 +355,10 @@ class VariableApi(Resource): # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" # } - parser = ( - reqparse.RequestParser() - .add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") - .add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") - ) - draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) - args = parser.parse_args(strict=True) + args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {}) variable = draft_var_srv.get_variable(variable_id=variable_id) if variable is None: @@ -375,8 +366,8 @@ class VariableApi(Resource): if variable.app_id != app_model.id: raise NotFoundError(description=f"variable not found, id={variable_id}") - new_name = args.get(self._PATCH_NAME_FIELD, None) - raw_value = args.get(self._PATCH_VALUE_FIELD, None) + new_name = args_model.name + raw_value = args_model.value if new_name is None and raw_value is None: return variable diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py index 5d16e4f979..9433b732e4 100644 --- a/api/controllers/console/app/workflow_trigger.py +++ b/api/controllers/console/app/workflow_trigger.py @@ -114,7 +114,7 @@ class AppTriggersApi(Resource): @console_ns.route("/apps//trigger-enable") class AppTriggerEnableApi(Resource): - @console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True) + @console_ns.expect(console_ns.models[ParserEnable.__name__]) @setup_required @login_required @account_initialization_required diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index a11b741040..6834656a7f 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,28 +1,53 @@ from flask import request -from flask_restx import Resource, fields, reqparse +from flask_restx import Resource, fields +from pydantic import BaseModel, Field, field_validator from constants.languages import supported_language from controllers.console import console_ns from controllers.console.error import AlreadyActivateError from extensions.ext_database import db from libs.datetime_utils import naive_utc_now -from libs.helper import StrLen, email, extract_remote_ip, timezone +from libs.helper import EmailStr, extract_remote_ip, timezone from models import AccountStatus from services.account_service import AccountService, RegisterService -active_check_parser = ( - reqparse.RequestParser() - .add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID") - .add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address") - .add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token") -) +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class ActivateCheckQuery(BaseModel): + workspace_id: str | None = Field(default=None) + email: EmailStr | None = Field(default=None) + token: str + + +class ActivatePayload(BaseModel): + workspace_id: str | None = Field(default=None) + email: EmailStr | None = Field(default=None) + token: str + name: str = Field(..., max_length=30) + interface_language: str = Field(...) + timezone: str = Field(...) + + @field_validator("interface_language") + @classmethod + def validate_lang(cls, value: str) -> str: + return supported_language(value) + + @field_validator("timezone") + @classmethod + def validate_tz(cls, value: str) -> str: + return timezone(value) + + +for model in (ActivateCheckQuery, ActivatePayload): + console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) @console_ns.route("/activate/check") class ActivateCheckApi(Resource): @console_ns.doc("check_activation_token") @console_ns.doc(description="Check if activation token is valid") - @console_ns.expect(active_check_parser) + @console_ns.expect(console_ns.models[ActivateCheckQuery.__name__]) @console_ns.response( 200, "Success", @@ -35,11 +60,11 @@ class ActivateCheckApi(Resource): ), ) def get(self): - args = active_check_parser.parse_args() + args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - workspaceId = args["workspace_id"] - reg_email = args["email"] - token = args["token"] + workspaceId = args.workspace_id + reg_email = args.email + token = args.token invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) if invitation: @@ -56,22 +81,11 @@ class ActivateCheckApi(Resource): return {"is_valid": False} -active_parser = ( - reqparse.RequestParser() - .add_argument("workspace_id", type=str, required=False, nullable=True, location="json") - .add_argument("email", type=email, required=False, nullable=True, location="json") - .add_argument("token", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") - .add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json") - .add_argument("timezone", type=timezone, required=True, nullable=False, location="json") -) - - @console_ns.route("/activate") class ActivateApi(Resource): @console_ns.doc("activate_account") @console_ns.doc(description="Activate account with invitation token") - @console_ns.expect(active_parser) + @console_ns.expect(console_ns.models[ActivatePayload.__name__]) @console_ns.response( 200, "Account activated successfully", @@ -85,19 +99,19 @@ class ActivateApi(Resource): ) @console_ns.response(400, "Already activated or invalid token") def post(self): - args = active_parser.parse_args() + args = ActivatePayload.model_validate(console_ns.payload) - invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"]) + invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token) if invitation is None: raise AlreadyActivateError() - RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"]) + RegisterService.revoke_token(args.workspace_id, args.email, args.token) account = invitation["account"] - account.name = args["name"] + account.name = args.name - account.interface_language = args["interface_language"] - account.timezone = args["timezone"] + account.interface_language = args.interface_language + account.timezone = args.timezone account.interface_theme = "light" account.status = AccountStatus.ACTIVE account.initialized_at = naive_utc_now() diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 9d7fcef183..905d0daef0 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -1,12 +1,26 @@ -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field -from controllers.console import console_ns -from controllers.console.auth.error import ApiKeyAuthFailedError -from controllers.console.wraps import is_admin_or_owner_required from libs.login import current_account_with_tenant, login_required from services.auth.api_key_auth_service import ApiKeyAuthService -from ..wraps import account_initialization_required, setup_required +from .. import console_ns +from ..auth.error import ApiKeyAuthFailedError +from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required + +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class ApiKeyAuthBindingPayload(BaseModel): + category: str = Field(...) + provider: str = Field(...) + credentials: dict = Field(...) + + +console_ns.schema_model( + ApiKeyAuthBindingPayload.__name__, + ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) @console_ns.route("/api-key-auth/data-source") @@ -40,19 +54,15 @@ class ApiKeyAuthDataSourceBinding(Resource): @login_required @account_initialization_required @is_admin_or_owner_required + @console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__]) def post(self): # The role of the current user in the table must be admin or owner _, current_tenant_id = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("category", type=str, required=True, nullable=False, location="json") - .add_argument("provider", type=str, required=True, nullable=False, location="json") - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - ) - args = parser.parse_args() - ApiKeyAuthService.validate_api_key_auth_args(args) + payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload) + data = payload.model_dump() + ApiKeyAuthService.validate_api_key_auth_args(data) try: - ApiKeyAuthService.create_provider_auth(current_tenant_id, args) + ApiKeyAuthService.create_provider_auth(current_tenant_id, data) except Exception as e: raise ApiKeyAuthFailedError(str(e)) return {"result": "success"}, 200 diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index cd547caf20..0dd7d33ae9 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -5,12 +5,11 @@ from flask import current_app, redirect, request from flask_restx import Resource, fields from configs import dify_config -from controllers.console import console_ns -from controllers.console.wraps import is_admin_or_owner_required from libs.login import login_required from libs.oauth_data_source import NotionOAuth -from ..wraps import account_initialization_required, setup_required +from .. import console_ns +from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required logger = logging.getLogger(__name__) diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index fe2bb54e0b..fa082c735d 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -1,5 +1,6 @@ from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session @@ -14,16 +15,45 @@ from controllers.console.auth.error import ( InvalidTokenError, PasswordMismatchError, ) -from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError -from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required from extensions.ext_database import db -from libs.helper import email, extract_remote_ip +from libs.helper import EmailStr, extract_remote_ip from libs.password import valid_password from models import Account from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import AccountNotFoundError, AccountRegisterError +from ..error import AccountInFreezeError, EmailSendIpLimitError +from ..wraps import email_password_login_enabled, email_register_enabled, setup_required + +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class EmailRegisterSendPayload(BaseModel): + email: EmailStr = Field(..., description="Email address") + language: str | None = Field(default=None, description="Language code") + + +class EmailRegisterValidityPayload(BaseModel): + email: EmailStr = Field(...) + code: str = Field(...) + token: str = Field(...) + + +class EmailRegisterResetPayload(BaseModel): + token: str = Field(...) + new_password: str = Field(...) + password_confirm: str = Field(...) + + @field_validator("new_password", "password_confirm") + @classmethod + def validate_password(cls, value: str) -> str: + return valid_password(value) + + +for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload): + console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + @console_ns.route("/email-register/send-email") class EmailRegisterSendEmailApi(Resource): @@ -31,27 +61,22 @@ class EmailRegisterSendEmailApi(Resource): @email_password_login_enabled @email_register_enabled def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("language", type=str, required=False, location="json") - ) - args = parser.parse_args() + args = EmailRegisterSendPayload.model_validate(console_ns.payload) ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() language = "en-US" - if args["language"] in languages: - language = args["language"] + if args.language in languages: + language = args.language - if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email): raise AccountInFreezeError() with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() + account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() token = None - token = AccountService.send_email_register_email(email=args["email"], account=account, language=language) + token = AccountService.send_email_register_email(email=args.email, account=account, language=language) return {"result": "success", "data": token} @@ -61,40 +86,34 @@ class EmailRegisterCheckApi(Resource): @email_password_login_enabled @email_register_enabled def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=str, required=True, location="json") - .add_argument("code", type=str, required=True, location="json") - .add_argument("token", type=str, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + args = EmailRegisterValidityPayload.model_validate(console_ns.payload) - user_email = args["email"] + user_email = args.email - is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"]) + is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email) if is_email_register_error_rate_limit: raise EmailRegisterLimitError() - token_data = AccountService.get_email_register_data(args["token"]) + token_data = AccountService.get_email_register_data(args.token) if token_data is None: raise InvalidTokenError() if user_email != token_data.get("email"): raise InvalidEmailError() - if args["code"] != token_data.get("code"): - AccountService.add_email_register_error_rate_limit(args["email"]) + if args.code != token_data.get("code"): + AccountService.add_email_register_error_rate_limit(args.email) raise EmailCodeError() # Verified, revoke the first token - AccountService.revoke_email_register_token(args["token"]) + AccountService.revoke_email_register_token(args.token) # Refresh token data by generating a new token _, new_token = AccountService.generate_email_register_token( - user_email, code=args["code"], additional_data={"phase": "register"} + user_email, code=args.code, additional_data={"phase": "register"} ) - AccountService.reset_email_register_error_rate_limit(args["email"]) + AccountService.reset_email_register_error_rate_limit(args.email) return {"is_valid": True, "email": token_data.get("email"), "token": new_token} @@ -104,20 +123,14 @@ class EmailRegisterResetApi(Resource): @email_password_login_enabled @email_register_enabled def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("token", type=str, required=True, nullable=False, location="json") - .add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") - .add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + args = EmailRegisterResetPayload.model_validate(console_ns.payload) # Validate passwords match - if args["new_password"] != args["password_confirm"]: + if args.new_password != args.password_confirm: raise PasswordMismatchError() # Validate token and get register data - register_data = AccountService.get_email_register_data(args["token"]) + register_data = AccountService.get_email_register_data(args.token) if not register_data: raise InvalidTokenError() # Must use token in reset phase @@ -125,7 +138,7 @@ class EmailRegisterResetApi(Resource): raise InvalidTokenError() # Revoke token to prevent reuse - AccountService.revoke_email_register_token(args["token"]) + AccountService.revoke_email_register_token(args.token) email = register_data.get("email", "") @@ -135,7 +148,7 @@ class EmailRegisterResetApi(Resource): if account: raise EmailAlreadyInUseError() else: - account = self._create_new_account(email, args["password_confirm"]) + account = self._create_new_account(email, args.password_confirm) if not account: raise AccountNotFoundError() token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index ee561bdd30..661f591182 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -2,7 +2,8 @@ import base64 import secrets from flask import request -from flask_restx import Resource, fields, reqparse +from flask_restx import Resource, fields +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session @@ -18,26 +19,46 @@ from controllers.console.error import AccountNotFound, EmailSendIpLimitError from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created from extensions.ext_database import db -from libs.helper import email, extract_remote_ip +from libs.helper import EmailStr, extract_remote_ip from libs.password import hash_password, valid_password from models import Account from services.account_service import AccountService, TenantService from services.feature_service import FeatureService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class ForgotPasswordSendPayload(BaseModel): + email: EmailStr = Field(...) + language: str | None = Field(default=None) + + +class ForgotPasswordCheckPayload(BaseModel): + email: EmailStr = Field(...) + code: str = Field(...) + token: str = Field(...) + + +class ForgotPasswordResetPayload(BaseModel): + token: str = Field(...) + new_password: str = Field(...) + password_confirm: str = Field(...) + + @field_validator("new_password", "password_confirm") + @classmethod + def validate_password(cls, value: str) -> str: + return valid_password(value) + + +for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload): + console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + @console_ns.route("/forgot-password") class ForgotPasswordSendEmailApi(Resource): @console_ns.doc("send_forgot_password_email") @console_ns.doc(description="Send password reset email") - @console_ns.expect( - console_ns.model( - "ForgotPasswordEmailRequest", - { - "email": fields.String(required=True, description="Email address"), - "language": fields.String(description="Language for email (zh-Hans/en-US)"), - }, - ) - ) + @console_ns.expect(console_ns.models[ForgotPasswordSendPayload.__name__]) @console_ns.response( 200, "Email sent successfully", @@ -54,28 +75,23 @@ class ForgotPasswordSendEmailApi(Resource): @setup_required @email_password_login_enabled def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("language", type=str, required=False, location="json") - ) - args = parser.parse_args() + args = ForgotPasswordSendPayload.model_validate(console_ns.payload) ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() - if args["language"] is not None and args["language"] == "zh-Hans": + if args.language is not None and args.language == "zh-Hans": language = "zh-Hans" else: language = "en-US" with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() + account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() token = AccountService.send_reset_password_email( account=account, - email=args["email"], + email=args.email, language=language, is_allow_register=FeatureService.get_system_features().is_allow_register, ) @@ -87,16 +103,7 @@ class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordCheckApi(Resource): @console_ns.doc("check_forgot_password_code") @console_ns.doc(description="Verify password reset code") - @console_ns.expect( - console_ns.model( - "ForgotPasswordCheckRequest", - { - "email": fields.String(required=True, description="Email address"), - "code": fields.String(required=True, description="Verification code"), - "token": fields.String(required=True, description="Reset token"), - }, - ) - ) + @console_ns.expect(console_ns.models[ForgotPasswordCheckPayload.__name__]) @console_ns.response( 200, "Code verified successfully", @@ -113,40 +120,34 @@ class ForgotPasswordCheckApi(Resource): @setup_required @email_password_login_enabled def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=str, required=True, location="json") - .add_argument("code", type=str, required=True, location="json") - .add_argument("token", type=str, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + args = ForgotPasswordCheckPayload.model_validate(console_ns.payload) - user_email = args["email"] + user_email = args.email - is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"]) + is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email) if is_forgot_password_error_rate_limit: raise EmailPasswordResetLimitError() - token_data = AccountService.get_reset_password_data(args["token"]) + token_data = AccountService.get_reset_password_data(args.token) if token_data is None: raise InvalidTokenError() if user_email != token_data.get("email"): raise InvalidEmailError() - if args["code"] != token_data.get("code"): - AccountService.add_forgot_password_error_rate_limit(args["email"]) + if args.code != token_data.get("code"): + AccountService.add_forgot_password_error_rate_limit(args.email) raise EmailCodeError() # Verified, revoke the first token - AccountService.revoke_reset_password_token(args["token"]) + AccountService.revoke_reset_password_token(args.token) # Refresh token data by generating a new token _, new_token = AccountService.generate_reset_password_token( - user_email, code=args["code"], additional_data={"phase": "reset"} + user_email, code=args.code, additional_data={"phase": "reset"} ) - AccountService.reset_forgot_password_error_rate_limit(args["email"]) + AccountService.reset_forgot_password_error_rate_limit(args.email) return {"is_valid": True, "email": token_data.get("email"), "token": new_token} @@ -154,16 +155,7 @@ class ForgotPasswordCheckApi(Resource): class ForgotPasswordResetApi(Resource): @console_ns.doc("reset_password") @console_ns.doc(description="Reset password with verification token") - @console_ns.expect( - console_ns.model( - "ForgotPasswordResetRequest", - { - "token": fields.String(required=True, description="Verification token"), - "new_password": fields.String(required=True, description="New password"), - "password_confirm": fields.String(required=True, description="Password confirmation"), - }, - ) - ) + @console_ns.expect(console_ns.models[ForgotPasswordResetPayload.__name__]) @console_ns.response( 200, "Password reset successfully", @@ -173,20 +165,14 @@ class ForgotPasswordResetApi(Resource): @setup_required @email_password_login_enabled def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("token", type=str, required=True, nullable=False, location="json") - .add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") - .add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + args = ForgotPasswordResetPayload.model_validate(console_ns.payload) # Validate passwords match - if args["new_password"] != args["password_confirm"]: + if args.new_password != args.password_confirm: raise PasswordMismatchError() # Validate token and get reset data - reset_data = AccountService.get_reset_password_data(args["token"]) + reset_data = AccountService.get_reset_password_data(args.token) if not reset_data: raise InvalidTokenError() # Must use token in reset phase @@ -194,11 +180,11 @@ class ForgotPasswordResetApi(Resource): raise InvalidTokenError() # Revoke token to prevent reuse - AccountService.revoke_reset_password_token(args["token"]) + AccountService.revoke_reset_password_token(args.token) # Generate secure salt and hash password salt = secrets.token_bytes(16) - password_hashed = hash_password(args["new_password"], salt) + password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 77ecd5a5e4..f486f4c313 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,6 +1,7 @@ import flask_login from flask import make_response, request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field import services from configs import dify_config @@ -23,7 +24,7 @@ from controllers.console.error import ( ) from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created -from libs.helper import email, extract_remote_ip +from libs.helper import EmailStr, extract_remote_ip from libs.login import current_account_with_tenant from libs.token import ( clear_access_token_from_cookie, @@ -40,6 +41,36 @@ from services.errors.account import AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.feature_service import FeatureService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class LoginPayload(BaseModel): + email: EmailStr = Field(..., description="Email address") + password: str = Field(..., description="Password") + remember_me: bool = Field(default=False, description="Remember me flag") + invite_token: str | None = Field(default=None, description="Invitation token") + + +class EmailPayload(BaseModel): + email: EmailStr = Field(...) + language: str | None = Field(default=None) + + +class EmailCodeLoginPayload(BaseModel): + email: EmailStr = Field(...) + code: str = Field(...) + token: str = Field(...) + language: str | None = Field(default=None) + + +def reg(cls: type[BaseModel]): + console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + + +reg(LoginPayload) +reg(EmailPayload) +reg(EmailCodeLoginPayload) + @console_ns.route("/login") class LoginApi(Resource): @@ -47,41 +78,36 @@ class LoginApi(Resource): @setup_required @email_password_login_enabled + @console_ns.expect(console_ns.models[LoginPayload.__name__]) def post(self): """Authenticate user and login.""" - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("password", type=str, required=True, location="json") - .add_argument("remember_me", type=bool, required=False, default=False, location="json") - .add_argument("invite_token", type=str, required=False, default=None, location="json") - ) - args = parser.parse_args() + args = LoginPayload.model_validate(console_ns.payload) - if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email): raise AccountInFreezeError() - is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"]) + is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email) if is_login_error_rate_limit: raise EmailPasswordLoginLimitError() - invitation = args["invite_token"] + # TODO: why invitation is re-assigned with different type? + invitation = args.invite_token # type: ignore if invitation: - invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation) + invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore try: if invitation: - data = invitation.get("data", {}) + data = invitation.get("data", {}) # type: ignore invitee_email = data.get("email") if data else None - if invitee_email != args["email"]: + if invitee_email != args.email: raise InvalidEmailError() - account = AccountService.authenticate(args["email"], args["password"], args["invite_token"]) + account = AccountService.authenticate(args.email, args.password, args.invite_token) else: - account = AccountService.authenticate(args["email"], args["password"]) + account = AccountService.authenticate(args.email, args.password) except services.errors.account.AccountLoginError: raise AccountBannedError() except services.errors.account.AccountPasswordError: - AccountService.add_login_error_rate_limit(args["email"]) + AccountService.add_login_error_rate_limit(args.email) raise AuthenticationFailedError() # SELF_HOSTED only have one workspace tenants = TenantService.get_join_tenants(account) @@ -97,7 +123,7 @@ class LoginApi(Resource): } token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(args["email"]) + AccountService.reset_login_error_rate_limit(args.email) # Create response with cookies instead of returning tokens in body response = make_response({"result": "success"}) @@ -134,25 +160,21 @@ class LogoutApi(Resource): class ResetPasswordSendEmailApi(Resource): @setup_required @email_password_login_enabled + @console_ns.expect(console_ns.models[EmailPayload.__name__]) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("language", type=str, required=False, location="json") - ) - args = parser.parse_args() + args = EmailPayload.model_validate(console_ns.payload) - if args["language"] is not None and args["language"] == "zh-Hans": + if args.language is not None and args.language == "zh-Hans": language = "zh-Hans" else: language = "en-US" try: - account = AccountService.get_user_through_email(args["email"]) + account = AccountService.get_user_through_email(args.email) except AccountRegisterError: raise AccountInFreezeError() token = AccountService.send_reset_password_email( - email=args["email"], + email=args.email, account=account, language=language, is_allow_register=FeatureService.get_system_features().is_allow_register, @@ -164,30 +186,26 @@ class ResetPasswordSendEmailApi(Resource): @console_ns.route("/email-code-login") class EmailCodeLoginSendEmailApi(Resource): @setup_required + @console_ns.expect(console_ns.models[EmailPayload.__name__]) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("language", type=str, required=False, location="json") - ) - args = parser.parse_args() + args = EmailPayload.model_validate(console_ns.payload) ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() - if args["language"] is not None and args["language"] == "zh-Hans": + if args.language is not None and args.language == "zh-Hans": language = "zh-Hans" else: language = "en-US" try: - account = AccountService.get_user_through_email(args["email"]) + account = AccountService.get_user_through_email(args.email) except AccountRegisterError: raise AccountInFreezeError() if account is None: if FeatureService.get_system_features().is_allow_register: - token = AccountService.send_email_code_login_email(email=args["email"], language=language) + token = AccountService.send_email_code_login_email(email=args.email, language=language) else: raise AccountNotFound() else: @@ -199,30 +217,24 @@ class EmailCodeLoginSendEmailApi(Resource): @console_ns.route("/email-code-login/validity") class EmailCodeLoginApi(Resource): @setup_required + @console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__]) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=str, required=True, location="json") - .add_argument("code", type=str, required=True, location="json") - .add_argument("token", type=str, required=True, location="json") - .add_argument("language", type=str, required=False, location="json") - ) - args = parser.parse_args() + args = EmailCodeLoginPayload.model_validate(console_ns.payload) - user_email = args["email"] - language = args["language"] + user_email = args.email + language = args.language - token_data = AccountService.get_email_code_login_data(args["token"]) + token_data = AccountService.get_email_code_login_data(args.token) if token_data is None: raise InvalidTokenError() - if token_data["email"] != args["email"]: + if token_data["email"] != args.email: raise InvalidEmailError() - if token_data["code"] != args["code"]: + if token_data["code"] != args.code: raise EmailCodeError() - AccountService.revoke_email_code_login_token(args["token"]) + AccountService.revoke_email_code_login_token(args.token) try: account = AccountService.get_user_through_email(user_email) except AccountRegisterError: @@ -255,7 +267,7 @@ class EmailCodeLoginApi(Resource): except WorkspacesLimitExceededError: raise WorkspacesLimitExceeded() token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(args["email"]) + AccountService.reset_login_error_rate_limit(args.email) # Create response with cookies instead of returning tokens in body response = make_response({"result": "success"}) diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 5e12aa7d03..6162d88a0b 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -3,7 +3,8 @@ from functools import wraps from typing import Concatenate, ParamSpec, TypeVar from flask import jsonify, request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel from werkzeug.exceptions import BadRequest, NotFound from controllers.console.wraps import account_initialization_required, setup_required @@ -20,15 +21,34 @@ R = TypeVar("R") T = TypeVar("T") +class OAuthClientPayload(BaseModel): + client_id: str + + +class OAuthProviderRequest(BaseModel): + client_id: str + redirect_uri: str + + +class OAuthTokenRequest(BaseModel): + client_id: str + grant_type: str + code: str | None = None + client_secret: str | None = None + redirect_uri: str | None = None + refresh_token: str | None = None + + def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]): @wraps(view) def decorated(self: T, *args: P.args, **kwargs: P.kwargs): - parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json") - parsed_args = parser.parse_args() - client_id = parsed_args.get("client_id") - if not client_id: + json_data = request.get_json() + if json_data is None: raise BadRequest("client_id is required") + payload = OAuthClientPayload.model_validate(json_data) + client_id = payload.client_id + oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id) if not oauth_provider_app: raise NotFound("client_id is invalid") @@ -89,9 +109,8 @@ class OAuthServerAppApi(Resource): @setup_required @oauth_server_client_id_required def post(self, oauth_provider_app: OAuthProviderApp): - parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json") - parsed_args = parser.parse_args() - redirect_uri = parsed_args.get("redirect_uri") + payload = OAuthProviderRequest.model_validate(request.get_json()) + redirect_uri = payload.redirect_uri # check if redirect_uri is valid if redirect_uri not in oauth_provider_app.redirect_uris: @@ -130,33 +149,25 @@ class OAuthServerUserTokenApi(Resource): @setup_required @oauth_server_client_id_required def post(self, oauth_provider_app: OAuthProviderApp): - parser = ( - reqparse.RequestParser() - .add_argument("grant_type", type=str, required=True, location="json") - .add_argument("code", type=str, required=False, location="json") - .add_argument("client_secret", type=str, required=False, location="json") - .add_argument("redirect_uri", type=str, required=False, location="json") - .add_argument("refresh_token", type=str, required=False, location="json") - ) - parsed_args = parser.parse_args() + payload = OAuthTokenRequest.model_validate(request.get_json()) try: - grant_type = OAuthGrantType(parsed_args["grant_type"]) + grant_type = OAuthGrantType(payload.grant_type) except ValueError: raise BadRequest("invalid grant_type") if grant_type == OAuthGrantType.AUTHORIZATION_CODE: - if not parsed_args["code"]: + if not payload.code: raise BadRequest("code is required") - if parsed_args["client_secret"] != oauth_provider_app.client_secret: + if payload.client_secret != oauth_provider_app.client_secret: raise BadRequest("client_secret is invalid") - if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris: + if payload.redirect_uri not in oauth_provider_app.redirect_uris: raise BadRequest("redirect_uri is invalid") access_token, refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id + grant_type, code=payload.code, client_id=oauth_provider_app.client_id ) return jsonable_encoder( { @@ -167,11 +178,11 @@ class OAuthServerUserTokenApi(Resource): } ) elif grant_type == OAuthGrantType.REFRESH_TOKEN: - if not parsed_args["refresh_token"]: + if not payload.refresh_token: raise BadRequest("refresh_token is required") access_token, refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id + grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id ) return jsonable_encoder( { diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 4fef1ba40d..7f907dc420 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,6 +1,8 @@ import base64 -from flask_restx import Resource, fields, reqparse +from flask import request +from flask_restx import Resource, fields +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import BadRequest from controllers.console import console_ns @@ -9,6 +11,35 @@ from enums.cloud_plan import CloudPlan from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class SubscriptionQuery(BaseModel): + plan: str = Field(..., description="Subscription plan") + interval: str = Field(..., description="Billing interval") + + @field_validator("plan") + @classmethod + def validate_plan(cls, value: str) -> str: + if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]: + raise ValueError("Invalid plan") + return value + + @field_validator("interval") + @classmethod + def validate_interval(cls, value: str) -> str: + if value not in {"month", "year"}: + raise ValueError("Invalid interval") + return value + + +class PartnerTenantsPayload(BaseModel): + click_id: str = Field(..., description="Click Id from partner referral link") + + +for model in (SubscriptionQuery, PartnerTenantsPayload): + console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + @console_ns.route("/billing/subscription") class Subscription(Resource): @@ -18,20 +49,9 @@ class Subscription(Resource): @only_edition_cloud def get(self): current_user, current_tenant_id = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument( - "plan", - type=str, - required=True, - location="args", - choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM], - ) - .add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) - ) - args = parser.parse_args() + args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore BillingService.is_tenant_owner_or_admin(current_user) - return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id) + return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id) @console_ns.route("/billing/invoices") @@ -65,11 +85,10 @@ class PartnerTenants(Resource): @only_edition_cloud def put(self, partner_key: str): current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json") - args = parser.parse_args() try: - click_id = args["click_id"] + args = PartnerTenantsPayload.model_validate(console_ns.payload or {}) + click_id = args.click_id decoded_partner_key = base64.b64decode(partner_key).decode("utf-8") except Exception: raise BadRequest("Invalid partner_key") diff --git a/api/controllers/console/billing/compliance.py b/api/controllers/console/billing/compliance.py index 2a6889968c..afc5f92b68 100644 --- a/api/controllers/console/billing/compliance.py +++ b/api/controllers/console/billing/compliance.py @@ -1,5 +1,6 @@ from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field from libs.helper import extract_remote_ip from libs.login import current_account_with_tenant, login_required @@ -9,16 +10,28 @@ from .. import console_ns from ..wraps import account_initialization_required, only_edition_cloud, setup_required +class ComplianceDownloadQuery(BaseModel): + doc_name: str = Field(..., description="Compliance document name") + + +console_ns.schema_model( + ComplianceDownloadQuery.__name__, + ComplianceDownloadQuery.model_json_schema(ref_template="#/definitions/{model}"), +) + + @console_ns.route("/compliance/download") class ComplianceApi(Resource): + @console_ns.expect(console_ns.models[ComplianceDownloadQuery.__name__]) + @console_ns.doc("download_compliance_document") + @console_ns.doc(description="Get compliance document download link") @setup_required @login_required @account_initialization_required @only_edition_cloud def get(self): current_user, current_tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args") - args = parser.parse_args() + args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore ip_address = extract_remote_ip(request) device_info = request.headers.get("User-Agent", "Unknown device") diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index ef66053075..01f268d94d 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -1,15 +1,15 @@ import json from collections.abc import Generator -from typing import cast +from typing import Any, cast from flask import request -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound -from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.common.schema import register_schema_model from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.indexing_runner import IndexingRunner @@ -25,6 +25,19 @@ from services.dataset_service import DatasetService, DocumentService from services.datasource_provider_service import DatasourceProviderService from tasks.document_indexing_sync_task import document_indexing_sync_task +from .. import console_ns +from ..wraps import account_initialization_required, setup_required + + +class NotionEstimatePayload(BaseModel): + notion_info_list: list[dict[str, Any]] + process_rule: dict[str, Any] + doc_form: str = Field(default="text_model") + doc_language: str = Field(default="English") + + +register_schema_model(console_ns, NotionEstimatePayload) + @console_ns.route( "/data-source/integrates", @@ -243,20 +256,15 @@ class DataSourceNotionApi(Resource): @setup_required @login_required @account_initialization_required + @console_ns.expect(console_ns.models[NotionEstimatePayload.__name__]) def post(self): _, current_tenant_id = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("notion_info_list", type=list, required=True, nullable=True, location="json") - .add_argument("process_rule", type=dict, required=True, nullable=True, location="json") - .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") - ) - args = parser.parse_args() + payload = NotionEstimatePayload.model_validate(console_ns.payload or {}) + args = payload.model_dump() # validate args DocumentService.estimate_args_validate(args) - notion_info_list = args["notion_info_list"] + notion_info_list = payload.notion_info_list extract_settings = [] for notion_info in notion_info_list: workspace_id = notion_info["workspace_id"] diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 45bc1fa694..70b6e932e9 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,12 +1,14 @@ from typing import Any, cast from flask import request -from flask_restx import Resource, fields, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, marshal, marshal_with +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound import services from configs import dify_config +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.apikey import ( api_key_item_model, @@ -48,7 +50,6 @@ from fields.dataset_fields import ( ) from fields.document_fields import document_status_fields from libs.login import current_account_with_tenant, login_required -from libs.validators import validate_description_length from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models.dataset import DatasetPermissionEnum from models.provider_ids import ModelProviderID @@ -107,10 +108,75 @@ related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_mode related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy) -def _validate_name(name: str) -> str: - if not name or len(name) < 1 or len(name) > 40: - raise ValueError("Name must be between 1 to 40 characters.") - return name +def _validate_indexing_technique(value: str | None) -> str | None: + if value is None: + return value + if value not in Dataset.INDEXING_TECHNIQUE_LIST: + raise ValueError("Invalid indexing technique.") + return value + + +class DatasetCreatePayload(BaseModel): + name: str = Field(..., min_length=1, max_length=40) + description: str = Field("", max_length=400) + indexing_technique: str | None = None + permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME + provider: str = "vendor" + external_knowledge_api_id: str | None = None + external_knowledge_id: str | None = None + + @field_validator("indexing_technique") + @classmethod + def validate_indexing(cls, value: str | None) -> str | None: + return _validate_indexing_technique(value) + + @field_validator("provider") + @classmethod + def validate_provider(cls, value: str) -> str: + if value not in Dataset.PROVIDER_LIST: + raise ValueError("Invalid provider.") + return value + + +class DatasetUpdatePayload(BaseModel): + name: str | None = Field(None, min_length=1, max_length=40) + description: str | None = Field(None, max_length=400) + permission: DatasetPermissionEnum | None = None + indexing_technique: str | None = None + embedding_model: str | None = None + embedding_model_provider: str | None = None + retrieval_model: dict[str, Any] | None = None + partial_member_list: list[str] | None = None + external_retrieval_model: dict[str, Any] | None = None + external_knowledge_id: str | None = None + external_knowledge_api_id: str | None = None + icon_info: dict[str, Any] | None = None + is_multimodal: bool | None = False + + @field_validator("indexing_technique") + @classmethod + def validate_indexing(cls, value: str | None) -> str | None: + return _validate_indexing_technique(value) + + +class IndexingEstimatePayload(BaseModel): + info_list: dict[str, Any] + process_rule: dict[str, Any] + indexing_technique: str + doc_form: str = "text_model" + dataset_id: str | None = None + doc_language: str = "English" + + @field_validator("indexing_technique") + @classmethod + def validate_indexing(cls, value: str) -> str: + result = _validate_indexing_technique(value) + if result is None: + raise ValueError("indexing_technique is required.") + return result + + +register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload) def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]: @@ -255,20 +321,7 @@ class DatasetListApi(Resource): @console_ns.doc("create_dataset") @console_ns.doc(description="Create a new dataset") - @console_ns.expect( - console_ns.model( - "CreateDatasetRequest", - { - "name": fields.String(required=True, description="Dataset name (1-40 characters)"), - "description": fields.String(description="Dataset description (max 400 characters)"), - "indexing_technique": fields.String(description="Indexing technique"), - "permission": fields.String(description="Dataset permission"), - "provider": fields.String(description="Provider"), - "external_knowledge_api_id": fields.String(description="External knowledge API ID"), - "external_knowledge_id": fields.String(description="External knowledge ID"), - }, - ) - ) + @console_ns.expect(console_ns.models[DatasetCreatePayload.__name__]) @console_ns.response(201, "Dataset created successfully") @console_ns.response(400, "Invalid request parameters") @setup_required @@ -276,52 +329,7 @@ class DatasetListApi(Resource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") def post(self): - parser = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="type is required. Name must be between 1 to 40 characters.", - type=_validate_name, - ) - .add_argument( - "description", - type=validate_description_length, - nullable=True, - required=False, - default="", - ) - .add_argument( - "indexing_technique", - type=str, - location="json", - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help="Invalid indexing technique.", - ) - .add_argument( - "external_knowledge_api_id", - type=str, - nullable=True, - required=False, - ) - .add_argument( - "provider", - type=str, - nullable=True, - choices=Dataset.PROVIDER_LIST, - required=False, - default="vendor", - ) - .add_argument( - "external_knowledge_id", - type=str, - nullable=True, - required=False, - ) - ) - args = parser.parse_args() + payload = DatasetCreatePayload.model_validate(console_ns.payload or {}) current_user, current_tenant_id = current_account_with_tenant() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator @@ -331,14 +339,14 @@ class DatasetListApi(Resource): try: dataset = DatasetService.create_empty_dataset( tenant_id=current_tenant_id, - name=args["name"], - description=args["description"], - indexing_technique=args["indexing_technique"], + name=payload.name, + description=payload.description, + indexing_technique=payload.indexing_technique, account=current_user, - permission=DatasetPermissionEnum.ONLY_ME, - provider=args["provider"], - external_knowledge_api_id=args["external_knowledge_api_id"], - external_knowledge_id=args["external_knowledge_id"], + permission=payload.permission or DatasetPermissionEnum.ONLY_ME, + provider=payload.provider, + external_knowledge_api_id=payload.external_knowledge_api_id, + external_knowledge_id=payload.external_knowledge_id, ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() @@ -399,18 +407,7 @@ class DatasetApi(Resource): @console_ns.doc("update_dataset") @console_ns.doc(description="Update dataset details") - @console_ns.expect( - console_ns.model( - "UpdateDatasetRequest", - { - "name": fields.String(description="Dataset name"), - "description": fields.String(description="Dataset description"), - "permission": fields.String(description="Dataset permission"), - "indexing_technique": fields.String(description="Indexing technique"), - "external_retrieval_model": fields.Raw(description="External retrieval model settings"), - }, - ) - ) + @console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__]) @console_ns.response(200, "Dataset updated successfully", dataset_detail_model) @console_ns.response(404, "Dataset not found") @console_ns.response(403, "Permission denied") @@ -424,93 +421,25 @@ class DatasetApi(Resource): if dataset is None: raise NotFound("Dataset not found.") - parser = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - help="type is required. Name must be between 1 to 40 characters.", - type=_validate_name, - ) - .add_argument("description", location="json", store_missing=False, type=validate_description_length) - .add_argument( - "indexing_technique", - type=str, - location="json", - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help="Invalid indexing technique.", - ) - .add_argument( - "permission", - type=str, - location="json", - choices=( - DatasetPermissionEnum.ONLY_ME, - DatasetPermissionEnum.ALL_TEAM, - DatasetPermissionEnum.PARTIAL_TEAM, - ), - help="Invalid permission.", - ) - .add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") - .add_argument( - "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." - ) - .add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") - .add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") - .add_argument( - "external_retrieval_model", - type=dict, - required=False, - nullable=True, - location="json", - help="Invalid external retrieval model.", - ) - .add_argument( - "external_knowledge_id", - type=str, - required=False, - nullable=True, - location="json", - help="Invalid external knowledge id.", - ) - .add_argument( - "external_knowledge_api_id", - type=str, - required=False, - nullable=True, - location="json", - help="Invalid external knowledge api id.", - ) - .add_argument( - "icon_info", - type=dict, - required=False, - nullable=True, - location="json", - help="Invalid icon info.", - ) - ) - args = parser.parse_args() - data = request.get_json() + payload = DatasetUpdatePayload.model_validate(console_ns.payload or {}) current_user, current_tenant_id = current_account_with_tenant() - # check embedding model setting if ( - data.get("indexing_technique") == "high_quality" - and data.get("embedding_model_provider") is not None - and data.get("embedding_model") is not None + payload.indexing_technique == "high_quality" + and payload.embedding_model_provider is not None + and payload.embedding_model is not None ): - DatasetService.check_embedding_model_setting( - dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") + is_multimodal = DatasetService.check_is_multimodal_model( + dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model ) - + payload.is_multimodal = is_multimodal + payload_data = payload.model_dump(exclude_unset=True) # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator DatasetPermissionService.check_permission( - current_user, dataset, data.get("permission"), data.get("partial_member_list") + current_user, dataset, payload.permission, payload.partial_member_list ) - dataset = DatasetService.update_dataset(dataset_id_str, args, current_user) + dataset = DatasetService.update_dataset(dataset_id_str, payload_data, current_user) if dataset is None: raise NotFound("Dataset not found.") @@ -518,15 +447,10 @@ class DatasetApi(Resource): result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) tenant_id = current_tenant_id - if data.get("partial_member_list") and data.get("permission") == "partial_members": - DatasetPermissionService.update_partial_member_list( - tenant_id, dataset_id_str, data.get("partial_member_list") - ) + if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM: + DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list) # clear partial member list when permission is only_me or all_team_members - elif ( - data.get("permission") == DatasetPermissionEnum.ONLY_ME - or data.get("permission") == DatasetPermissionEnum.ALL_TEAM - ): + elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}: DatasetPermissionService.clear_partial_member_list(dataset_id_str) partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) @@ -615,24 +539,10 @@ class DatasetIndexingEstimateApi(Resource): @setup_required @login_required @account_initialization_required + @console_ns.expect(console_ns.models[IndexingEstimatePayload.__name__]) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("info_list", type=dict, required=True, nullable=True, location="json") - .add_argument("process_rule", type=dict, required=True, nullable=True, location="json") - .add_argument( - "indexing_technique", - type=str, - required=True, - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - location="json", - ) - .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - .add_argument("dataset_id", type=str, required=False, nullable=False, location="json") - .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") - ) - args = parser.parse_args() + payload = IndexingEstimatePayload.model_validate(console_ns.payload or {}) + args = payload.model_dump() _, current_tenant_id = current_account_with_tenant() # validate args DocumentService.estimate_args_validate(args) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 2663c939bc..6145da31a5 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -6,31 +6,14 @@ from typing import Literal, cast import sqlalchemy as sa from flask import request -from flask_restx import Resource, fields, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, marshal, marshal_with +from pydantic import BaseModel from sqlalchemy import asc, desc, select from werkzeug.exceptions import Forbidden, NotFound import services +from controllers.common.schema import register_schema_models from controllers.console import console_ns -from controllers.console.app.error import ( - ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, - ProviderQuotaExceededError, -) -from controllers.console.datasets.error import ( - ArchivedDocumentImmutableError, - DocumentAlreadyFinishedError, - DocumentIndexingError, - IndexingEstimateError, - InvalidActionError, - InvalidMetadataError, -) -from controllers.console.wraps import ( - account_initialization_required, - cloud_edition_billing_rate_limit_check, - cloud_edition_billing_resource_check, - setup_required, -) from core.errors.error import ( LLMBadRequestError, ModelCurrentlyNotSupportError, @@ -55,10 +38,30 @@ from fields.document_fields import ( ) from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required -from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile +from models import DatasetProcessRule, Document, DocumentSegment, UploadFile from models.dataset import DocumentPipelineExecutionLog from services.dataset_service import DatasetService, DocumentService -from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig +from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel + +from ..app.error import ( + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from ..datasets.error import ( + ArchivedDocumentImmutableError, + DocumentAlreadyFinishedError, + DocumentIndexingError, + IndexingEstimateError, + InvalidActionError, + InvalidMetadataError, +) +from ..wraps import ( + account_initialization_required, + cloud_edition_billing_rate_limit_check, + cloud_edition_billing_resource_check, + setup_required, +) logger = logging.getLogger(__name__) @@ -93,6 +96,24 @@ dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(docume dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy) +class DocumentRetryPayload(BaseModel): + document_ids: list[str] + + +class DocumentRenamePayload(BaseModel): + name: str + + +register_schema_models( + console_ns, + KnowledgeConfig, + ProcessRule, + RetrievalModel, + DocumentRetryPayload, + DocumentRenamePayload, +) + + class DocumentResource(Resource): def get_document(self, dataset_id: str, document_id: str) -> Document: current_user, current_tenant_id = current_account_with_tenant() @@ -201,8 +222,9 @@ class DatasetDocumentListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id: str): + def get(self, dataset_id): current_user, current_tenant_id = current_account_with_tenant() + dataset_id = str(dataset_id) page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) search = request.args.get("keyword", default=None, type=str) @@ -310,6 +332,7 @@ class DatasetDocumentListApi(Resource): @marshal_with(dataset_and_document_model) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") + @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) def post(self, dataset_id): current_user, _ = current_account_with_tenant() dataset_id = str(dataset_id) @@ -328,23 +351,7 @@ class DatasetDocumentListApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - parser = ( - reqparse.RequestParser() - .add_argument( - "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" - ) - .add_argument("data_source", type=dict, required=False, location="json") - .add_argument("process_rule", type=dict, required=False, location="json") - .add_argument("duplicate", type=bool, default=True, nullable=False, location="json") - .add_argument("original_document_id", type=str, required=False, location="json") - .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - .add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") - .add_argument("embedding_model", type=str, required=False, nullable=True, location="json") - .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") - .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") - ) - args = parser.parse_args() - knowledge_config = KnowledgeConfig.model_validate(args) + knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {}) if not dataset.indexing_technique and not knowledge_config.indexing_technique: raise ValueError("indexing_technique is required.") @@ -390,17 +397,7 @@ class DatasetDocumentListApi(Resource): class DatasetInitApi(Resource): @console_ns.doc("init_dataset") @console_ns.doc(description="Initialize dataset with documents") - @console_ns.expect( - console_ns.model( - "DatasetInitRequest", - { - "upload_file_id": fields.String(required=True, description="Upload file ID"), - "indexing_technique": fields.String(description="Indexing technique"), - "process_rule": fields.Raw(description="Processing rules"), - "data_source": fields.Raw(description="Data source configuration"), - }, - ) - ) + @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) @console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model) @console_ns.response(400, "Invalid request parameters") @setup_required @@ -415,27 +412,7 @@ class DatasetInitApi(Resource): if not current_user.is_dataset_editor: raise Forbidden() - parser = ( - reqparse.RequestParser() - .add_argument( - "indexing_technique", - type=str, - choices=Dataset.INDEXING_TECHNIQUE_LIST, - required=True, - nullable=False, - location="json", - ) - .add_argument("data_source", type=dict, required=True, nullable=True, location="json") - .add_argument("process_rule", type=dict, required=True, nullable=True, location="json") - .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") - .add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") - .add_argument("embedding_model", type=str, required=False, nullable=True, location="json") - .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") - ) - args = parser.parse_args() - - knowledge_config = KnowledgeConfig.model_validate(args) + knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {}) if knowledge_config.indexing_technique == "high_quality": if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: raise ValueError("embedding model and embedding model provider are required for high quality indexing.") @@ -443,10 +420,14 @@ class DatasetInitApi(Resource): model_manager = ModelManager() model_manager.get_model_instance( tenant_id=current_tenant_id, - provider=args["embedding_model_provider"], + provider=knowledge_config.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=args["embedding_model"], + model=knowledge_config.embedding_model, ) + is_multimodal = DatasetService.check_is_multimodal_model( + current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model + ) + knowledge_config.is_multimodal = is_multimodal except InvokeAuthorizationError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." @@ -1076,19 +1057,16 @@ class DocumentRetryApi(DocumentResource): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") + @console_ns.expect(console_ns.models[DocumentRetryPayload.__name__]) def post(self, dataset_id): """retry document.""" - - parser = reqparse.RequestParser().add_argument( - "document_ids", type=list, required=True, nullable=False, location="json" - ) - args = parser.parse_args() + payload = DocumentRetryPayload.model_validate(console_ns.payload or {}) dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) retry_documents = [] if not dataset: raise NotFound("Dataset not found.") - for document_id in args["document_ids"]: + for document_id in payload.document_ids: try: document_id = str(document_id) @@ -1121,6 +1099,7 @@ class DocumentRenameApi(DocumentResource): @login_required @account_initialization_required @marshal_with(document_fields) + @console_ns.expect(console_ns.models[DocumentRenamePayload.__name__]) def post(self, dataset_id, document_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator current_user, _ = current_account_with_tenant() @@ -1130,11 +1109,10 @@ class DocumentRenameApi(DocumentResource): if not dataset: raise NotFound("Dataset not found.") DatasetService.check_dataset_operator_permission(current_user, dataset) - parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() + payload = DocumentRenamePayload.model_validate(console_ns.payload or {}) try: - document = DocumentService.rename_document(dataset_id, document_id, args["name"]) + document = DocumentService.rename_document(dataset_id, document_id, payload.name) except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 2fe7d42e46..e73abc2555 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,11 +1,13 @@ import uuid from flask import request -from flask_restx import Resource, marshal, reqparse +from flask_restx import Resource, marshal +from pydantic import BaseModel, Field from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound import services +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import ( @@ -36,6 +38,58 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task +class SegmentListQuery(BaseModel): + limit: int = Field(default=20, ge=1, le=100) + status: list[str] = Field(default_factory=list) + hit_count_gte: int | None = None + enabled: str = Field(default="all") + keyword: str | None = None + page: int = Field(default=1, ge=1) + + +class SegmentCreatePayload(BaseModel): + content: str + answer: str | None = None + keywords: list[str] | None = None + attachment_ids: list[str] | None = None + + +class SegmentUpdatePayload(BaseModel): + content: str + answer: str | None = None + keywords: list[str] | None = None + regenerate_child_chunks: bool = False + attachment_ids: list[str] | None = None + + +class BatchImportPayload(BaseModel): + upload_file_id: str + + +class ChildChunkCreatePayload(BaseModel): + content: str + + +class ChildChunkUpdatePayload(BaseModel): + content: str + + +class ChildChunkBatchUpdatePayload(BaseModel): + chunks: list[ChildChunkUpdateArgs] + + +register_schema_models( + console_ns, + SegmentListQuery, + SegmentCreatePayload, + SegmentUpdatePayload, + BatchImportPayload, + ChildChunkCreatePayload, + ChildChunkUpdatePayload, + ChildChunkBatchUpdatePayload, +) + + @console_ns.route("/datasets//documents//segments") class DatasetDocumentSegmentListApi(Resource): @setup_required @@ -60,23 +114,18 @@ class DatasetDocumentSegmentListApi(Resource): if not document: raise NotFound("Document not found.") - parser = ( - reqparse.RequestParser() - .add_argument("limit", type=int, default=20, location="args") - .add_argument("status", type=str, action="append", default=[], location="args") - .add_argument("hit_count_gte", type=int, default=None, location="args") - .add_argument("enabled", type=str, default="all", location="args") - .add_argument("keyword", type=str, default=None, location="args") - .add_argument("page", type=int, default=1, location="args") + args = SegmentListQuery.model_validate( + { + **request.args.to_dict(), + "status": request.args.getlist("status"), + } ) - args = parser.parse_args() - - page = args["page"] - limit = min(args["limit"], 100) - status_list = args["status"] - hit_count_gte = args["hit_count_gte"] - keyword = args["keyword"] + page = args.page + limit = min(args.limit, 100) + status_list = args.status + hit_count_gte = args.hit_count_gte + keyword = args.keyword query = ( select(DocumentSegment) @@ -96,10 +145,10 @@ class DatasetDocumentSegmentListApi(Resource): if keyword: query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) - if args["enabled"].lower() != "all": - if args["enabled"].lower() == "true": + if args.enabled.lower() != "all": + if args.enabled.lower() == "true": query = query.where(DocumentSegment.enabled == True) - elif args["enabled"].lower() == "false": + elif args.enabled.lower() == "false": query = query.where(DocumentSegment.enabled == False) segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) @@ -210,6 +259,7 @@ class DatasetDocumentSegmentAddApi(Resource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_rate_limit_check("knowledge") + @console_ns.expect(console_ns.models[SegmentCreatePayload.__name__]) def post(self, dataset_id, document_id): current_user, current_tenant_id = current_account_with_tenant() @@ -246,15 +296,10 @@ class DatasetDocumentSegmentAddApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) # validate args - parser = ( - reqparse.RequestParser() - .add_argument("content", type=str, required=True, nullable=False, location="json") - .add_argument("answer", type=str, required=False, nullable=True, location="json") - .add_argument("keywords", type=list, required=False, nullable=True, location="json") - ) - args = parser.parse_args() - SegmentService.segment_create_args_validate(args, document) - segment = SegmentService.create_segment(args, document, dataset) + payload = SegmentCreatePayload.model_validate(console_ns.payload or {}) + payload_dict = payload.model_dump(exclude_none=True) + SegmentService.segment_create_args_validate(payload_dict, document) + segment = SegmentService.create_segment(payload_dict, document, dataset) return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 @@ -265,6 +310,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") + @console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__]) def patch(self, dataset_id, document_id, segment_id): current_user, current_tenant_id = current_account_with_tenant() @@ -313,18 +359,12 @@ class DatasetDocumentSegmentUpdateApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) # validate args - parser = ( - reqparse.RequestParser() - .add_argument("content", type=str, required=True, nullable=False, location="json") - .add_argument("answer", type=str, required=False, nullable=True, location="json") - .add_argument("keywords", type=list, required=False, nullable=True, location="json") - .add_argument( - "regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json" - ) + payload = SegmentUpdatePayload.model_validate(console_ns.payload or {}) + payload_dict = payload.model_dump(exclude_none=True) + SegmentService.segment_create_args_validate(payload_dict, document) + segment = SegmentService.update_segment( + SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset ) - args = parser.parse_args() - SegmentService.segment_create_args_validate(args, document) - segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset) return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 @setup_required @@ -377,6 +417,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_rate_limit_check("knowledge") + @console_ns.expect(console_ns.models[BatchImportPayload.__name__]) def post(self, dataset_id, document_id): current_user, current_tenant_id = current_account_with_tenant() @@ -391,11 +432,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource): if not document: raise NotFound("Document not found.") - parser = reqparse.RequestParser().add_argument( - "upload_file_id", type=str, required=True, nullable=False, location="json" - ) - args = parser.parse_args() - upload_file_id = args["upload_file_id"] + payload = BatchImportPayload.model_validate(console_ns.payload or {}) + upload_file_id = payload.upload_file_id upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() if not upload_file: @@ -446,6 +484,7 @@ class ChildChunkAddApi(Resource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_rate_limit_check("knowledge") + @console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__]) def post(self, dataset_id, document_id, segment_id): current_user, current_tenant_id = current_account_with_tenant() @@ -491,13 +530,9 @@ class ChildChunkAddApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) # validate args - parser = reqparse.RequestParser().add_argument( - "content", type=str, required=True, nullable=False, location="json" - ) - args = parser.parse_args() try: - content = args["content"] - child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset) + payload = ChildChunkCreatePayload.model_validate(console_ns.payload or {}) + child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunk, child_chunk_fields)}, 200 @@ -529,18 +564,17 @@ class ChildChunkAddApi(Resource): ) if not segment: raise NotFound("Segment not found.") - parser = ( - reqparse.RequestParser() - .add_argument("limit", type=int, default=20, location="args") - .add_argument("keyword", type=str, default=None, location="args") - .add_argument("page", type=int, default=1, location="args") + args = SegmentListQuery.model_validate( + { + "limit": request.args.get("limit", default=20, type=int), + "keyword": request.args.get("keyword"), + "page": request.args.get("page", default=1, type=int), + } ) - args = parser.parse_args() - - page = args["page"] - limit = min(args["limit"], 100) - keyword = args["keyword"] + page = args.page + limit = min(args.limit, 100) + keyword = args.keyword child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword) return { @@ -588,14 +622,9 @@ class ChildChunkAddApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) # validate args - parser = reqparse.RequestParser().add_argument( - "chunks", type=list, required=True, nullable=False, location="json" - ) - args = parser.parse_args() + payload = ChildChunkBatchUpdatePayload.model_validate(console_ns.payload or {}) try: - chunks_data = args["chunks"] - chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data] - child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset) + child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunks, child_chunk_fields)}, 200 @@ -665,6 +694,7 @@ class ChildChunkUpdateApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") + @console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__]) def patch(self, dataset_id, document_id, segment_id, child_chunk_id): current_user, current_tenant_id = current_account_with_tenant() @@ -711,13 +741,9 @@ class ChildChunkUpdateApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) # validate args - parser = reqparse.RequestParser().add_argument( - "content", type=str, required=True, nullable=False, location="json" - ) - args = parser.parse_args() try: - content = args["content"] - child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset) + payload = ChildChunkUpdatePayload.model_validate(console_ns.payload or {}) + child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunk, child_chunk_fields)}, 200 diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 950884e496..89c9fcad36 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -1,8 +1,10 @@ from flask import request -from flask_restx import Resource, fields, marshal, reqparse +from flask_restx import Resource, fields, marshal +from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required @@ -71,10 +73,38 @@ except KeyError: dataset_detail_model = _build_dataset_detail_model() -def _validate_name(name: str) -> str: - if not name or len(name) < 1 or len(name) > 100: - raise ValueError("Name must be between 1 to 100 characters.") - return name +class ExternalKnowledgeApiPayload(BaseModel): + name: str = Field(..., min_length=1, max_length=40) + settings: dict[str, object] + + +class ExternalDatasetCreatePayload(BaseModel): + external_knowledge_api_id: str + external_knowledge_id: str + name: str = Field(..., min_length=1, max_length=40) + description: str | None = Field(None, max_length=400) + external_retrieval_model: dict[str, object] | None = None + + +class ExternalHitTestingPayload(BaseModel): + query: str + external_retrieval_model: dict[str, object] | None = None + metadata_filtering_conditions: dict[str, object] | None = None + + +class BedrockRetrievalPayload(BaseModel): + retrieval_setting: dict[str, object] + query: str + knowledge_id: str + + +register_schema_models( + console_ns, + ExternalKnowledgeApiPayload, + ExternalDatasetCreatePayload, + ExternalHitTestingPayload, + BedrockRetrievalPayload, +) @console_ns.route("/datasets/external-knowledge-api") @@ -113,28 +143,12 @@ class ExternalApiTemplateListApi(Resource): @setup_required @login_required @account_initialization_required + @console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__]) def post(self): current_user, current_tenant_id = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="Name is required. Name must be between 1 to 100 characters.", - type=_validate_name, - ) - .add_argument( - "settings", - type=dict, - location="json", - nullable=False, - required=True, - ) - ) - args = parser.parse_args() + payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {}) - ExternalDatasetService.validate_api_list(args["settings"]) + ExternalDatasetService.validate_api_list(payload.settings) # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: @@ -142,7 +156,7 @@ class ExternalApiTemplateListApi(Resource): try: external_knowledge_api = ExternalDatasetService.create_external_knowledge_api( - tenant_id=current_tenant_id, user_id=current_user.id, args=args + tenant_id=current_tenant_id, user_id=current_user.id, args=payload.model_dump() ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() @@ -171,35 +185,19 @@ class ExternalApiTemplateApi(Resource): @setup_required @login_required @account_initialization_required + @console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__]) def patch(self, external_knowledge_api_id): current_user, current_tenant_id = current_account_with_tenant() external_knowledge_api_id = str(external_knowledge_api_id) - parser = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="type is required. Name must be between 1 to 100 characters.", - type=_validate_name, - ) - .add_argument( - "settings", - type=dict, - location="json", - nullable=False, - required=True, - ) - ) - args = parser.parse_args() - ExternalDatasetService.validate_api_list(args["settings"]) + payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {}) + ExternalDatasetService.validate_api_list(payload.settings) external_knowledge_api = ExternalDatasetService.update_external_knowledge_api( tenant_id=current_tenant_id, user_id=current_user.id, external_knowledge_api_id=external_knowledge_api_id, - args=args, + args=payload.model_dump(), ) return external_knowledge_api.to_dict(), 200 @@ -240,17 +238,7 @@ class ExternalApiUseCheckApi(Resource): class ExternalDatasetCreateApi(Resource): @console_ns.doc("create_external_dataset") @console_ns.doc(description="Create external knowledge dataset") - @console_ns.expect( - console_ns.model( - "CreateExternalDatasetRequest", - { - "external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"), - "external_knowledge_id": fields.String(required=True, description="External knowledge ID"), - "name": fields.String(required=True, description="Dataset name"), - "description": fields.String(description="Dataset description"), - }, - ) - ) + @console_ns.expect(console_ns.models[ExternalDatasetCreatePayload.__name__]) @console_ns.response(201, "External dataset created successfully", dataset_detail_model) @console_ns.response(400, "Invalid parameters") @console_ns.response(403, "Permission denied") @@ -261,22 +249,8 @@ class ExternalDatasetCreateApi(Resource): def post(self): # The role of the current user in the ta table must be admin, owner, or editor current_user, current_tenant_id = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json") - .add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json") - .add_argument( - "name", - nullable=False, - required=True, - help="name is required. Name must be between 1 to 100 characters.", - type=_validate_name, - ) - .add_argument("description", type=str, required=False, nullable=True, location="json") - .add_argument("external_retrieval_model", type=dict, required=False, location="json") - ) - - args = parser.parse_args() + payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {}) + args = payload.model_dump(exclude_none=True) # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: @@ -299,16 +273,7 @@ class ExternalKnowledgeHitTestingApi(Resource): @console_ns.doc("test_external_knowledge_retrieval") @console_ns.doc(description="Test external knowledge retrieval for dataset") @console_ns.doc(params={"dataset_id": "Dataset ID"}) - @console_ns.expect( - console_ns.model( - "ExternalHitTestingRequest", - { - "query": fields.String(required=True, description="Query text for testing"), - "retrieval_model": fields.Raw(description="Retrieval model configuration"), - "external_retrieval_model": fields.Raw(description="External retrieval model configuration"), - }, - ) - ) + @console_ns.expect(console_ns.models[ExternalHitTestingPayload.__name__]) @console_ns.response(200, "External hit testing completed successfully") @console_ns.response(404, "Dataset not found") @console_ns.response(400, "Invalid parameters") @@ -327,23 +292,16 @@ class ExternalKnowledgeHitTestingApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - parser = ( - reqparse.RequestParser() - .add_argument("query", type=str, location="json") - .add_argument("external_retrieval_model", type=dict, required=False, location="json") - .add_argument("metadata_filtering_conditions", type=dict, required=False, location="json") - ) - args = parser.parse_args() - - HitTestingService.hit_testing_args_check(args) + payload = ExternalHitTestingPayload.model_validate(console_ns.payload or {}) + HitTestingService.hit_testing_args_check(payload.model_dump()) try: response = HitTestingService.external_retrieve( dataset=dataset, - query=args["query"], + query=payload.query, account=current_user, - external_retrieval_model=args["external_retrieval_model"], - metadata_filtering_conditions=args["metadata_filtering_conditions"], + external_retrieval_model=payload.external_retrieval_model, + metadata_filtering_conditions=payload.metadata_filtering_conditions, ) return response @@ -356,33 +314,13 @@ class BedrockRetrievalApi(Resource): # this api is only for internal testing @console_ns.doc("bedrock_retrieval_test") @console_ns.doc(description="Bedrock retrieval test (internal use only)") - @console_ns.expect( - console_ns.model( - "BedrockRetrievalTestRequest", - { - "retrieval_setting": fields.Raw(required=True, description="Retrieval settings"), - "query": fields.String(required=True, description="Query text"), - "knowledge_id": fields.String(required=True, description="Knowledge ID"), - }, - ) - ) + @console_ns.expect(console_ns.models[BedrockRetrievalPayload.__name__]) @console_ns.response(200, "Bedrock retrieval test completed") def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json") - .add_argument( - "query", - nullable=False, - required=True, - type=str, - ) - .add_argument("knowledge_id", nullable=False, required=True, type=str) - ) - args = parser.parse_args() + payload = BedrockRetrievalPayload.model_validate(console_ns.payload or {}) # Call the knowledge retrieval service result = ExternalDatasetTestService.knowledge_retrieval( - args["retrieval_setting"], args["query"], args["knowledge_id"] + payload.retrieval_setting, payload.query, payload.knowledge_id ) return result, 200 diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 7ba2eeb7dd..932cb4fcce 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,13 +1,17 @@ -from flask_restx import Resource, fields +from flask_restx import Resource -from controllers.console import console_ns -from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase -from controllers.console.wraps import ( +from controllers.common.schema import register_schema_model +from libs.login import login_required + +from .. import console_ns +from ..datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload +from ..wraps import ( account_initialization_required, cloud_edition_billing_rate_limit_check, setup_required, ) -from libs.login import login_required + +register_schema_model(console_ns, HitTestingPayload) @console_ns.route("/datasets//hit-testing") @@ -15,17 +19,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): @console_ns.doc("test_dataset_retrieval") @console_ns.doc(description="Test dataset knowledge retrieval") @console_ns.doc(params={"dataset_id": "Dataset ID"}) - @console_ns.expect( - console_ns.model( - "HitTestingRequest", - { - "query": fields.String(required=True, description="Query text for testing"), - "retrieval_model": fields.Raw(description="Retrieval model configuration"), - "top_k": fields.Integer(description="Number of top results to return"), - "score_threshold": fields.Float(description="Score threshold for filtering results"), - }, - ) - ) + @console_ns.expect(console_ns.models[HitTestingPayload.__name__]) @console_ns.response(200, "Hit testing completed successfully") @console_ns.response(404, "Dataset not found") @console_ns.response(400, "Invalid parameters") @@ -37,7 +31,8 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): dataset_id_str = str(dataset_id) dataset = self.get_and_validate_dataset(dataset_id_str) - args = self.parse_args() + payload = HitTestingPayload.model_validate(console_ns.payload or {}) + args = payload.model_dump(exclude_none=True) self.hit_testing_args_check(args) return self.perform_hit_testing(dataset, args) diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 99d4d5a29c..db7c50f422 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,6 +1,8 @@ import logging +from typing import Any from flask_restx import marshal, reqparse +from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -27,6 +29,13 @@ from services.hit_testing_service import HitTestingService logger = logging.getLogger(__name__) +class HitTestingPayload(BaseModel): + query: str = Field(max_length=250) + retrieval_model: dict[str, Any] | None = None + external_retrieval_model: dict[str, Any] | None = None + attachment_ids: list[str] | None = None + + class DatasetsHitTestingBase: @staticmethod def get_and_validate_dataset(dataset_id: str): @@ -43,14 +52,15 @@ class DatasetsHitTestingBase: return dataset @staticmethod - def hit_testing_args_check(args): + def hit_testing_args_check(args: dict[str, Any]): HitTestingService.hit_testing_args_check(args) @staticmethod def parse_args(): parser = ( reqparse.RequestParser() - .add_argument("query", type=str, location="json") + .add_argument("query", type=str, required=False, location="json") + .add_argument("attachment_ids", type=list, required=False, location="json") .add_argument("retrieval_model", type=dict, required=False, location="json") .add_argument("external_retrieval_model", type=dict, required=False, location="json") ) @@ -62,10 +72,11 @@ class DatasetsHitTestingBase: try: response = HitTestingService.retrieve( dataset=dataset, - query=args["query"], + query=args.get("query"), account=current_user, - retrieval_model=args["retrieval_model"], - external_retrieval_model=args["external_retrieval_model"], + retrieval_model=args.get("retrieval_model"), + external_retrieval_model=args.get("external_retrieval_model"), + attachment_ids=args.get("attachment_ids"), limit=10, ) return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 72b2ff0ff8..8eead1696a 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -1,8 +1,10 @@ from typing import Literal -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_model, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from fields.dataset_fields import dataset_metadata_fields @@ -15,6 +17,14 @@ from services.entities.knowledge_entities.knowledge_entities import ( from services.metadata_service import MetadataService +class MetadataUpdatePayload(BaseModel): + name: str + + +register_schema_models(console_ns, MetadataArgs, MetadataOperationData) +register_schema_model(console_ns, MetadataUpdatePayload) + + @console_ns.route("/datasets//metadata") class DatasetMetadataCreateApi(Resource): @setup_required @@ -22,15 +32,10 @@ class DatasetMetadataCreateApi(Resource): @account_initialization_required @enterprise_license_required @marshal_with(dataset_metadata_fields) + @console_ns.expect(console_ns.models[MetadataArgs.__name__]) def post(self, dataset_id): current_user, _ = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("type", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - ) - args = parser.parse_args() - metadata_args = MetadataArgs.model_validate(args) + metadata_args = MetadataArgs.model_validate(console_ns.payload or {}) dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -60,11 +65,11 @@ class DatasetMetadataApi(Resource): @account_initialization_required @enterprise_license_required @marshal_with(dataset_metadata_fields) + @console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__]) def patch(self, dataset_id, metadata_id): current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() - name = args["name"] + payload = MetadataUpdatePayload.model_validate(console_ns.payload or {}) + name = payload.name dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) @@ -131,6 +136,7 @@ class DocumentMetadataEditApi(Resource): @login_required @account_initialization_required @enterprise_license_required + @console_ns.expect(console_ns.models[MetadataOperationData.__name__]) def post(self, dataset_id): current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) @@ -139,11 +145,7 @@ class DocumentMetadataEditApi(Resource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - parser = reqparse.RequestParser().add_argument( - "operation_data", type=list, required=True, nullable=False, location="json" - ) - args = parser.parse_args() - metadata_args = MetadataOperationData.model_validate(args) + metadata_args = MetadataOperationData.model_validate(console_ns.payload or {}) MetadataService.update_documents_metadata(dataset, metadata_args) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index cf9e5d2990..1a47e226e5 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -1,20 +1,63 @@ +from typing import Any + from flask import make_response, redirect, request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, NotFound from configs import dify_config +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.oauth import OAuthHandler -from libs.helper import StrLen from libs.login import current_account_with_tenant, login_required from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService +class DatasourceCredentialPayload(BaseModel): + name: str | None = Field(default=None, max_length=100) + credentials: dict[str, Any] + + +class DatasourceCredentialDeletePayload(BaseModel): + credential_id: str + + +class DatasourceCredentialUpdatePayload(BaseModel): + credential_id: str + name: str | None = Field(default=None, max_length=100) + credentials: dict[str, Any] | None = None + + +class DatasourceCustomClientPayload(BaseModel): + client_params: dict[str, Any] | None = None + enable_oauth_custom_client: bool | None = None + + +class DatasourceDefaultPayload(BaseModel): + id: str + + +class DatasourceUpdateNamePayload(BaseModel): + credential_id: str + name: str = Field(max_length=100) + + +register_schema_models( + console_ns, + DatasourceCredentialPayload, + DatasourceCredentialDeletePayload, + DatasourceCredentialUpdatePayload, + DatasourceCustomClientPayload, + DatasourceDefaultPayload, + DatasourceUpdateNamePayload, +) + + @console_ns.route("/oauth/plugin//datasource/get-authorization-url") class DatasourcePluginOAuthAuthorizationUrl(Resource): @setup_required @@ -121,16 +164,9 @@ class DatasourceOAuthCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") -parser_datasource = ( - reqparse.RequestParser() - .add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None) - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") -) - - @console_ns.route("/auth/plugin/datasource/") class DatasourceAuth(Resource): - @console_ns.expect(parser_datasource) + @console_ns.expect(console_ns.models[DatasourceCredentialPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -138,7 +174,7 @@ class DatasourceAuth(Resource): def post(self, provider_id: str): _, current_tenant_id = current_account_with_tenant() - args = parser_datasource.parse_args() + payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {}) datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() @@ -146,8 +182,8 @@ class DatasourceAuth(Resource): datasource_provider_service.add_datasource_api_key_provider( tenant_id=current_tenant_id, provider_id=datasource_provider_id, - credentials=args["credentials"], - name=args["name"], + credentials=payload.credentials, + name=payload.name, ) except CredentialsValidateFailedError as ex: raise ValueError(str(ex)) @@ -169,14 +205,9 @@ class DatasourceAuth(Resource): return {"result": datasources}, 200 -parser_datasource_delete = reqparse.RequestParser().add_argument( - "credential_id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/auth/plugin/datasource//delete") class DatasourceAuthDeleteApi(Resource): - @console_ns.expect(parser_datasource_delete) + @console_ns.expect(console_ns.models[DatasourceCredentialDeletePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -188,28 +219,20 @@ class DatasourceAuthDeleteApi(Resource): plugin_id = datasource_provider_id.plugin_id provider_name = datasource_provider_id.provider_name - args = parser_datasource_delete.parse_args() + payload = DatasourceCredentialDeletePayload.model_validate(console_ns.payload or {}) datasource_provider_service = DatasourceProviderService() datasource_provider_service.remove_datasource_credentials( tenant_id=current_tenant_id, - auth_id=args["credential_id"], + auth_id=payload.credential_id, provider=provider_name, plugin_id=plugin_id, ) return {"result": "success"}, 200 -parser_datasource_update = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=False, nullable=True, location="json") - .add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json") - .add_argument("credential_id", type=str, required=True, nullable=False, location="json") -) - - @console_ns.route("/auth/plugin/datasource//update") class DatasourceAuthUpdateApi(Resource): - @console_ns.expect(parser_datasource_update) + @console_ns.expect(console_ns.models[DatasourceCredentialUpdatePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -218,16 +241,16 @@ class DatasourceAuthUpdateApi(Resource): _, current_tenant_id = current_account_with_tenant() datasource_provider_id = DatasourceProviderID(provider_id) - args = parser_datasource_update.parse_args() + payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {}) datasource_provider_service = DatasourceProviderService() datasource_provider_service.update_datasource_credentials( tenant_id=current_tenant_id, - auth_id=args["credential_id"], + auth_id=payload.credential_id, provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id, - credentials=args.get("credentials", {}), - name=args.get("name", None), + credentials=payload.credentials or {}, + name=payload.name, ) return {"result": "success"}, 201 @@ -258,16 +281,9 @@ class DatasourceHardCodeAuthListApi(Resource): return {"result": jsonable_encoder(datasources)}, 200 -parser_datasource_custom = ( - reqparse.RequestParser() - .add_argument("client_params", type=dict, required=False, nullable=True, location="json") - .add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") -) - - @console_ns.route("/auth/plugin/datasource//custom-client") class DatasourceAuthOauthCustomClient(Resource): - @console_ns.expect(parser_datasource_custom) + @console_ns.expect(console_ns.models[DatasourceCustomClientPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -275,14 +291,14 @@ class DatasourceAuthOauthCustomClient(Resource): def post(self, provider_id: str): _, current_tenant_id = current_account_with_tenant() - args = parser_datasource_custom.parse_args() + payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {}) datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() datasource_provider_service.setup_oauth_custom_client_params( tenant_id=current_tenant_id, datasource_provider_id=datasource_provider_id, - client_params=args.get("client_params", {}), - enabled=args.get("enable_oauth_custom_client", False), + client_params=payload.client_params or {}, + enabled=payload.enable_oauth_custom_client or False, ) return {"result": "success"}, 200 @@ -301,12 +317,9 @@ class DatasourceAuthOauthCustomClient(Resource): return {"result": "success"}, 200 -parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json") - - @console_ns.route("/auth/plugin/datasource//default") class DatasourceAuthDefaultApi(Resource): - @console_ns.expect(parser_default) + @console_ns.expect(console_ns.models[DatasourceDefaultPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -314,27 +327,20 @@ class DatasourceAuthDefaultApi(Resource): def post(self, provider_id: str): _, current_tenant_id = current_account_with_tenant() - args = parser_default.parse_args() + payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {}) datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() datasource_provider_service.set_default_datasource_provider( tenant_id=current_tenant_id, datasource_provider_id=datasource_provider_id, - credential_id=args["id"], + credential_id=payload.id, ) return {"result": "success"}, 200 -parser_update_name = ( - reqparse.RequestParser() - .add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json") - .add_argument("credential_id", type=str, required=True, nullable=False, location="json") -) - - @console_ns.route("/auth/plugin/datasource//update-name") class DatasourceUpdateProviderNameApi(Resource): - @console_ns.expect(parser_update_name) + @console_ns.expect(console_ns.models[DatasourceUpdateNamePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -342,13 +348,13 @@ class DatasourceUpdateProviderNameApi(Resource): def post(self, provider_id: str): _, current_tenant_id = current_account_with_tenant() - args = parser_update_name.parse_args() + payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {}) datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() datasource_provider_service.update_datasource_provider_name( tenant_id=current_tenant_id, datasource_provider_id=datasource_provider_id, - name=args["name"], - credential_id=args["credential_id"], + name=payload.name, + credential_id=payload.credential_id, ) return {"result": "success"}, 200 diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py index 42387557d6..7caf5b52ed 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py @@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D @console_ns.route("/rag/pipelines//workflows/published/datasource/nodes//preview") class DataSourceContentPreviewApi(Resource): - @console_ns.expect(console_ns.models[Parser.__name__], validate=True) + @console_ns.expect(console_ns.models[Parser.__name__]) @setup_required @login_required @account_initialization_required diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index f589bba3bf..6e0cd31b8d 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -1,9 +1,11 @@ import logging from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field from sqlalchemy.orm import Session +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, @@ -20,18 +22,6 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService logger = logging.getLogger(__name__) -def _validate_name(name: str) -> str: - if not name or len(name) < 1 or len(name) > 40: - raise ValueError("Name must be between 1 to 40 characters.") - return name - - -def _validate_description_length(description: str) -> str: - if len(description) > 400: - raise ValueError("Description cannot exceed 400 characters.") - return description - - @console_ns.route("/rag/pipeline/templates") class PipelineTemplateListApi(Resource): @setup_required @@ -59,6 +49,15 @@ class PipelineTemplateDetailApi(Resource): return pipeline_template, 200 +class Payload(BaseModel): + name: str = Field(..., min_length=1, max_length=40) + description: str = Field(default="", max_length=400) + icon_info: dict[str, object] | None = None + + +register_schema_models(console_ns, Payload) + + @console_ns.route("/rag/pipeline/customized/templates/") class CustomizedPipelineTemplateApi(Resource): @setup_required @@ -66,31 +65,8 @@ class CustomizedPipelineTemplateApi(Resource): @account_initialization_required @enterprise_license_required def patch(self, template_id: str): - parser = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 40 characters.", - type=_validate_name, - ) - .add_argument( - "description", - type=_validate_description_length, - nullable=True, - required=False, - default="", - ) - .add_argument( - "icon_info", - type=dict, - location="json", - nullable=True, - ) - ) - args = parser.parse_args() - pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args) + payload = Payload.model_validate(console_ns.payload or {}) + pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump()) RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) return 200 @@ -119,36 +95,14 @@ class CustomizedPipelineTemplateApi(Resource): @console_ns.route("/rag/pipelines//customized/publish") class PublishCustomizedPipelineTemplateApi(Resource): + @console_ns.expect(console_ns.models[Payload.__name__]) @setup_required @login_required @account_initialization_required @enterprise_license_required @knowledge_pipeline_publish_enabled def post(self, pipeline_id: str): - parser = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 40 characters.", - type=_validate_name, - ) - .add_argument( - "description", - type=_validate_description_length, - nullable=True, - required=False, - default="", - ) - .add_argument( - "icon_info", - type=dict, - location="json", - nullable=True, - ) - ) - args = parser.parse_args() + payload = Payload.model_validate(console_ns.payload or {}) rag_pipeline_service = RagPipelineService() - rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args) + rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump()) return {"result": "success"} diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py index 98876e9f5e..e65cb19b39 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -1,8 +1,10 @@ -from flask_restx import Resource, marshal, reqparse +from flask_restx import Resource, marshal +from pydantic import BaseModel from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden import services +from controllers.common.schema import register_schema_model from controllers.console import console_ns from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.wraps import ( @@ -19,22 +21,22 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService +class RagPipelineDatasetImportPayload(BaseModel): + yaml_content: str + + +register_schema_model(console_ns, RagPipelineDatasetImportPayload) + + @console_ns.route("/rag/pipeline/dataset") class CreateRagPipelineDatasetApi(Resource): + @console_ns.expect(console_ns.models[RagPipelineDatasetImportPayload.__name__]) @setup_required @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") def post(self): - parser = reqparse.RequestParser().add_argument( - "yaml_content", - type=str, - nullable=False, - required=True, - help="yaml_content is required.", - ) - - args = parser.parse_args() + payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {}) current_user, current_tenant_id = current_account_with_tenant() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: @@ -49,7 +51,7 @@ class CreateRagPipelineDatasetApi(Resource): ), permission=DatasetPermissionEnum.ONLY_ME, partial_member_list=None, - yaml_content=args["yaml_content"], + yaml_content=payload.yaml_content, ) try: with Session(db.engine) as session: diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 858ba94bf8..720e2ce365 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -1,11 +1,13 @@ import logging -from typing import NoReturn +from typing import Any, NoReturn -from flask import Response -from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse +from flask import Response, request +from flask_restx import Resource, fields, marshal, marshal_with +from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( DraftWorkflowNotExist, @@ -33,19 +35,21 @@ logger = logging.getLogger(__name__) def _create_pagination_parser(): - parser = ( - reqparse.RequestParser() - .add_argument( - "page", - type=inputs.int_range(1, 100_000), - required=False, - default=1, - location="args", - help="the page of data requested", - ) - .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") - ) - return parser + class PaginationQuery(BaseModel): + page: int = Field(default=1, ge=1, le=100_000) + limit: int = Field(default=20, ge=1, le=100) + + register_schema_models(console_ns, PaginationQuery) + + return PaginationQuery + + +class WorkflowDraftVariablePatchPayload(BaseModel): + name: str | None = None + value: Any | None = None + + +register_schema_models(console_ns, WorkflowDraftVariablePatchPayload) def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]: @@ -93,8 +97,8 @@ class RagPipelineVariableCollectionApi(Resource): """ Get draft workflow """ - parser = _create_pagination_parser() - args = parser.parse_args() + pagination = _create_pagination_parser() + query = pagination.model_validate(request.args.to_dict()) # fetch draft workflow by app_model rag_pipeline_service = RagPipelineService() @@ -109,8 +113,8 @@ class RagPipelineVariableCollectionApi(Resource): ) workflow_vars = draft_var_srv.list_variables_without_values( app_id=pipeline.id, - page=args.page, - limit=args.limit, + page=query.page, + limit=query.limit, ) return workflow_vars @@ -186,6 +190,7 @@ class RagPipelineVariableApi(Resource): @_api_prerequisite @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) + @console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__]) def patch(self, pipeline: Pipeline, variable_id: str): # Request payload for file types: # @@ -208,16 +213,11 @@ class RagPipelineVariableApi(Resource): # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" # } - parser = ( - reqparse.RequestParser() - .add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") - .add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") - ) - draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) - args = parser.parse_args(strict=True) + payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {}) + args = payload.model_dump(exclude_none=True) variable = draft_var_srv.get_variable(variable_id=variable_id) if variable is None: diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py index d658d65b71..d43ee9a6e0 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -1,6 +1,9 @@ -from flask_restx import Resource, marshal_with, reqparse # type: ignore +from flask import request +from flask_restx import Resource, marshal_with # type: ignore +from pydantic import BaseModel, Field from sqlalchemy.orm import Session +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import ( @@ -16,6 +19,25 @@ from services.app_dsl_service import ImportStatus from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService +class RagPipelineImportPayload(BaseModel): + mode: str + yaml_content: str | None = None + yaml_url: str | None = None + name: str | None = None + description: str | None = None + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + pipeline_id: str | None = None + + +class IncludeSecretQuery(BaseModel): + include_secret: str = Field(default="false") + + +register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery) + + @console_ns.route("/rag/pipelines/imports") class RagPipelineImportApi(Resource): @setup_required @@ -23,23 +45,11 @@ class RagPipelineImportApi(Resource): @account_initialization_required @edit_permission_required @marshal_with(pipeline_import_fields) + @console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__]) def post(self): # Check user role first current_user, _ = current_account_with_tenant() - - parser = ( - reqparse.RequestParser() - .add_argument("mode", type=str, required=True, location="json") - .add_argument("yaml_content", type=str, location="json") - .add_argument("yaml_url", type=str, location="json") - .add_argument("name", type=str, location="json") - .add_argument("description", type=str, location="json") - .add_argument("icon_type", type=str, location="json") - .add_argument("icon", type=str, location="json") - .add_argument("icon_background", type=str, location="json") - .add_argument("pipeline_id", type=str, location="json") - ) - args = parser.parse_args() + payload = RagPipelineImportPayload.model_validate(console_ns.payload or {}) # Create service with session with Session(db.engine) as session: @@ -48,11 +58,11 @@ class RagPipelineImportApi(Resource): account = current_user result = import_service.import_rag_pipeline( account=account, - import_mode=args["mode"], - yaml_content=args.get("yaml_content"), - yaml_url=args.get("yaml_url"), - pipeline_id=args.get("pipeline_id"), - dataset_name=args.get("name"), + import_mode=payload.mode, + yaml_content=payload.yaml_content, + yaml_url=payload.yaml_url, + pipeline_id=payload.pipeline_id, + dataset_name=payload.name, ) session.commit() @@ -114,13 +124,12 @@ class RagPipelineExportApi(Resource): @edit_permission_required def get(self, pipeline: Pipeline): # Add include_secret params - parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args") - args = parser.parse_args() + query = IncludeSecretQuery.model_validate(request.args.to_dict()) with Session(db.engine) as session: export_service = RagPipelineDslService(session) result = export_service.export_rag_pipeline_dsl( - pipeline=pipeline, include_secret=args["include_secret"] == "true" + pipeline=pipeline, include_secret=query.include_secret == "true" ) return {"data": result}, 200 diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index a0dc692c4e..debe8eed97 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -1,14 +1,16 @@ import json import logging -from typing import cast +from typing import Any, Literal, cast +from uuid import UUID from flask import abort, request -from flask_restx import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore -from flask_restx.inputs import int_range # type: ignore +from flask_restx import Resource, marshal_with # type: ignore +from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( ConversationCompletedError, @@ -36,7 +38,7 @@ from fields.workflow_run_fields import ( workflow_run_pagination_fields, ) from libs import helper -from libs.helper import TimestampField, uuid_value +from libs.helper import TimestampField from libs.login import current_account_with_tenant, current_user, login_required from models import Account from models.dataset import Pipeline @@ -51,6 +53,91 @@ from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTran logger = logging.getLogger(__name__) +class DraftWorkflowSyncPayload(BaseModel): + graph: dict[str, Any] + hash: str | None = None + environment_variables: list[dict[str, Any]] | None = None + conversation_variables: list[dict[str, Any]] | None = None + rag_pipeline_variables: list[dict[str, Any]] | None = None + features: dict[str, Any] | None = None + + +class NodeRunPayload(BaseModel): + inputs: dict[str, Any] | None = None + + +class NodeRunRequiredPayload(BaseModel): + inputs: dict[str, Any] + + +class DatasourceNodeRunPayload(BaseModel): + inputs: dict[str, Any] + datasource_type: str + credential_id: str | None = None + + +class DraftWorkflowRunPayload(BaseModel): + inputs: dict[str, Any] + datasource_type: str + datasource_info_list: list[dict[str, Any]] + start_node_id: str + + +class PublishedWorkflowRunPayload(DraftWorkflowRunPayload): + is_preview: bool = False + response_mode: Literal["streaming", "blocking"] = "streaming" + original_document_id: str | None = None + + +class DefaultBlockConfigQuery(BaseModel): + q: str | None = None + + +class WorkflowListQuery(BaseModel): + page: int = Field(default=1, ge=1, le=99999) + limit: int = Field(default=10, ge=1, le=100) + user_id: str | None = None + named_only: bool = False + + +class WorkflowUpdatePayload(BaseModel): + marked_name: str | None = Field(default=None, max_length=20) + marked_comment: str | None = Field(default=None, max_length=100) + + +class NodeIdQuery(BaseModel): + node_id: str + + +class WorkflowRunQuery(BaseModel): + last_id: UUID | None = None + limit: int = Field(default=20, ge=1, le=100) + + +class DatasourceVariablesPayload(BaseModel): + datasource_type: str + datasource_info: dict[str, Any] + start_node_id: str + start_node_title: str + + +register_schema_models( + console_ns, + DraftWorkflowSyncPayload, + NodeRunPayload, + NodeRunRequiredPayload, + DatasourceNodeRunPayload, + DraftWorkflowRunPayload, + PublishedWorkflowRunPayload, + DefaultBlockConfigQuery, + WorkflowListQuery, + WorkflowUpdatePayload, + NodeIdQuery, + WorkflowRunQuery, + DatasourceVariablesPayload, +) + + @console_ns.route("/rag/pipelines//workflows/draft") class DraftRagPipelineApi(Resource): @setup_required @@ -88,15 +175,7 @@ class DraftRagPipelineApi(Resource): content_type = request.headers.get("Content-Type", "") if "application/json" in content_type: - parser = ( - reqparse.RequestParser() - .add_argument("graph", type=dict, required=True, nullable=False, location="json") - .add_argument("hash", type=str, required=False, location="json") - .add_argument("environment_variables", type=list, required=False, location="json") - .add_argument("conversation_variables", type=list, required=False, location="json") - .add_argument("rag_pipeline_variables", type=list, required=False, location="json") - ) - args = parser.parse_args() + payload_dict = console_ns.payload or {} elif "text/plain" in content_type: try: data = json.loads(request.data.decode("utf-8")) @@ -106,7 +185,7 @@ class DraftRagPipelineApi(Resource): if not isinstance(data.get("graph"), dict): raise ValueError("graph is not a dict") - args = { + payload_dict = { "graph": data.get("graph"), "features": data.get("features"), "hash": data.get("hash"), @@ -119,24 +198,26 @@ class DraftRagPipelineApi(Resource): else: abort(415) + payload = DraftWorkflowSyncPayload.model_validate(payload_dict) + try: - environment_variables_list = args.get("environment_variables") or [] + environment_variables_list = payload.environment_variables or [] environment_variables = [ variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list ] - conversation_variables_list = args.get("conversation_variables") or [] + conversation_variables_list = payload.conversation_variables or [] conversation_variables = [ variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.sync_draft_workflow( pipeline=pipeline, - graph=args["graph"], - unique_hash=args.get("hash"), + graph=payload.graph, + unique_hash=payload.hash, account=current_user, environment_variables=environment_variables, conversation_variables=conversation_variables, - rag_pipeline_variables=args.get("rag_pipeline_variables") or [], + rag_pipeline_variables=payload.rag_pipeline_variables or [], ) except WorkflowHashNotEqualError: raise DraftWorkflowNotSync() @@ -148,12 +229,9 @@ class DraftRagPipelineApi(Resource): } -parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") - - @console_ns.route("/rag/pipelines//workflows/draft/iteration/nodes//run") class RagPipelineDraftRunIterationNodeApi(Resource): - @console_ns.expect(parser_run) + @console_ns.expect(console_ns.models[NodeRunPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -166,7 +244,8 @@ class RagPipelineDraftRunIterationNodeApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - args = parser_run.parse_args() + payload = NodeRunPayload.model_validate(console_ns.payload or {}) + args = payload.model_dump(exclude_none=True) try: response = PipelineGenerateService.generate_single_iteration( @@ -187,7 +266,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource): @console_ns.route("/rag/pipelines//workflows/draft/loop/nodes//run") class RagPipelineDraftRunLoopNodeApi(Resource): - @console_ns.expect(parser_run) + @console_ns.expect(console_ns.models[NodeRunPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -200,7 +279,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - args = parser_run.parse_args() + payload = NodeRunPayload.model_validate(console_ns.payload or {}) + args = payload.model_dump(exclude_none=True) try: response = PipelineGenerateService.generate_single_loop( @@ -219,18 +299,9 @@ class RagPipelineDraftRunLoopNodeApi(Resource): raise InternalServerError() -parser_draft_run = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("datasource_type", type=str, required=True, location="json") - .add_argument("datasource_info_list", type=list, required=True, location="json") - .add_argument("start_node_id", type=str, required=True, location="json") -) - - @console_ns.route("/rag/pipelines//workflows/draft/run") class DraftRagPipelineRunApi(Resource): - @console_ns.expect(parser_draft_run) + @console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -243,7 +314,8 @@ class DraftRagPipelineRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - args = parser_draft_run.parse_args() + payload = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}) + args = payload.model_dump() try: response = PipelineGenerateService.generate( @@ -259,21 +331,9 @@ class DraftRagPipelineRunApi(Resource): raise InvokeRateLimitHttpError(ex.description) -parser_published_run = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("datasource_type", type=str, required=True, location="json") - .add_argument("datasource_info_list", type=list, required=True, location="json") - .add_argument("start_node_id", type=str, required=True, location="json") - .add_argument("is_preview", type=bool, required=True, location="json", default=False) - .add_argument("response_mode", type=str, required=True, location="json", default="streaming") - .add_argument("original_document_id", type=str, required=False, location="json") -) - - @console_ns.route("/rag/pipelines//workflows/published/run") class PublishedRagPipelineRunApi(Resource): - @console_ns.expect(parser_published_run) + @console_ns.expect(console_ns.models[PublishedWorkflowRunPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -286,16 +346,16 @@ class PublishedRagPipelineRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - args = parser_published_run.parse_args() - - streaming = args["response_mode"] == "streaming" + payload = PublishedWorkflowRunPayload.model_validate(console_ns.payload or {}) + args = payload.model_dump(exclude_none=True) + streaming = payload.response_mode == "streaming" try: response = PipelineGenerateService.generate( pipeline=pipeline, user=current_user, args=args, - invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED, + invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED, streaming=streaming, ) @@ -387,17 +447,9 @@ class PublishedRagPipelineRunApi(Resource): # # return result # -parser_rag_run = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("datasource_type", type=str, required=True, location="json") - .add_argument("credential_id", type=str, required=False, location="json") -) - - @console_ns.route("/rag/pipelines//workflows/published/datasource/nodes//run") class RagPipelinePublishedDatasourceNodeRunApi(Resource): - @console_ns.expect(parser_rag_run) + @console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -410,14 +462,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - args = parser_rag_run.parse_args() - - inputs = args.get("inputs") - if inputs is None: - raise ValueError("missing inputs") - datasource_type = args.get("datasource_type") - if datasource_type is None: - raise ValueError("missing datasource_type") + payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {}) rag_pipeline_service = RagPipelineService() return helper.compact_generate_response( @@ -425,11 +470,11 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): rag_pipeline_service.run_datasource_workflow_node( pipeline=pipeline, node_id=node_id, - user_inputs=inputs, + user_inputs=payload.inputs, account=current_user, - datasource_type=datasource_type, + datasource_type=payload.datasource_type, is_published=False, - credential_id=args.get("credential_id"), + credential_id=payload.credential_id, ) ) ) @@ -437,7 +482,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): @console_ns.route("/rag/pipelines//workflows/draft/datasource/nodes//run") class RagPipelineDraftDatasourceNodeRunApi(Resource): - @console_ns.expect(parser_rag_run) + @console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__]) @setup_required @login_required @edit_permission_required @@ -450,14 +495,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - args = parser_rag_run.parse_args() - - inputs = args.get("inputs") - if inputs is None: - raise ValueError("missing inputs") - datasource_type = args.get("datasource_type") - if datasource_type is None: - raise ValueError("missing datasource_type") + payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {}) rag_pipeline_service = RagPipelineService() return helper.compact_generate_response( @@ -465,24 +503,19 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): rag_pipeline_service.run_datasource_workflow_node( pipeline=pipeline, node_id=node_id, - user_inputs=inputs, + user_inputs=payload.inputs, account=current_user, - datasource_type=datasource_type, + datasource_type=payload.datasource_type, is_published=False, - credential_id=args.get("credential_id"), + credential_id=payload.credential_id, ) ) ) -parser_run_api = reqparse.RequestParser().add_argument( - "inputs", type=dict, required=True, nullable=False, location="json" -) - - @console_ns.route("/rag/pipelines//workflows/draft/nodes//run") class RagPipelineDraftNodeRunApi(Resource): - @console_ns.expect(parser_run_api) + @console_ns.expect(console_ns.models[NodeRunRequiredPayload.__name__]) @setup_required @login_required @edit_permission_required @@ -496,11 +529,8 @@ class RagPipelineDraftNodeRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - args = parser_run_api.parse_args() - - inputs = args.get("inputs") - if inputs == None: - raise ValueError("missing inputs") + payload = NodeRunRequiredPayload.model_validate(console_ns.payload or {}) + inputs = payload.inputs rag_pipeline_service = RagPipelineService() workflow_node_execution = rag_pipeline_service.run_draft_workflow_node( @@ -602,12 +632,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource): return rag_pipeline_service.get_default_block_configs() -parser_default = reqparse.RequestParser().add_argument("q", type=str, location="args") - - @console_ns.route("/rag/pipelines//workflows/default-workflow-block-configs/") class DefaultRagPipelineBlockConfigApi(Resource): - @console_ns.expect(parser_default) @setup_required @login_required @account_initialization_required @@ -617,14 +643,12 @@ class DefaultRagPipelineBlockConfigApi(Resource): """ Get default block config """ - args = parser_default.parse_args() - - q = args.get("q") + query = DefaultBlockConfigQuery.model_validate(request.args.to_dict()) filters = None - if q: + if query.q: try: - filters = json.loads(args.get("q", "")) + filters = json.loads(query.q) except json.JSONDecodeError: raise ValueError("Invalid filters") @@ -633,18 +657,8 @@ class DefaultRagPipelineBlockConfigApi(Resource): return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters) -parser_wf = ( - reqparse.RequestParser() - .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") - .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args") - .add_argument("user_id", type=str, required=False, location="args") - .add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args") -) - - @console_ns.route("/rag/pipelines//workflows") class PublishedAllRagPipelineApi(Resource): - @console_ns.expect(parser_wf) @setup_required @login_required @account_initialization_required @@ -657,16 +671,16 @@ class PublishedAllRagPipelineApi(Resource): """ current_user, _ = current_account_with_tenant() - args = parser_wf.parse_args() - page = args["page"] - limit = args["limit"] - user_id = args.get("user_id") - named_only = args.get("named_only", False) + query = WorkflowListQuery.model_validate(request.args.to_dict()) + + page = query.page + limit = query.limit + user_id = query.user_id + named_only = query.named_only if user_id: if user_id != current_user.id: raise Forbidden() - user_id = cast(str, user_id) rag_pipeline_service = RagPipelineService() with Session(db.engine) as session: @@ -687,16 +701,8 @@ class PublishedAllRagPipelineApi(Resource): } -parser_wf_id = ( - reqparse.RequestParser() - .add_argument("marked_name", type=str, required=False, location="json") - .add_argument("marked_comment", type=str, required=False, location="json") -) - - @console_ns.route("/rag/pipelines//workflows/") class RagPipelineByIdApi(Resource): - @console_ns.expect(parser_wf_id) @setup_required @login_required @account_initialization_required @@ -710,20 +716,8 @@ class RagPipelineByIdApi(Resource): # Check permission current_user, _ = current_account_with_tenant() - args = parser_wf_id.parse_args() - - # Validate name and comment length - if args.marked_name and len(args.marked_name) > 20: - raise ValueError("Marked name cannot exceed 20 characters") - if args.marked_comment and len(args.marked_comment) > 100: - raise ValueError("Marked comment cannot exceed 100 characters") - - # Prepare update data - update_data = {} - if args.get("marked_name") is not None: - update_data["marked_name"] = args["marked_name"] - if args.get("marked_comment") is not None: - update_data["marked_comment"] = args["marked_comment"] + payload = WorkflowUpdatePayload.model_validate(console_ns.payload or {}) + update_data = payload.model_dump(exclude_unset=True) if not update_data: return {"message": "No valid fields to update"}, 400 @@ -749,12 +743,8 @@ class RagPipelineByIdApi(Resource): return workflow -parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args") - - @console_ns.route("/rag/pipelines//workflows/published/processing/parameters") class PublishedRagPipelineSecondStepApi(Resource): - @console_ns.expect(parser_parameters) @setup_required @login_required @account_initialization_required @@ -764,10 +754,8 @@ class PublishedRagPipelineSecondStepApi(Resource): """ Get second step parameters of rag pipeline """ - args = parser_parameters.parse_args() - node_id = args.get("node_id") - if not node_id: - raise ValueError("Node ID is required") + query = NodeIdQuery.model_validate(request.args.to_dict()) + node_id = query.node_id rag_pipeline_service = RagPipelineService() variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False) return { @@ -777,7 +765,6 @@ class PublishedRagPipelineSecondStepApi(Resource): @console_ns.route("/rag/pipelines//workflows/published/pre-processing/parameters") class PublishedRagPipelineFirstStepApi(Resource): - @console_ns.expect(parser_parameters) @setup_required @login_required @account_initialization_required @@ -787,10 +774,8 @@ class PublishedRagPipelineFirstStepApi(Resource): """ Get first step parameters of rag pipeline """ - args = parser_parameters.parse_args() - node_id = args.get("node_id") - if not node_id: - raise ValueError("Node ID is required") + query = NodeIdQuery.model_validate(request.args.to_dict()) + node_id = query.node_id rag_pipeline_service = RagPipelineService() variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False) return { @@ -800,7 +785,6 @@ class PublishedRagPipelineFirstStepApi(Resource): @console_ns.route("/rag/pipelines//workflows/draft/pre-processing/parameters") class DraftRagPipelineFirstStepApi(Resource): - @console_ns.expect(parser_parameters) @setup_required @login_required @account_initialization_required @@ -810,10 +794,8 @@ class DraftRagPipelineFirstStepApi(Resource): """ Get first step parameters of rag pipeline """ - args = parser_parameters.parse_args() - node_id = args.get("node_id") - if not node_id: - raise ValueError("Node ID is required") + query = NodeIdQuery.model_validate(request.args.to_dict()) + node_id = query.node_id rag_pipeline_service = RagPipelineService() variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True) return { @@ -823,7 +805,6 @@ class DraftRagPipelineFirstStepApi(Resource): @console_ns.route("/rag/pipelines//workflows/draft/processing/parameters") class DraftRagPipelineSecondStepApi(Resource): - @console_ns.expect(parser_parameters) @setup_required @login_required @account_initialization_required @@ -833,10 +814,8 @@ class DraftRagPipelineSecondStepApi(Resource): """ Get second step parameters of rag pipeline """ - args = parser_parameters.parse_args() - node_id = args.get("node_id") - if not node_id: - raise ValueError("Node ID is required") + query = NodeIdQuery.model_validate(request.args.to_dict()) + node_id = query.node_id rag_pipeline_service = RagPipelineService() variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True) @@ -845,16 +824,8 @@ class DraftRagPipelineSecondStepApi(Resource): } -parser_wf_run = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") -) - - @console_ns.route("/rag/pipelines//workflow-runs") class RagPipelineWorkflowRunListApi(Resource): - @console_ns.expect(parser_wf_run) @setup_required @login_required @account_initialization_required @@ -864,7 +835,16 @@ class RagPipelineWorkflowRunListApi(Resource): """ Get workflow run list """ - args = parser_wf_run.parse_args() + query = WorkflowRunQuery.model_validate( + { + "last_id": request.args.get("last_id"), + "limit": request.args.get("limit", type=int, default=20), + } + ) + args = { + "last_id": str(query.last_id) if query.last_id else None, + "limit": query.limit, + } rag_pipeline_service = RagPipelineService() result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args) @@ -964,18 +944,9 @@ class RagPipelineTransformApi(Resource): return result -parser_var = ( - reqparse.RequestParser() - .add_argument("datasource_type", type=str, required=True, location="json") - .add_argument("datasource_info", type=dict, required=True, location="json") - .add_argument("start_node_id", type=str, required=True, location="json") - .add_argument("start_node_title", type=str, required=True, location="json") -) - - @console_ns.route("/rag/pipelines//workflows/draft/datasource/variables-inspect") class RagPipelineDatasourceVariableApi(Resource): - @console_ns.expect(parser_var) + @console_ns.expect(console_ns.models[DatasourceVariablesPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -987,7 +958,7 @@ class RagPipelineDatasourceVariableApi(Resource): Set datasource variables """ current_user, _ = current_account_with_tenant() - args = parser_var.parse_args() + args = DatasourceVariablesPayload.model_validate(console_ns.payload or {}).model_dump() rag_pipeline_service = RagPipelineService() workflow_node_execution = rag_pipeline_service.set_datasource_variables( diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index b2998a8d3e..335c8f6030 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -1,5 +1,10 @@ -from flask_restx import Resource, fields, reqparse +from typing import Literal +from flask import request +from flask_restx import Resource +from pydantic import BaseModel + +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.datasets.error import WebsiteCrawlError from controllers.console.wraps import account_initialization_required, setup_required @@ -7,48 +12,35 @@ from libs.login import login_required from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService +class WebsiteCrawlPayload(BaseModel): + provider: Literal["firecrawl", "watercrawl", "jinareader"] + url: str + options: dict[str, object] + + +class WebsiteCrawlStatusQuery(BaseModel): + provider: Literal["firecrawl", "watercrawl", "jinareader"] + + +register_schema_models(console_ns, WebsiteCrawlPayload, WebsiteCrawlStatusQuery) + + @console_ns.route("/website/crawl") class WebsiteCrawlApi(Resource): @console_ns.doc("crawl_website") @console_ns.doc(description="Crawl website content") - @console_ns.expect( - console_ns.model( - "WebsiteCrawlRequest", - { - "provider": fields.String( - required=True, - description="Crawl provider (firecrawl/watercrawl/jinareader)", - enum=["firecrawl", "watercrawl", "jinareader"], - ), - "url": fields.String(required=True, description="URL to crawl"), - "options": fields.Raw(required=True, description="Crawl options"), - }, - ) - ) + @console_ns.expect(console_ns.models[WebsiteCrawlPayload.__name__]) @console_ns.response(200, "Website crawl initiated successfully") @console_ns.response(400, "Invalid crawl parameters") @setup_required @login_required @account_initialization_required def post(self): - parser = ( - reqparse.RequestParser() - .add_argument( - "provider", - type=str, - choices=["firecrawl", "watercrawl", "jinareader"], - required=True, - nullable=True, - location="json", - ) - .add_argument("url", type=str, required=True, nullable=True, location="json") - .add_argument("options", type=dict, required=True, nullable=True, location="json") - ) - args = parser.parse_args() + payload = WebsiteCrawlPayload.model_validate(console_ns.payload or {}) # Create typed request and validate try: - api_request = WebsiteCrawlApiRequest.from_args(args) + api_request = WebsiteCrawlApiRequest.from_args(payload.model_dump()) except ValueError as e: raise WebsiteCrawlError(str(e)) @@ -65,6 +57,7 @@ class WebsiteCrawlStatusApi(Resource): @console_ns.doc("get_crawl_status") @console_ns.doc(description="Get website crawl status") @console_ns.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"}) + @console_ns.expect(console_ns.models[WebsiteCrawlStatusQuery.__name__]) @console_ns.response(200, "Crawl status retrieved successfully") @console_ns.response(404, "Crawl job not found") @console_ns.response(400, "Invalid provider") @@ -72,14 +65,11 @@ class WebsiteCrawlStatusApi(Resource): @login_required @account_initialization_required def get(self, job_id: str): - parser = reqparse.RequestParser().add_argument( - "provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args" - ) - args = parser.parse_args() + args = WebsiteCrawlStatusQuery.model_validate(request.args.to_dict()) # Create typed request and validate try: - api_request = WebsiteCrawlStatusApiRequest.from_args(args, job_id) + api_request = WebsiteCrawlStatusApiRequest.from_args(args.model_dump(), job_id) except ValueError as e: raise WebsiteCrawlError(str(e)) diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 2a248cf20d..0311db1584 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -1,9 +1,11 @@ import logging from flask import request +from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError import services +from controllers.common.schema import register_schema_model from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -31,6 +33,16 @@ from .. import console_ns logger = logging.getLogger(__name__) +class TextToAudioPayload(BaseModel): + message_id: str | None = None + voice: str | None = None + text: str | None = None + streaming: bool | None = Field(default=None, description="Enable streaming response") + + +register_schema_model(console_ns, TextToAudioPayload) + + @console_ns.route( "/installed-apps//audio-to-text", endpoint="installed_app_audio", @@ -76,23 +88,15 @@ class ChatAudioApi(InstalledAppResource): endpoint="installed_app_text", ) class ChatTextApi(InstalledAppResource): + @console_ns.expect(console_ns.models[TextToAudioPayload.__name__]) def post(self, installed_app): - from flask_restx import reqparse - app_model = installed_app.app try: - parser = ( - reqparse.RequestParser() - .add_argument("message_id", type=str, required=False, location="json") - .add_argument("voice", type=str, location="json") - .add_argument("text", type=str, location="json") - .add_argument("streaming", type=bool, location="json") - ) - args = parser.parse_args() + payload = TextToAudioPayload.model_validate(console_ns.payload or {}) - message_id = args.get("message_id", None) - text = args.get("text", None) - voice = args.get("voice", None) + message_id = payload.message_id + text = payload.text + voice = payload.voice response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id) return response diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 52d6426e7f..5901eca915 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,9 +1,12 @@ import logging +from typing import Any, Literal +from uuid import UUID -from flask_restx import reqparse +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound import services +from controllers.common.schema import register_schema_models from controllers.console.app.error import ( AppUnavailableError, CompletionRequestError, @@ -25,7 +28,6 @@ from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from libs import helper from libs.datetime_utils import naive_utc_now -from libs.helper import uuid_value from libs.login import current_user from models import Account from models.model import AppMode @@ -38,28 +40,56 @@ from .. import console_ns logger = logging.getLogger(__name__) +class CompletionMessagePayload(BaseModel): + inputs: dict[str, Any] + query: str = "" + files: list[dict[str, Any]] | None = None + response_mode: Literal["blocking", "streaming"] | None = None + retriever_from: str = Field(default="explore_app") + + +class ChatMessagePayload(BaseModel): + inputs: dict[str, Any] + query: str + files: list[dict[str, Any]] | None = None + conversation_id: str | None = None + parent_message_id: str | None = None + retriever_from: str = Field(default="explore_app") + + @field_validator("conversation_id", "parent_message_id", mode="before") + @classmethod + def normalize_uuid(cls, value: str | UUID | None) -> str | None: + """ + Accept blank IDs and validate UUID format when provided. + """ + if not value: + return None + + try: + return helper.uuid_value(value) + except ValueError as exc: + raise ValueError("must be a valid UUID") from exc + + +register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload) + + # define completion api for user @console_ns.route( "/installed-apps//completion-messages", endpoint="installed_app_completion", ) class CompletionApi(InstalledAppResource): + @console_ns.expect(console_ns.models[CompletionMessagePayload.__name__]) def post(self, installed_app): app_model = installed_app.app if app_model.mode != AppMode.COMPLETION: raise NotCompletionAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json") - .add_argument("query", type=str, location="json", default="") - .add_argument("files", type=list, required=False, location="json") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - .add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") - ) - args = parser.parse_args() + payload = CompletionMessagePayload.model_validate(console_ns.payload or {}) + args = payload.model_dump(exclude_none=True) - streaming = args["response_mode"] == "streaming" + streaming = payload.response_mode == "streaming" args["auto_generate_name"] = False installed_app.last_used_at = naive_utc_now() @@ -123,22 +153,15 @@ class CompletionStopApi(InstalledAppResource): endpoint="installed_app_chat_completion", ) class ChatApi(InstalledAppResource): + @console_ns.expect(console_ns.models[ChatMessagePayload.__name__]) def post(self, installed_app): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json") - .add_argument("query", type=str, required=True, location="json") - .add_argument("files", type=list, required=False, location="json") - .add_argument("conversation_id", type=uuid_value, location="json") - .add_argument("parent_message_id", type=uuid_value, required=False, location="json") - .add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") - ) - args = parser.parse_args() + payload = ChatMessagePayload.model_validate(console_ns.payload or {}) + args = payload.model_dump(exclude_none=True) args["auto_generate_name"] = False diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 5a39363cc2..92da591ab4 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,14 +1,18 @@ -from flask_restx import marshal_with, reqparse -from flask_restx.inputs import int_range +from typing import Any +from uuid import UUID + +from flask import request +from flask_restx import marshal_with +from pydantic import BaseModel, Field, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields -from libs.helper import uuid_value from libs.login import current_user from models import Account from models.model import AppMode @@ -19,29 +23,51 @@ from services.web_conversation_service import WebConversationService from .. import console_ns +class ConversationListQuery(BaseModel): + last_id: UUID | None = None + limit: int = Field(default=20, ge=1, le=100) + pinned: bool | None = None + + +class ConversationRenamePayload(BaseModel): + name: str | None = None + auto_generate: bool = False + + @model_validator(mode="after") + def validate_name_requirement(self): + if not self.auto_generate: + if self.name is None or not self.name.strip(): + raise ValueError("name is required when auto_generate is false") + return self + + +register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload) + + @console_ns.route( "/installed-apps//conversations", endpoint="installed_app_conversations", ) class ConversationListApi(InstalledAppResource): @marshal_with(conversation_infinite_scroll_pagination_fields) + @console_ns.expect(console_ns.models[ConversationListQuery.__name__]) def get(self, installed_app): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - .add_argument("pinned", type=str, choices=["true", "false", None], location="args") - ) - args = parser.parse_args() - - pinned = None - if "pinned" in args and args["pinned"] is not None: - pinned = args["pinned"] == "true" + raw_args: dict[str, Any] = { + "last_id": request.args.get("last_id"), + "limit": request.args.get("limit", default=20, type=int), + "pinned": request.args.get("pinned"), + } + if raw_args["last_id"] is None: + raw_args["last_id"] = None + pinned_value = raw_args["pinned"] + if isinstance(pinned_value, str): + raw_args["pinned"] = pinned_value == "true" + args = ConversationListQuery.model_validate(raw_args) try: if not isinstance(current_user, Account): @@ -51,10 +77,10 @@ class ConversationListApi(InstalledAppResource): session=session, app_model=app_model, user=current_user, - last_id=args["last_id"], - limit=args["limit"], + last_id=str(args.last_id) if args.last_id else None, + limit=args.limit, invoke_from=InvokeFrom.EXPLORE, - pinned=pinned, + pinned=args.pinned, ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -88,6 +114,7 @@ class ConversationApi(InstalledAppResource): ) class ConversationRenameApi(InstalledAppResource): @marshal_with(simple_conversation_fields) + @console_ns.expect(console_ns.models[ConversationRenamePayload.__name__]) def post(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) @@ -96,18 +123,13 @@ class ConversationRenameApi(InstalledAppResource): conversation_id = str(c_id) - parser = ( - reqparse.RequestParser() - .add_argument("name", type=str, required=False, location="json") - .add_argument("auto_generate", type=bool, required=False, default=False, location="json") - ) - args = parser.parse_args() + payload = ConversationRenamePayload.model_validate(console_ns.payload or {}) try: if not isinstance(current_user, Account): raise ValueError("current_user must be an Account instance") return ConversationService.rename( - app_model, conversation_id, current_user, args["name"], args["auto_generate"] + app_model, conversation_id, current_user, payload.name, payload.auto_generate ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index db854e09bb..229b7c8865 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,9 +1,13 @@ import logging +from typing import Literal +from uuid import UUID -from flask_restx import marshal_with, reqparse -from flask_restx.inputs import int_range +from flask import request +from flask_restx import marshal_with +from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError, NotFound +from controllers.common.schema import register_schema_models from controllers.console.app.error import ( AppMoreLikeThisDisabledError, CompletionRequestError, @@ -22,7 +26,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields from libs import helper -from libs.helper import uuid_value from libs.login import current_account_with_tenant from models.model import AppMode from services.app_generate_service import AppGenerateService @@ -40,12 +43,31 @@ from .. import console_ns logger = logging.getLogger(__name__) +class MessageListQuery(BaseModel): + conversation_id: UUID + first_id: UUID | None = None + limit: int = Field(default=20, ge=1, le=100) + + +class MessageFeedbackPayload(BaseModel): + rating: Literal["like", "dislike"] | None = None + content: str | None = None + + +class MoreLikeThisQuery(BaseModel): + response_mode: Literal["blocking", "streaming"] + + +register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, MoreLikeThisQuery) + + @console_ns.route( "/installed-apps//messages", endpoint="installed_app_messages", ) class MessageListApi(InstalledAppResource): @marshal_with(message_infinite_scroll_pagination_fields) + @console_ns.expect(console_ns.models[MessageListQuery.__name__]) def get(self, installed_app): current_user, _ = current_account_with_tenant() app_model = installed_app.app @@ -53,18 +75,15 @@ class MessageListApi(InstalledAppResource): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - - parser = ( - reqparse.RequestParser() - .add_argument("conversation_id", required=True, type=uuid_value, location="args") - .add_argument("first_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - ) - args = parser.parse_args() + args = MessageListQuery.model_validate(request.args.to_dict()) try: return MessageService.pagination_by_first_id( - app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] + app_model, + current_user, + str(args.conversation_id), + str(args.first_id) if args.first_id else None, + args.limit, ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -77,26 +96,22 @@ class MessageListApi(InstalledAppResource): endpoint="installed_app_message_feedback", ) class MessageFeedbackApi(InstalledAppResource): + @console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__]) def post(self, installed_app, message_id): current_user, _ = current_account_with_tenant() app_model = installed_app.app message_id = str(message_id) - parser = ( - reqparse.RequestParser() - .add_argument("rating", type=str, choices=["like", "dislike", None], location="json") - .add_argument("content", type=str, location="json") - ) - args = parser.parse_args() + payload = MessageFeedbackPayload.model_validate(console_ns.payload or {}) try: MessageService.create_feedback( app_model=app_model, message_id=message_id, user=current_user, - rating=args.get("rating"), - content=args.get("content"), + rating=payload.rating, + content=payload.content, ) except MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -109,6 +124,7 @@ class MessageFeedbackApi(InstalledAppResource): endpoint="installed_app_more_like_this", ) class MessageMoreLikeThisApi(InstalledAppResource): + @console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__]) def get(self, installed_app, message_id): current_user, _ = current_account_with_tenant() app_model = installed_app.app @@ -117,12 +133,9 @@ class MessageMoreLikeThisApi(InstalledAppResource): message_id = str(message_id) - parser = reqparse.RequestParser().add_argument( - "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" - ) - args = parser.parse_args() + args = MoreLikeThisQuery.model_validate(request.args.to_dict()) - streaming = args["response_mode"] == "streaming" + streaming = args.response_mode == "streaming" try: response = AppGenerateService.generate_more_like_this( diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 5a9c3ef133..2b2f807694 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,4 +1,6 @@ -from flask_restx import Resource, fields, marshal_with, reqparse +from flask import request +from flask_restx import Resource, fields, marshal_with +from pydantic import BaseModel, Field from constants.languages import languages from controllers.console import console_ns @@ -35,20 +37,26 @@ recommended_app_list_fields = { } -parser_apps = reqparse.RequestParser().add_argument("language", type=str, location="args") +class RecommendedAppsQuery(BaseModel): + language: str | None = Field(default=None) + + +console_ns.schema_model( + RecommendedAppsQuery.__name__, + RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"), +) @console_ns.route("/explore/apps") class RecommendedAppListApi(Resource): - @console_ns.expect(parser_apps) + @console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__]) @login_required @account_initialization_required @marshal_with(recommended_app_list_fields) def get(self): # language args - args = parser_apps.parse_args() - - language = args.get("language") + args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + language = args.language if language and language in languages: language_prefix = language elif current_user and current_user.interface_language: diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 9775c951f7..6a9e274a0e 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,16 +1,33 @@ -from flask_restx import fields, marshal_with, reqparse -from flask_restx.inputs import int_range +from uuid import UUID + +from flask import request +from flask_restx import fields, marshal_with +from pydantic import BaseModel, Field from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from fields.conversation_fields import message_file_fields -from libs.helper import TimestampField, uuid_value +from libs.helper import TimestampField from libs.login import current_account_with_tenant from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService + +class SavedMessageListQuery(BaseModel): + last_id: UUID | None = None + limit: int = Field(default=20, ge=1, le=100) + + +class SavedMessageCreatePayload(BaseModel): + message_id: UUID + + +register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload) + + feedback_fields = {"rating": fields.String} message_fields = { @@ -33,32 +50,33 @@ class SavedMessageListApi(InstalledAppResource): } @marshal_with(saved_message_infinite_scroll_pagination_fields) + @console_ns.expect(console_ns.models[SavedMessageListQuery.__name__]) def get(self, installed_app): current_user, _ = current_account_with_tenant() app_model = installed_app.app if app_model.mode != "completion": raise NotCompletionAppError() - parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args") - .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + args = SavedMessageListQuery.model_validate(request.args.to_dict()) + + return SavedMessageService.pagination_by_last_id( + app_model, + current_user, + str(args.last_id) if args.last_id else None, + args.limit, ) - args = parser.parse_args() - - return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) + @console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__]) def post(self, installed_app): current_user, _ = current_account_with_tenant() app_model = installed_app.app if app_model.mode != "completion": raise NotCompletionAppError() - parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json") - args = parser.parse_args() + payload = SavedMessageCreatePayload.model_validate(console_ns.payload or {}) try: - SavedMessageService.save(app_model, current_user, args["message_id"]) + SavedMessageService.save(app_model, current_user, str(payload.message_id)) except MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 125f603a5a..d679d0722d 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -1,8 +1,10 @@ import logging +from typing import Any -from flask_restx import reqparse +from pydantic import BaseModel from werkzeug.exceptions import InternalServerError +from controllers.common.schema import register_schema_model from controllers.console.app.error import ( CompletionRequestError, ProviderModelCurrentlyNotSupportError, @@ -32,8 +34,17 @@ from .. import console_ns logger = logging.getLogger(__name__) +class WorkflowRunPayload(BaseModel): + inputs: dict[str, Any] + files: list[dict[str, Any]] | None = None + + +register_schema_model(console_ns, WorkflowRunPayload) + + @console_ns.route("/installed-apps//workflows/run") class InstalledAppWorkflowRunApi(InstalledAppResource): + @console_ns.expect(console_ns.models[WorkflowRunPayload.__name__]) def post(self, installed_app: InstalledApp): """ Run workflow @@ -46,12 +57,8 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("files", type=list, required=False, location="json") - ) - args = parser.parse_args() + payload = WorkflowRunPayload.model_validate(console_ns.payload or {}) + args = payload.model_dump(exclude_none=True) try: response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index fdd7c2f479..29417dc896 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -45,6 +45,9 @@ class FileApi(Resource): "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, + "image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT, + "single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT, + "attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT, }, 200 @setup_required diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index f27fa26983..2bebe79eac 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -1,13 +1,13 @@ import os from flask import session -from flask_restx import Resource, fields, reqparse +from flask_restx import Resource, fields +from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config from extensions.ext_database import db -from libs.helper import StrLen from models.model import DifySetup from services.account_service import TenantService @@ -15,6 +15,18 @@ from . import console_ns from .error import AlreadySetupError, InitValidateFailedError from .wraps import only_edition_self_hosted +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class InitValidatePayload(BaseModel): + password: str = Field(..., max_length=30) + + +console_ns.schema_model( + InitValidatePayload.__name__, + InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + @console_ns.route("/init") class InitValidateAPI(Resource): @@ -37,12 +49,7 @@ class InitValidateAPI(Resource): @console_ns.doc("validate_init_password") @console_ns.doc(description="Validate initialization password for self-hosted edition") - @console_ns.expect( - console_ns.model( - "InitValidateRequest", - {"password": fields.String(required=True, description="Initialization password", max_length=30)}, - ) - ) + @console_ns.expect(console_ns.models[InitValidatePayload.__name__]) @console_ns.response( 201, "Success", @@ -57,8 +64,8 @@ class InitValidateAPI(Resource): if tenant_count > 0: raise AlreadySetupError() - parser = reqparse.RequestParser().add_argument("password", type=StrLen(30), required=True, location="json") - input_password = parser.parse_args()["password"] + payload = InitValidatePayload.model_validate(console_ns.payload) + input_password = payload.password if input_password != os.environ.get("INIT_PASSWORD"): session["is_init_validated"] = False diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 49a4df1b5a..47eef7eb7e 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -1,7 +1,8 @@ import urllib.parse import httpx -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field import services from controllers.common import helpers @@ -36,17 +37,23 @@ class RemoteFileInfoApi(Resource): } -parser_upload = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required") +class RemoteFileUploadPayload(BaseModel): + url: str = Field(..., description="URL to fetch") + + +console_ns.schema_model( + RemoteFileUploadPayload.__name__, + RemoteFileUploadPayload.model_json_schema(ref_template="#/definitions/{model}"), +) @console_ns.route("/remote-files/upload") class RemoteFileUploadApi(Resource): - @console_ns.expect(parser_upload) + @console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__]) @marshal_with(file_fields_with_signed_url) def post(self): - args = parser_upload.parse_args() - - url = args["url"] + args = RemoteFileUploadPayload.model_validate(console_ns.payload) + url = args.url try: resp = ssrf_proxy.head(url=url) diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 0c2a4d797b..7fa02ae280 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,8 +1,9 @@ from flask import request -from flask_restx import Resource, fields, reqparse +from flask_restx import Resource, fields +from pydantic import BaseModel, Field, field_validator from configs import dify_config -from libs.helper import StrLen, email, extract_remote_ip +from libs.helper import EmailStr, extract_remote_ip from libs.password import valid_password from models.model import DifySetup, db from services.account_service import RegisterService, TenantService @@ -12,6 +13,26 @@ from .error import AlreadySetupError, NotInitValidateError from .init_validate import get_init_validate_status from .wraps import only_edition_self_hosted +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class SetupRequestPayload(BaseModel): + email: EmailStr = Field(..., description="Admin email address") + name: str = Field(..., max_length=30, description="Admin name (max 30 characters)") + password: str = Field(..., description="Admin password") + language: str | None = Field(default=None, description="Admin language") + + @field_validator("password") + @classmethod + def validate_password(cls, value: str) -> str: + return valid_password(value) + + +console_ns.schema_model( + SetupRequestPayload.__name__, + SetupRequestPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + @console_ns.route("/setup") class SetupApi(Resource): @@ -42,17 +63,7 @@ class SetupApi(Resource): @console_ns.doc("setup_system") @console_ns.doc(description="Initialize system setup with admin account") - @console_ns.expect( - console_ns.model( - "SetupRequest", - { - "email": fields.String(required=True, description="Admin email address"), - "name": fields.String(required=True, description="Admin name (max 30 characters)"), - "password": fields.String(required=True, description="Admin password"), - "language": fields.String(required=False, description="Admin language"), - }, - ) - ) + @console_ns.expect(console_ns.models[SetupRequestPayload.__name__]) @console_ns.response( 201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")}) ) @@ -72,22 +83,15 @@ class SetupApi(Resource): if not get_init_validate_status(): raise NotInitValidateError() - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("name", type=StrLen(30), required=True, location="json") - .add_argument("password", type=valid_password, required=True, location="json") - .add_argument("language", type=str, required=False, location="json") - ) - args = parser.parse_args() + args = SetupRequestPayload.model_validate(console_ns.payload) # setup RegisterService.setup( - email=args["email"], - name=args["name"], - password=args["password"], + email=args.email, + name=args.name, + password=args.password, ip_address=extract_remote_ip(request), - language=args["language"], + language=args.language, ) return {"result": "success"}, 201 diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 4e3d9d6786..419261ba2a 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -2,8 +2,10 @@ import json import logging import httpx -from flask_restx import Resource, fields, reqparse +from flask import request +from flask_restx import Resource, fields from packaging import version +from pydantic import BaseModel, Field from configs import dify_config @@ -11,8 +13,14 @@ from . import console_ns logger = logging.getLogger(__name__) -parser = reqparse.RequestParser().add_argument( - "current_version", type=str, required=True, location="args", help="Current application version" + +class VersionQuery(BaseModel): + current_version: str = Field(..., description="Current application version") + + +console_ns.schema_model( + VersionQuery.__name__, + VersionQuery.model_json_schema(ref_template="#/definitions/{model}"), ) @@ -20,7 +28,7 @@ parser = reqparse.RequestParser().add_argument( class VersionApi(Resource): @console_ns.doc("check_version_update") @console_ns.doc(description="Check for application version updates") - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[VersionQuery.__name__]) @console_ns.response( 200, "Success", @@ -37,7 +45,7 @@ class VersionApi(Resource): ) def get(self): """Check for application version updates""" - args = parser.parse_args() + args = VersionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore check_update_url = dify_config.CHECK_UPDATE_URL result = { @@ -57,16 +65,16 @@ class VersionApi(Resource): try: response = httpx.get( check_update_url, - params={"current_version": args["current_version"]}, + params={"current_version": args.current_version}, timeout=httpx.Timeout(timeout=10.0, connect=3.0), ) except Exception as error: logger.warning("Check update version error: %s.", str(error)) - result["version"] = args["current_version"] + result["version"] = args.current_version return result content = json.loads(response.content) - if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"): + if _has_new_version(latest_version=content["version"], current_version=f"{args.current_version}"): result["version"] = content["version"] result["release_date"] = content["releaseDate"] result["release_notes"] = content["releaseNotes"] diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 6334314988..55eaa2f09f 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -37,7 +37,7 @@ from controllers.console.wraps import ( from extensions.ext_database import db from fields.member_fields import account_fields from libs.datetime_utils import naive_utc_now -from libs.helper import TimestampField, email, extract_remote_ip, timezone +from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone from libs.login import current_account_with_tenant, login_required from models import Account, AccountIntegrate, InvitationCode from services.account_service import AccountService @@ -111,14 +111,9 @@ class AccountDeletePayload(BaseModel): class AccountDeletionFeedbackPayload(BaseModel): - email: str + email: EmailStr feedback: str - @field_validator("email") - @classmethod - def validate_email(cls, value: str) -> str: - return email(value) - class EducationActivatePayload(BaseModel): token: str @@ -133,45 +128,25 @@ class EducationAutocompleteQuery(BaseModel): class ChangeEmailSendPayload(BaseModel): - email: str + email: EmailStr language: str | None = None phase: str | None = None token: str | None = None - @field_validator("email") - @classmethod - def validate_email(cls, value: str) -> str: - return email(value) - class ChangeEmailValidityPayload(BaseModel): - email: str + email: EmailStr code: str token: str - @field_validator("email") - @classmethod - def validate_email(cls, value: str) -> str: - return email(value) - class ChangeEmailResetPayload(BaseModel): - new_email: str + new_email: EmailStr token: str - @field_validator("new_email") - @classmethod - def validate_email(cls, value: str) -> str: - return email(value) - class CheckEmailUniquePayload(BaseModel): - email: str - - @field_validator("email") - @classmethod - def validate_email(cls, value: str) -> str: - return email(value) + email: EmailStr def reg(cls: type[BaseModel]): diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 246a869291..2def57ed7b 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -230,7 +230,7 @@ class ModelProviderModelApi(Resource): return {"result": "success"}, 200 - @console_ns.expect(console_ns.models[ParserDeleteModels.__name__], validate=True) + @console_ns.expect(console_ns.models[ParserDeleteModels.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -282,9 +282,10 @@ class ModelProviderModelCredentialApi(Resource): tenant_id=tenant_id, provider_name=provider ) else: - model_type = args.model_type + # Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM) + normalized_model_type = args.model_type.to_origin_model_type() available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials( - tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model + tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model ) return jsonable_encoder( diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 69281c6214..268473d6d1 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -22,7 +22,12 @@ from services.trigger.trigger_subscription_builder_service import TriggerSubscri from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService from .. import console_ns -from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required +from ..wraps import ( + account_initialization_required, + edit_permission_required, + is_admin_or_owner_required, + setup_required, +) logger = logging.getLogger(__name__) @@ -72,7 +77,7 @@ class TriggerProviderInfoApi(Resource): class TriggerSubscriptionListApi(Resource): @setup_required @login_required - @is_admin_or_owner_required + @edit_permission_required @account_initialization_required def get(self, provider): """List all trigger subscriptions for the current tenant's provider""" @@ -104,7 +109,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource): @console_ns.expect(parser) @setup_required @login_required - @is_admin_or_owner_required + @edit_permission_required @account_initialization_required def post(self, provider): """Add a new subscription instance for a trigger provider""" @@ -133,6 +138,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource): class TriggerSubscriptionBuilderGetApi(Resource): @setup_required @login_required + @edit_permission_required @account_initialization_required def get(self, provider, subscription_builder_id): """Get a subscription instance for a trigger provider""" @@ -155,7 +161,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource): @console_ns.expect(parser_api) @setup_required @login_required - @is_admin_or_owner_required + @edit_permission_required @account_initialization_required def post(self, provider, subscription_builder_id): """Verify a subscription instance for a trigger provider""" @@ -200,6 +206,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource): @console_ns.expect(parser_update_api) @setup_required @login_required + @edit_permission_required @account_initialization_required def post(self, provider, subscription_builder_id): """Update a subscription instance for a trigger provider""" @@ -233,6 +240,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource): class TriggerSubscriptionBuilderLogsApi(Resource): @setup_required @login_required + @edit_permission_required @account_initialization_required def get(self, provider, subscription_builder_id): """Get the request logs for a subscription instance for a trigger provider""" @@ -255,7 +263,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource): @console_ns.expect(parser_update_api) @setup_required @login_required - @is_admin_or_owner_required + @edit_permission_required @account_initialization_required def post(self, provider, subscription_builder_id): """Build a subscription instance for a trigger provider""" diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index d320855f29..64f47f426a 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -1,7 +1,8 @@ from urllib.parse import quote from flask import Response, request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field from werkzeug.exceptions import NotFound import services @@ -11,6 +12,26 @@ from extensions.ext_database import db from services.account_service import TenantService from services.file_service import FileService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class FileSignatureQuery(BaseModel): + timestamp: str = Field(..., description="Unix timestamp used in the signature") + nonce: str = Field(..., description="Random string for signature") + sign: str = Field(..., description="HMAC signature") + + +class FilePreviewQuery(FileSignatureQuery): + as_attachment: bool = Field(default=False, description="Whether to download as attachment") + + +files_ns.schema_model( + FileSignatureQuery.__name__, FileSignatureQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) +files_ns.schema_model( + FilePreviewQuery.__name__, FilePreviewQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + @files_ns.route("//image-preview") class ImagePreviewApi(Resource): @@ -36,12 +57,10 @@ class ImagePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) - timestamp = request.args.get("timestamp") - nonce = request.args.get("nonce") - sign = request.args.get("sign") - - if not timestamp or not nonce or not sign: - return {"content": "Invalid request."}, 400 + args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + timestamp = args.timestamp + nonce = args.nonce + sign = args.sign try: generator, mimetype = FileService(db.engine).get_image_preview( @@ -80,25 +99,14 @@ class FilePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) - parser = ( - reqparse.RequestParser() - .add_argument("timestamp", type=str, required=True, location="args") - .add_argument("nonce", type=str, required=True, location="args") - .add_argument("sign", type=str, required=True, location="args") - .add_argument("as_attachment", type=bool, required=False, default=False, location="args") - ) - - args = parser.parse_args() - - if not args["timestamp"] or not args["nonce"] or not args["sign"]: - return {"content": "Invalid request."}, 400 + args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore try: generator, upload_file = FileService(db.engine).get_file_generator_by_file_id( file_id=file_id, - timestamp=args["timestamp"], - nonce=args["nonce"], - sign=args["sign"], + timestamp=args.timestamp, + nonce=args.nonce, + sign=args.sign, ) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() @@ -125,7 +133,7 @@ class FilePreviewApi(Resource): response.headers["Accept-Ranges"] = "bytes" if upload_file.size > 0: response.headers["Content-Length"] = str(upload_file.size) - if args["as_attachment"]: + if args.as_attachment: encoded_filename = quote(upload_file.name) response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" response.headers["Content-Type"] = "application/octet-stream" diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index ecaeb85821..c487a0a915 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -1,7 +1,8 @@ from urllib.parse import quote -from flask import Response -from flask_restx import Resource, reqparse +from flask import Response, request +from flask_restx import Resource +from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, NotFound from controllers.common.errors import UnsupportedFileTypeError @@ -10,6 +11,20 @@ from core.tools.signature import verify_tool_file_signature from core.tools.tool_file_manager import ToolFileManager from extensions.ext_database import db as global_db +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class ToolFileQuery(BaseModel): + timestamp: str = Field(..., description="Unix timestamp") + nonce: str = Field(..., description="Random nonce") + sign: str = Field(..., description="HMAC signature") + as_attachment: bool = Field(default=False, description="Download as attachment") + + +files_ns.schema_model( + ToolFileQuery.__name__, ToolFileQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + @files_ns.route("/tools/.") class ToolFileApi(Resource): @@ -36,18 +51,8 @@ class ToolFileApi(Resource): def get(self, file_id, extension): file_id = str(file_id) - parser = ( - reqparse.RequestParser() - .add_argument("timestamp", type=str, required=True, location="args") - .add_argument("nonce", type=str, required=True, location="args") - .add_argument("sign", type=str, required=True, location="args") - .add_argument("as_attachment", type=bool, required=False, default=False, location="args") - ) - - args = parser.parse_args() - if not verify_tool_file_signature( - file_id=file_id, timestamp=args["timestamp"], nonce=args["nonce"], sign=args["sign"] - ): + args = ToolFileQuery.model_validate(request.args.to_dict()) + if not verify_tool_file_signature(file_id=file_id, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign): raise Forbidden("Invalid request.") try: @@ -69,7 +74,7 @@ class ToolFileApi(Resource): ) if tool_file.size > 0: response.headers["Content-Length"] = str(tool_file.size) - if args["as_attachment"]: + if args.as_attachment: encoded_filename = quote(tool_file.name) response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index a09e24e2d9..6096a87c56 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -1,40 +1,45 @@ from mimetypes import guess_extension -from flask_restx import Resource, reqparse +from flask import request +from flask_restx import Resource from flask_restx.api import HTTPStatus +from pydantic import BaseModel, Field from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden import services -from controllers.common.errors import ( - FileTooLargeError, - UnsupportedFileTypeError, -) -from controllers.console.wraps import setup_required -from controllers.files import files_ns -from controllers.inner_api.plugin.wraps import get_user from core.file.helpers import verify_plugin_file_signature from core.tools.tool_file_manager import ToolFileManager from fields.file_fields import build_file_model -# Define parser for both documentation and validation -upload_parser = ( - reqparse.RequestParser() - .add_argument("file", location="files", type=FileStorage, required=True, help="File to upload") - .add_argument( - "timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification" - ) - .add_argument("nonce", type=str, required=True, location="args", help="Random string for signature verification") - .add_argument("sign", type=str, required=True, location="args", help="HMAC signature for request validation") - .add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier") - .add_argument("user_id", type=str, required=False, location="args", help="User identifier") +from ..common.errors import ( + FileTooLargeError, + UnsupportedFileTypeError, +) +from ..console.wraps import setup_required +from ..files import files_ns +from ..inner_api.plugin.wraps import get_user + +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class PluginUploadQuery(BaseModel): + timestamp: str = Field(..., description="Unix timestamp for signature verification") + nonce: str = Field(..., description="Random nonce for signature verification") + sign: str = Field(..., description="HMAC signature") + tenant_id: str = Field(..., description="Tenant identifier") + user_id: str | None = Field(default=None, description="User identifier") + + +files_ns.schema_model( + PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) ) @files_ns.route("/upload/for-plugin") class PluginUploadFileApi(Resource): @setup_required - @files_ns.expect(upload_parser) + @files_ns.expect(files_ns.models[PluginUploadQuery.__name__]) @files_ns.doc("upload_plugin_file") @files_ns.doc(description="Upload a file for plugin usage with signature verification") @files_ns.doc( @@ -62,15 +67,17 @@ class PluginUploadFileApi(Resource): FileTooLargeError: File exceeds size limit UnsupportedFileTypeError: File type not supported """ - # Parse and validate all arguments - args = upload_parser.parse_args() + args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - file: FileStorage = args["file"] - timestamp: str = args["timestamp"] - nonce: str = args["nonce"] - sign: str = args["sign"] - tenant_id: str = args["tenant_id"] - user_id: str | None = args.get("user_id") + file: FileStorage | None = request.files.get("file") + if file is None: + raise Forbidden("File is required.") + + timestamp = args.timestamp + nonce = args.nonce + sign = args.sign + tenant_id = args.tenant_id + user_id = args.user_id user = get_user(tenant_id, user_id) filename: str | None = file.filename diff --git a/api/controllers/inner_api/mail.py b/api/controllers/inner_api/mail.py index 7e40d81706..885ab7b78d 100644 --- a/api/controllers/inner_api/mail.py +++ b/api/controllers/inner_api/mail.py @@ -1,29 +1,38 @@ -from flask_restx import Resource, reqparse +from typing import Any +from flask_restx import Resource +from pydantic import BaseModel, Field + +from controllers.common.schema import register_schema_model from controllers.console.wraps import setup_required from controllers.inner_api import inner_api_ns from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only from tasks.mail_inner_task import send_inner_email_task -_mail_parser = ( - reqparse.RequestParser() - .add_argument("to", type=str, action="append", required=True) - .add_argument("subject", type=str, required=True) - .add_argument("body", type=str, required=True) - .add_argument("substitutions", type=dict, required=False) -) + +class InnerMailPayload(BaseModel): + to: list[str] = Field(description="Recipient email addresses", min_length=1) + subject: str + body: str + substitutions: dict[str, Any] | None = None + + +register_schema_model(inner_api_ns, InnerMailPayload) class BaseMail(Resource): """Shared logic for sending an inner email.""" + @inner_api_ns.doc("send_inner_mail") + @inner_api_ns.doc(description="Send internal email") + @inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__]) def post(self): - args = _mail_parser.parse_args() - send_inner_email_task.delay( # type: ignore - to=args["to"], - subject=args["subject"], - body=args["body"], - substitutions=args["substitutions"], + args = InnerMailPayload.model_validate(inner_api_ns.payload or {}) + send_inner_email_task.delay( + to=args.to, + subject=args.subject, + body=args.body, + substitutions=args.substitutions, # type: ignore ) return {"message": "success"}, 200 @@ -34,7 +43,7 @@ class EnterpriseMail(BaseMail): @inner_api_ns.doc("send_enterprise_mail") @inner_api_ns.doc(description="Send internal email for enterprise features") - @inner_api_ns.expect(_mail_parser) + @inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__]) @inner_api_ns.doc( responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"} ) @@ -56,7 +65,7 @@ class BillingMail(BaseMail): @inner_api_ns.doc("send_billing_mail") @inner_api_ns.doc(description="Send internal email for billing notifications") - @inner_api_ns.expect(_mail_parser) + @inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__]) @inner_api_ns.doc( responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"} ) diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 2a57bb745b..edf3ac393c 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -1,10 +1,9 @@ from collections.abc import Callable from functools import wraps -from typing import ParamSpec, TypeVar, cast +from typing import ParamSpec, TypeVar from flask import current_app, request from flask_login import user_logged_in -from flask_restx import reqparse from pydantic import BaseModel from sqlalchemy.orm import Session @@ -17,6 +16,11 @@ P = ParamSpec("P") R = TypeVar("R") +class TenantUserPayload(BaseModel): + tenant_id: str + user_id: str + + def get_user(tenant_id: str, user_id: str | None) -> EndUser: """ Get current user @@ -67,58 +71,45 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: return user_model -def get_user_tenant(view: Callable[P, R] | None = None): - def decorator(view_func: Callable[P, R]): - @wraps(view_func) - def decorated_view(*args: P.args, **kwargs: P.kwargs): - # fetch json body - parser = ( - reqparse.RequestParser() - .add_argument("tenant_id", type=str, required=True, location="json") - .add_argument("user_id", type=str, required=True, location="json") - ) +def get_user_tenant(view_func: Callable[P, R]): + @wraps(view_func) + def decorated_view(*args: P.args, **kwargs: P.kwargs): + payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {}) - p = parser.parse_args() + user_id = payload.user_id + tenant_id = payload.tenant_id - user_id = cast(str, p.get("user_id")) - tenant_id = cast(str, p.get("tenant_id")) + if not tenant_id: + raise ValueError("tenant_id is required") - if not tenant_id: - raise ValueError("tenant_id is required") + if not user_id: + user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID - if not user_id: - user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID - - try: - tenant_model = ( - db.session.query(Tenant) - .where( - Tenant.id == tenant_id, - ) - .first() + try: + tenant_model = ( + db.session.query(Tenant) + .where( + Tenant.id == tenant_id, ) - except Exception: - raise ValueError("tenant not found") + .first() + ) + except Exception: + raise ValueError("tenant not found") - if not tenant_model: - raise ValueError("tenant not found") + if not tenant_model: + raise ValueError("tenant not found") - kwargs["tenant_model"] = tenant_model + kwargs["tenant_model"] = tenant_model - user = get_user(tenant_id, user_id) - kwargs["user_model"] = user + user = get_user(tenant_id, user_id) + kwargs["user_model"] = user - current_app.login_manager._update_request_context_with_user(user) # type: ignore - user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore + current_app.login_manager._update_request_context_with_user(user) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore - return view_func(*args, **kwargs) + return view_func(*args, **kwargs) - return decorated_view - - if view is None: - return decorator - else: - return decorator(view) + return decorated_view def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]): diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 8391a15919..a5746abafa 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -1,7 +1,9 @@ import json -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel +from controllers.common.schema import register_schema_models from controllers.console.wraps import setup_required from controllers.inner_api import inner_api_ns from controllers.inner_api.wraps import enterprise_inner_api_only @@ -11,12 +13,25 @@ from models import Account from services.account_service import TenantService +class WorkspaceCreatePayload(BaseModel): + name: str + owner_email: str + + +class WorkspaceOwnerlessPayload(BaseModel): + name: str + + +register_schema_models(inner_api_ns, WorkspaceCreatePayload, WorkspaceOwnerlessPayload) + + @inner_api_ns.route("/enterprise/workspace") class EnterpriseWorkspace(Resource): @setup_required @enterprise_inner_api_only @inner_api_ns.doc("create_enterprise_workspace") @inner_api_ns.doc(description="Create a new enterprise workspace with owner assignment") + @inner_api_ns.expect(inner_api_ns.models[WorkspaceCreatePayload.__name__]) @inner_api_ns.doc( responses={ 200: "Workspace created successfully", @@ -25,18 +40,13 @@ class EnterpriseWorkspace(Resource): } ) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("name", type=str, required=True, location="json") - .add_argument("owner_email", type=str, required=True, location="json") - ) - args = parser.parse_args() + args = WorkspaceCreatePayload.model_validate(inner_api_ns.payload or {}) - account = db.session.query(Account).filter_by(email=args["owner_email"]).first() + account = db.session.query(Account).filter_by(email=args.owner_email).first() if account is None: return {"message": "owner account not found."}, 404 - tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True) + tenant = TenantService.create_tenant(args.name, is_from_dashboard=True) TenantService.create_tenant_member(tenant, account, role="owner") tenant_was_created.send(tenant) @@ -62,6 +72,7 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource): @enterprise_inner_api_only @inner_api_ns.doc("create_enterprise_workspace_ownerless") @inner_api_ns.doc(description="Create a new enterprise workspace without initial owner assignment") + @inner_api_ns.expect(inner_api_ns.models[WorkspaceOwnerlessPayload.__name__]) @inner_api_ns.doc( responses={ 200: "Workspace created successfully", @@ -70,10 +81,9 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource): } ) def post(self): - parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json") - args = parser.parse_args() + args = WorkspaceOwnerlessPayload.model_validate(inner_api_ns.payload or {}) - tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True) + tenant = TenantService.create_tenant(args.name, is_from_dashboard=True) tenant_was_created.send(tenant) diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 8d8fe6b3a8..90137a10ba 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -1,10 +1,11 @@ -from typing import Union +from typing import Any, Union from flask import Response -from flask_restx import Resource, reqparse -from pydantic import ValidationError +from flask_restx import Resource +from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import Session +from controllers.common.schema import register_schema_model from controllers.console.app.mcp_server import AppMCPServerStatus from controllers.mcp import mcp_ns from core.app.app_config.entities import VariableEntity @@ -24,27 +25,19 @@ class MCPRequestError(Exception): super().__init__(message) -def int_or_str(value): - """Validate that a value is either an integer or string.""" - if isinstance(value, (int, str)): - return value - else: - return None +class MCPRequestPayload(BaseModel): + jsonrpc: str = Field(description="JSON-RPC version (should be '2.0')") + method: str = Field(description="The method to invoke") + params: dict[str, Any] | None = Field(default=None, description="Parameters for the method") + id: int | str | None = Field(default=None, description="Request ID for tracking responses") -# Define parser for both documentation and validation -mcp_request_parser = ( - reqparse.RequestParser() - .add_argument("jsonrpc", type=str, required=True, location="json", help="JSON-RPC version (should be '2.0')") - .add_argument("method", type=str, required=True, location="json", help="The method to invoke") - .add_argument("params", type=dict, required=False, location="json", help="Parameters for the method") - .add_argument("id", type=int_or_str, required=False, location="json", help="Request ID for tracking responses") -) +register_schema_model(mcp_ns, MCPRequestPayload) @mcp_ns.route("/server//mcp") class MCPAppApi(Resource): - @mcp_ns.expect(mcp_request_parser) + @mcp_ns.expect(mcp_ns.models[MCPRequestPayload.__name__]) @mcp_ns.doc("handle_mcp_request") @mcp_ns.doc(description="Handle Model Context Protocol (MCP) requests for a specific server") @mcp_ns.doc(params={"server_code": "Unique identifier for the MCP server"}) @@ -70,9 +63,9 @@ class MCPAppApi(Resource): Raises: ValidationError: Invalid request format or parameters """ - args = mcp_request_parser.parse_args() - request_id: Union[int, str] | None = args.get("id") - mcp_request = self._parse_mcp_request(args) + args = MCPRequestPayload.model_validate(mcp_ns.payload or {}) + request_id: Union[int, str] | None = args.id + mcp_request = self._parse_mcp_request(args.model_dump(exclude_none=True)) with Session(db.engine, expire_on_commit=False) as session: # Get MCP server and app diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index f26718555a..63c373b50f 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -1,9 +1,11 @@ from typing import Literal from flask import request -from flask_restx import Api, Namespace, Resource, fields, reqparse +from flask_restx import Api, Namespace, Resource, fields from flask_restx.api import HTTPStatus +from pydantic import BaseModel, Field +from controllers.common.schema import register_schema_models from controllers.console.wraps import edit_permission_required from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token @@ -12,26 +14,24 @@ from fields.annotation_fields import annotation_fields, build_annotation_model from models.model import App from services.annotation_service import AppAnnotationService -# Define parsers for annotation API -annotation_create_parser = ( - reqparse.RequestParser() - .add_argument("question", required=True, type=str, location="json", help="Annotation question") - .add_argument("answer", required=True, type=str, location="json", help="Annotation answer") -) -annotation_reply_action_parser = ( - reqparse.RequestParser() - .add_argument( - "score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching" - ) - .add_argument("embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name") - .add_argument("embedding_model_name", required=True, type=str, location="json", help="Embedding model name") -) +class AnnotationCreatePayload(BaseModel): + question: str = Field(description="Annotation question") + answer: str = Field(description="Annotation answer") + + +class AnnotationReplyActionPayload(BaseModel): + score_threshold: float = Field(description="Score threshold for annotation matching") + embedding_provider_name: str = Field(description="Embedding provider name") + embedding_model_name: str = Field(description="Embedding model name") + + +register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload) @service_api_ns.route("/apps/annotation-reply/") class AnnotationReplyActionApi(Resource): - @service_api_ns.expect(annotation_reply_action_parser) + @service_api_ns.expect(service_api_ns.models[AnnotationReplyActionPayload.__name__]) @service_api_ns.doc("annotation_reply_action") @service_api_ns.doc(description="Enable or disable annotation reply feature") @service_api_ns.doc(params={"action": "Action to perform: 'enable' or 'disable'"}) @@ -44,7 +44,7 @@ class AnnotationReplyActionApi(Resource): @validate_app_token def post(self, app_model: App, action: Literal["enable", "disable"]): """Enable or disable annotation reply feature.""" - args = annotation_reply_action_parser.parse_args() + args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump() if action == "enable": result = AppAnnotationService.enable_app_annotation(args, app_model.id) elif action == "disable": @@ -126,7 +126,7 @@ class AnnotationListApi(Resource): "page": page, } - @service_api_ns.expect(annotation_create_parser) + @service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__]) @service_api_ns.doc("create_annotation") @service_api_ns.doc(description="Create a new annotation") @service_api_ns.doc( @@ -139,14 +139,14 @@ class AnnotationListApi(Resource): @service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED) def post(self, app_model: App): """Create a new annotation.""" - args = annotation_create_parser.parse_args() + args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id) return annotation, 201 @service_api_ns.route("/apps/annotations/") class AnnotationUpdateDeleteApi(Resource): - @service_api_ns.expect(annotation_create_parser) + @service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__]) @service_api_ns.doc("update_annotation") @service_api_ns.doc(description="Update an existing annotation") @service_api_ns.doc(params={"annotation_id": "Annotation ID"}) @@ -163,7 +163,7 @@ class AnnotationUpdateDeleteApi(Resource): @service_api_ns.marshal_with(build_annotation_model(service_api_ns)) def put(self, app_model: App, annotation_id: str): """Update an existing annotation.""" - args = annotation_create_parser.parse_args() + args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) return annotation diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index c069a7ddfb..e383920460 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -1,10 +1,12 @@ import logging from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError import services +from controllers.common.schema import register_schema_model from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( AppUnavailableError, @@ -84,19 +86,19 @@ class AudioApi(Resource): raise InternalServerError() -# Define parser for text-to-audio API -text_to_audio_parser = ( - reqparse.RequestParser() - .add_argument("message_id", type=str, required=False, location="json", help="Message ID") - .add_argument("voice", type=str, location="json", help="Voice to use for TTS") - .add_argument("text", type=str, location="json", help="Text to convert to audio") - .add_argument("streaming", type=bool, location="json", help="Enable streaming response") -) +class TextToAudioPayload(BaseModel): + message_id: str | None = Field(default=None, description="Message ID") + voice: str | None = Field(default=None, description="Voice to use for TTS") + text: str | None = Field(default=None, description="Text to convert to audio") + streaming: bool | None = Field(default=None, description="Enable streaming response") + + +register_schema_model(service_api_ns, TextToAudioPayload) @service_api_ns.route("/text-to-audio") class TextApi(Resource): - @service_api_ns.expect(text_to_audio_parser) + @service_api_ns.expect(service_api_ns.models[TextToAudioPayload.__name__]) @service_api_ns.doc("text_to_audio") @service_api_ns.doc(description="Convert text to audio using text-to-speech") @service_api_ns.doc( @@ -114,11 +116,11 @@ class TextApi(Resource): Converts the provided text to audio using the specified voice. """ try: - args = text_to_audio_parser.parse_args() + payload = TextToAudioPayload.model_validate(service_api_ns.payload or {}) - message_id = args.get("message_id", None) - text = args.get("text", None) - voice = args.get("voice", None) + message_id = payload.message_id + text = payload.text + voice = payload.voice response = AudioService.transcript_tts( app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id ) diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index c5dd919759..b7fb01c6fe 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -1,10 +1,14 @@ import logging +from typing import Any, Literal +from uuid import UUID from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services +from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( AppUnavailableError, @@ -26,7 +30,6 @@ from core.errors.error import ( from core.helper.trace_id_helper import get_external_trace_id from core.model_runtime.errors.invoke import InvokeError from libs import helper -from libs.helper import uuid_value from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService @@ -36,40 +39,43 @@ from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) -# Define parser for completion API -completion_parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for completion") - .add_argument("query", type=str, location="json", default="", help="The query string") - .add_argument("files", type=list, required=False, location="json", help="List of file attachments") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode") - .add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source") -) +class CompletionRequestPayload(BaseModel): + inputs: dict[str, Any] + query: str = Field(default="") + files: list[dict[str, Any]] | None = None + response_mode: Literal["blocking", "streaming"] | None = None + retriever_from: str = Field(default="dev") -# Define parser for chat API -chat_parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat") - .add_argument("query", type=str, required=True, location="json", help="The chat query") - .add_argument("files", type=list, required=False, location="json", help="List of file attachments") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode") - .add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID") - .add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source") - .add_argument( - "auto_generate_name", - type=bool, - required=False, - default=True, - location="json", - help="Auto generate conversation name", - ) - .add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat") -) + +class ChatRequestPayload(BaseModel): + inputs: dict[str, Any] + query: str + files: list[dict[str, Any]] | None = None + response_mode: Literal["blocking", "streaming"] | None = None + conversation_id: str | None = Field(default=None, description="Conversation UUID") + retriever_from: str = Field(default="dev") + auto_generate_name: bool = Field(default=True, description="Auto generate conversation name") + workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat") + + @field_validator("conversation_id", mode="before") + @classmethod + def normalize_conversation_id(cls, value: str | UUID | None) -> str | None: + """Allow missing or blank conversation IDs; enforce UUID format when provided.""" + if not value: + return None + + try: + return helper.uuid_value(value) + except ValueError as exc: + raise ValueError("conversation_id must be a valid UUID") from exc + + +register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload) @service_api_ns.route("/completion-messages") class CompletionApi(Resource): - @service_api_ns.expect(completion_parser) + @service_api_ns.expect(service_api_ns.models[CompletionRequestPayload.__name__]) @service_api_ns.doc("create_completion") @service_api_ns.doc(description="Create a completion for the given prompt") @service_api_ns.doc( @@ -91,12 +97,13 @@ class CompletionApi(Resource): if app_model.mode != AppMode.COMPLETION: raise AppUnavailableError() - args = completion_parser.parse_args() + payload = CompletionRequestPayload.model_validate(service_api_ns.payload or {}) external_trace_id = get_external_trace_id(request) + args = payload.model_dump(exclude_none=True) if external_trace_id: args["external_trace_id"] = external_trace_id - streaming = args["response_mode"] == "streaming" + streaming = payload.response_mode == "streaming" args["auto_generate_name"] = False @@ -162,7 +169,7 @@ class CompletionStopApi(Resource): @service_api_ns.route("/chat-messages") class ChatApi(Resource): - @service_api_ns.expect(chat_parser) + @service_api_ns.expect(service_api_ns.models[ChatRequestPayload.__name__]) @service_api_ns.doc("create_chat_message") @service_api_ns.doc(description="Send a message in a chat conversation") @service_api_ns.doc( @@ -186,13 +193,14 @@ class ChatApi(Resource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - args = chat_parser.parse_args() + payload = ChatRequestPayload.model_validate(service_api_ns.payload or {}) external_trace_id = get_external_trace_id(request) + args = payload.model_dump(exclude_none=True) if external_trace_id: args["external_trace_id"] = external_trace_id - streaming = args["response_mode"] == "streaming" + streaming = payload.response_mode == "streaming" try: response = AppGenerateService.generate( diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index c4e23dd2e7..be6d837032 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,10 +1,15 @@ -from flask_restx import Resource, reqparse +from typing import Any, Literal +from uuid import UUID + +from flask import request +from flask_restx import Resource from flask_restx._http import HTTPStatus -from flask_restx.inputs import int_range +from pydantic import BaseModel, Field, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound import services +from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token @@ -19,74 +24,51 @@ from fields.conversation_variable_fields import ( build_conversation_variable_infinite_scroll_pagination_model, build_conversation_variable_model, ) -from libs.helper import uuid_value from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService -# Define parsers for conversation APIs -conversation_list_parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args", help="Last conversation ID for pagination") - .add_argument( - "limit", - type=int_range(1, 100), - required=False, - default=20, - location="args", - help="Number of conversations to return", - ) - .add_argument( - "sort_by", - type=str, - choices=["created_at", "-created_at", "updated_at", "-updated_at"], - required=False, - default="-updated_at", - location="args", - help="Sort order for conversations", - ) -) -conversation_rename_parser = ( - reqparse.RequestParser() - .add_argument("name", type=str, required=False, location="json", help="New conversation name") - .add_argument( - "auto_generate", - type=bool, - required=False, - default=False, - location="json", - help="Auto-generate conversation name", +class ConversationListQuery(BaseModel): + last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination") + limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return") + sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field( + default="-updated_at", description="Sort order for conversations" ) -) -conversation_variables_parser = ( - reqparse.RequestParser() - .add_argument("last_id", type=uuid_value, location="args", help="Last variable ID for pagination") - .add_argument( - "limit", - type=int_range(1, 100), - required=False, - default=20, - location="args", - help="Number of variables to return", - ) -) -conversation_variable_update_parser = reqparse.RequestParser().add_argument( - # using lambda is for passing the already-typed value without modification - # if no lambda, it will be converted to string - # the string cannot be converted using json.loads - "value", - required=True, - location="json", - type=lambda x: x, - help="New value for the conversation variable", +class ConversationRenamePayload(BaseModel): + name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)") + auto_generate: bool = Field(default=False, description="Auto-generate conversation name") + + @model_validator(mode="after") + def validate_name_requirement(self): + if not self.auto_generate: + if self.name is None or not self.name.strip(): + raise ValueError("name is required when auto_generate is false") + return self + + +class ConversationVariablesQuery(BaseModel): + last_id: UUID | None = Field(default=None, description="Last variable ID for pagination") + limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return") + + +class ConversationVariableUpdatePayload(BaseModel): + value: Any + + +register_schema_models( + service_api_ns, + ConversationListQuery, + ConversationRenamePayload, + ConversationVariablesQuery, + ConversationVariableUpdatePayload, ) @service_api_ns.route("/conversations") class ConversationApi(Resource): - @service_api_ns.expect(conversation_list_parser) + @service_api_ns.expect(service_api_ns.models[ConversationListQuery.__name__]) @service_api_ns.doc("list_conversations") @service_api_ns.doc(description="List all conversations for the current user") @service_api_ns.doc( @@ -107,7 +89,8 @@ class ConversationApi(Resource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - args = conversation_list_parser.parse_args() + query_args = ConversationListQuery.model_validate(request.args.to_dict()) + last_id = str(query_args.last_id) if query_args.last_id else None try: with Session(db.engine) as session: @@ -115,10 +98,10 @@ class ConversationApi(Resource): session=session, app_model=app_model, user=end_user, - last_id=args["last_id"], - limit=args["limit"], + last_id=last_id, + limit=query_args.limit, invoke_from=InvokeFrom.SERVICE_API, - sort_by=args["sort_by"], + sort_by=query_args.sort_by, ) except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -155,7 +138,7 @@ class ConversationDetailApi(Resource): @service_api_ns.route("/conversations//name") class ConversationRenameApi(Resource): - @service_api_ns.expect(conversation_rename_parser) + @service_api_ns.expect(service_api_ns.models[ConversationRenamePayload.__name__]) @service_api_ns.doc("rename_conversation") @service_api_ns.doc(description="Rename a conversation or auto-generate a name") @service_api_ns.doc(params={"c_id": "Conversation ID"}) @@ -176,17 +159,17 @@ class ConversationRenameApi(Resource): conversation_id = str(c_id) - args = conversation_rename_parser.parse_args() + payload = ConversationRenamePayload.model_validate(service_api_ns.payload or {}) try: - return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) + return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @service_api_ns.route("/conversations//variables") class ConversationVariablesApi(Resource): - @service_api_ns.expect(conversation_variables_parser) + @service_api_ns.expect(service_api_ns.models[ConversationVariablesQuery.__name__]) @service_api_ns.doc("list_conversation_variables") @service_api_ns.doc(description="List all variables for a conversation") @service_api_ns.doc(params={"c_id": "Conversation ID"}) @@ -211,11 +194,12 @@ class ConversationVariablesApi(Resource): conversation_id = str(c_id) - args = conversation_variables_parser.parse_args() + query_args = ConversationVariablesQuery.model_validate(request.args.to_dict()) + last_id = str(query_args.last_id) if query_args.last_id else None try: return ConversationService.get_conversational_variable( - app_model, conversation_id, end_user, args["limit"], args["last_id"] + app_model, conversation_id, end_user, query_args.limit, last_id ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -223,7 +207,7 @@ class ConversationVariablesApi(Resource): @service_api_ns.route("/conversations//variables/") class ConversationVariableDetailApi(Resource): - @service_api_ns.expect(conversation_variable_update_parser) + @service_api_ns.expect(service_api_ns.models[ConversationVariableUpdatePayload.__name__]) @service_api_ns.doc("update_conversation_variable") @service_api_ns.doc(description="Update a conversation variable's value") @service_api_ns.doc(params={"c_id": "Conversation ID", "variable_id": "Variable ID"}) @@ -250,11 +234,11 @@ class ConversationVariableDetailApi(Resource): conversation_id = str(c_id) variable_id = str(variable_id) - args = conversation_variable_update_parser.parse_args() + payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {}) try: return ConversationService.update_conversation_variable( - app_model, conversation_id, variable_id, end_user, args["value"] + app_model, conversation_id, variable_id, end_user, payload.value ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/service_api/app/file_preview.py b/api/controllers/service_api/app/file_preview.py index b8e91f0657..60f422b88e 100644 --- a/api/controllers/service_api/app/file_preview.py +++ b/api/controllers/service_api/app/file_preview.py @@ -1,9 +1,11 @@ import logging from urllib.parse import quote -from flask import Response -from flask_restx import Resource, reqparse +from flask import Response, request +from flask_restx import Resource +from pydantic import BaseModel, Field +from controllers.common.schema import register_schema_model from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( FileAccessDeniedError, @@ -17,10 +19,11 @@ from models.model import App, EndUser, Message, MessageFile, UploadFile logger = logging.getLogger(__name__) -# Define parser for file preview API -file_preview_parser = reqparse.RequestParser().add_argument( - "as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment" -) +class FilePreviewQuery(BaseModel): + as_attachment: bool = Field(default=False, description="Download as attachment") + + +register_schema_model(service_api_ns, FilePreviewQuery) @service_api_ns.route("/files//preview") @@ -32,7 +35,7 @@ class FilePreviewApi(Resource): Files can only be accessed if they belong to messages within the requesting app's context. """ - @service_api_ns.expect(file_preview_parser) + @service_api_ns.expect(service_api_ns.models[FilePreviewQuery.__name__]) @service_api_ns.doc("preview_file") @service_api_ns.doc(description="Preview or download a file uploaded via Service API") @service_api_ns.doc(params={"file_id": "UUID of the file to preview"}) @@ -55,7 +58,7 @@ class FilePreviewApi(Resource): file_id = str(file_id) # Parse query parameters - args = file_preview_parser.parse_args() + args = FilePreviewQuery.model_validate(request.args.to_dict()) # Validate file ownership and get file objects _, upload_file = self._validate_file_ownership(file_id, app_model.id) @@ -67,7 +70,7 @@ class FilePreviewApi(Resource): raise FileNotFoundError(f"Failed to load file content: {str(e)}") # Build response with appropriate headers - response = self._build_file_response(generator, upload_file, args["as_attachment"]) + response = self._build_file_response(generator, upload_file, args.as_attachment) return response diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index b8e5ed28e4..d342f4e661 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,11 +1,15 @@ import json import logging +from typing import Literal +from uuid import UUID -from flask_restx import Api, Namespace, Resource, fields, reqparse -from flask_restx.inputs import int_range +from flask import request +from flask_restx import Namespace, Resource, fields +from pydantic import BaseModel, Field from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services +from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token @@ -13,7 +17,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from fields.conversation_fields import build_message_file_model from fields.message_fields import build_agent_thought_model, build_feedback_model from fields.raws import FilesContainedField -from libs.helper import TimestampField, uuid_value +from libs.helper import TimestampField from models.model import App, AppMode, EndUser from services.errors.message import ( FirstMessageNotExistsError, @@ -25,42 +29,26 @@ from services.message_service import MessageService logger = logging.getLogger(__name__) -# Define parsers for message APIs -message_list_parser = ( - reqparse.RequestParser() - .add_argument("conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID") - .add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination") - .add_argument( - "limit", - type=int_range(1, 100), - required=False, - default=20, - location="args", - help="Number of messages to return", - ) -) - -message_feedback_parser = ( - reqparse.RequestParser() - .add_argument("rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating") - .add_argument("content", type=str, location="json", help="Feedback content") -) - -feedback_list_parser = ( - reqparse.RequestParser() - .add_argument("page", type=int, default=1, location="args", help="Page number") - .add_argument( - "limit", - type=int_range(1, 101), - required=False, - default=20, - location="args", - help="Number of feedbacks per page", - ) -) +class MessageListQuery(BaseModel): + conversation_id: UUID + first_id: UUID | None = None + limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return") -def build_message_model(api_or_ns: Api | Namespace): +class MessageFeedbackPayload(BaseModel): + rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating") + content: str | None = Field(default=None, description="Feedback content") + + +class FeedbackListQuery(BaseModel): + page: int = Field(default=1, ge=1, description="Page number") + limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page") + + +register_schema_models(service_api_ns, MessageListQuery, MessageFeedbackPayload, FeedbackListQuery) + + +def build_message_model(api_or_ns: Namespace): """Build the message model for the API or Namespace.""" # First build the nested models feedback_model = build_feedback_model(api_or_ns) @@ -90,7 +78,7 @@ def build_message_model(api_or_ns: Api | Namespace): return api_or_ns.model("Message", message_fields) -def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): +def build_message_infinite_scroll_pagination_model(api_or_ns: Namespace): """Build the message infinite scroll pagination model for the API or Namespace.""" # Build the nested message model first message_model = build_message_model(api_or_ns) @@ -105,7 +93,7 @@ def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): @service_api_ns.route("/messages") class MessageListApi(Resource): - @service_api_ns.expect(message_list_parser) + @service_api_ns.expect(service_api_ns.models[MessageListQuery.__name__]) @service_api_ns.doc("list_messages") @service_api_ns.doc(description="List messages in a conversation") @service_api_ns.doc( @@ -126,11 +114,13 @@ class MessageListApi(Resource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - args = message_list_parser.parse_args() + query_args = MessageListQuery.model_validate(request.args.to_dict()) + conversation_id = str(query_args.conversation_id) + first_id = str(query_args.first_id) if query_args.first_id else None try: return MessageService.pagination_by_first_id( - app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] + app_model, end_user, conversation_id, first_id, query_args.limit ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -140,7 +130,7 @@ class MessageListApi(Resource): @service_api_ns.route("/messages//feedbacks") class MessageFeedbackApi(Resource): - @service_api_ns.expect(message_feedback_parser) + @service_api_ns.expect(service_api_ns.models[MessageFeedbackPayload.__name__]) @service_api_ns.doc("create_message_feedback") @service_api_ns.doc(description="Submit feedback for a message") @service_api_ns.doc(params={"message_id": "Message ID"}) @@ -159,15 +149,15 @@ class MessageFeedbackApi(Resource): """ message_id = str(message_id) - args = message_feedback_parser.parse_args() + payload = MessageFeedbackPayload.model_validate(service_api_ns.payload or {}) try: MessageService.create_feedback( app_model=app_model, message_id=message_id, user=end_user, - rating=args.get("rating"), - content=args.get("content"), + rating=payload.rating, + content=payload.content, ) except MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -177,7 +167,7 @@ class MessageFeedbackApi(Resource): @service_api_ns.route("/app/feedbacks") class AppGetFeedbacksApi(Resource): - @service_api_ns.expect(feedback_list_parser) + @service_api_ns.expect(service_api_ns.models[FeedbackListQuery.__name__]) @service_api_ns.doc("get_app_feedbacks") @service_api_ns.doc(description="Get all feedbacks for the application") @service_api_ns.doc( @@ -192,8 +182,8 @@ class AppGetFeedbacksApi(Resource): Returns paginated list of all feedback submitted for messages in this app. """ - args = feedback_list_parser.parse_args() - feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"]) + query_args = FeedbackListQuery.model_validate(request.args.to_dict()) + feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=query_args.page, limit=query_args.limit) return {"data": feedbacks} diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index af5eae463d..4964888fd6 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,12 +1,14 @@ import logging +from typing import Any, Literal from dateutil.parser import isoparse from flask import request -from flask_restx import Api, Namespace, Resource, fields, reqparse -from flask_restx.inputs import int_range +from flask_restx import Api, Namespace, Resource, fields +from pydantic import BaseModel, Field from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import BadRequest, InternalServerError, NotFound +from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( CompletionRequestError, @@ -41,37 +43,25 @@ from services.workflow_app_service import WorkflowAppService logger = logging.getLogger(__name__) -# Define parsers for workflow APIs -workflow_run_parser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("files", type=list, required=False, location="json") - .add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") -) -workflow_log_parser = ( - reqparse.RequestParser() - .add_argument("keyword", type=str, location="args") - .add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") - .add_argument("created_at__before", type=str, location="args") - .add_argument("created_at__after", type=str, location="args") - .add_argument( - "created_by_end_user_session_id", - type=str, - location="args", - required=False, - default=None, - ) - .add_argument( - "created_by_account", - type=str, - location="args", - required=False, - default=None, - ) - .add_argument("page", type=int_range(1, 99999), default=1, location="args") - .add_argument("limit", type=int_range(1, 100), default=20, location="args") -) +class WorkflowRunPayload(BaseModel): + inputs: dict[str, Any] + files: list[dict[str, Any]] | None = None + response_mode: Literal["blocking", "streaming"] | None = None + + +class WorkflowLogQuery(BaseModel): + keyword: str | None = None + status: Literal["succeeded", "failed", "stopped"] | None = None + created_at__before: str | None = None + created_at__after: str | None = None + created_by_end_user_session_id: str | None = None + created_by_account: str | None = None + page: int = Field(default=1, ge=1, le=99999) + limit: int = Field(default=20, ge=1, le=100) + + +register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery) workflow_run_fields = { "id": fields.String, @@ -130,7 +120,7 @@ class WorkflowRunDetailApi(Resource): @service_api_ns.route("/workflows/run") class WorkflowRunApi(Resource): - @service_api_ns.expect(workflow_run_parser) + @service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__]) @service_api_ns.doc("run_workflow") @service_api_ns.doc(description="Execute a workflow") @service_api_ns.doc( @@ -154,11 +144,12 @@ class WorkflowRunApi(Resource): if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - args = workflow_run_parser.parse_args() + payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {}) + args = payload.model_dump(exclude_none=True) external_trace_id = get_external_trace_id(request) if external_trace_id: args["external_trace_id"] = external_trace_id - streaming = args.get("response_mode") == "streaming" + streaming = payload.response_mode == "streaming" try: response = AppGenerateService.generate( @@ -185,7 +176,7 @@ class WorkflowRunApi(Resource): @service_api_ns.route("/workflows//run") class WorkflowRunByIdApi(Resource): - @service_api_ns.expect(workflow_run_parser) + @service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__]) @service_api_ns.doc("run_workflow_by_id") @service_api_ns.doc(description="Execute a specific workflow by ID") @service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"}) @@ -209,7 +200,8 @@ class WorkflowRunByIdApi(Resource): if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - args = workflow_run_parser.parse_args() + payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {}) + args = payload.model_dump(exclude_none=True) # Add workflow_id to args for AppGenerateService args["workflow_id"] = workflow_id @@ -217,7 +209,7 @@ class WorkflowRunByIdApi(Resource): external_trace_id = get_external_trace_id(request) if external_trace_id: args["external_trace_id"] = external_trace_id - streaming = args.get("response_mode") == "streaming" + streaming = payload.response_mode == "streaming" try: response = AppGenerateService.generate( @@ -279,7 +271,7 @@ class WorkflowTaskStopApi(Resource): @service_api_ns.route("/workflows/logs") class WorkflowAppLogApi(Resource): - @service_api_ns.expect(workflow_log_parser) + @service_api_ns.expect(service_api_ns.models[WorkflowLogQuery.__name__]) @service_api_ns.doc("get_workflow_logs") @service_api_ns.doc(description="Get workflow execution logs") @service_api_ns.doc( @@ -295,14 +287,11 @@ class WorkflowAppLogApi(Resource): Returns paginated workflow execution logs with filtering options. """ - args = workflow_log_parser.parse_args() + args = WorkflowLogQuery.model_validate(request.args.to_dict()) - args.status = WorkflowExecutionStatus(args.status) if args.status else None - if args.created_at__before: - args.created_at__before = isoparse(args.created_at__before) - - if args.created_at__after: - args.created_at__after = isoparse(args.created_at__after) + status = WorkflowExecutionStatus(args.status) if args.status else None + created_at_before = isoparse(args.created_at__before) if args.created_at__before else None + created_at_after = isoparse(args.created_at__after) if args.created_at__after else None # get paginate workflow app logs workflow_app_service = WorkflowAppService() @@ -311,9 +300,9 @@ class WorkflowAppLogApi(Resource): session=session, app_model=app_model, keyword=args.keyword, - status=args.status, - created_at_before=args.created_at__before, - created_at_after=args.created_at__after, + status=status, + created_at_before=created_at_before, + created_at_after=created_at_after, page=args.page, limit=args.limit, created_by_end_user_session_id=args.created_by_end_user_session_id, diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 4cca3e6ce8..7692aeed23 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,10 +1,12 @@ from typing import Any, Literal, cast from flask import request -from flask_restx import marshal, reqparse +from flask_restx import marshal +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import Forbidden, NotFound import services +from controllers.common.schema import register_schema_models from controllers.console.wraps import edit_permission_required from controllers.service_api import service_api_ns from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError @@ -18,173 +20,83 @@ from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import build_dataset_tag_fields from libs.login import current_user -from libs.validators import validate_description_length from models.account import Account -from models.dataset import Dataset, DatasetPermissionEnum +from models.dataset import DatasetPermissionEnum from models.provider_ids import ModelProviderID from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import RetrievalModel from services.tag_service import TagService -def _validate_name(name): - if not name or len(name) < 1 or len(name) > 40: - raise ValueError("Name must be between 1 to 40 characters.") - return name +class DatasetCreatePayload(BaseModel): + name: str = Field(..., min_length=1, max_length=40) + description: str = Field(default="", description="Dataset description (max 400 chars)", max_length=400) + indexing_technique: Literal["high_quality", "economy"] | None = None + permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME + external_knowledge_api_id: str | None = None + provider: str = "vendor" + external_knowledge_id: str | None = None + retrieval_model: RetrievalModel | None = None + embedding_model: str | None = None + embedding_model_provider: str | None = None -# Define parsers for dataset operations -dataset_create_parser = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="type is required. Name must be between 1 to 40 characters.", - type=_validate_name, - ) - .add_argument( - "description", - type=validate_description_length, - nullable=True, - required=False, - default="", - ) - .add_argument( - "indexing_technique", - type=str, - location="json", - choices=Dataset.INDEXING_TECHNIQUE_LIST, - help="Invalid indexing technique.", - ) - .add_argument( - "permission", - type=str, - location="json", - choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), - help="Invalid permission.", - required=False, - nullable=False, - ) - .add_argument( - "external_knowledge_api_id", - type=str, - nullable=True, - required=False, - default="_validate_name", - ) - .add_argument( - "provider", - type=str, - nullable=True, - required=False, - default="vendor", - ) - .add_argument( - "external_knowledge_id", - type=str, - nullable=True, - required=False, - ) - .add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") - .add_argument("embedding_model", type=str, required=False, nullable=True, location="json") - .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") -) +class DatasetUpdatePayload(BaseModel): + name: str | None = Field(default=None, min_length=1, max_length=40) + description: str | None = Field(default=None, description="Dataset description (max 400 chars)", max_length=400) + indexing_technique: Literal["high_quality", "economy"] | None = None + permission: DatasetPermissionEnum | None = None + embedding_model: str | None = None + embedding_model_provider: str | None = None + retrieval_model: RetrievalModel | None = None + partial_member_list: list[str] | None = None + external_retrieval_model: dict[str, Any] | None = None + external_knowledge_id: str | None = None + external_knowledge_api_id: str | None = None -dataset_update_parser = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - help="type is required. Name must be between 1 to 40 characters.", - type=_validate_name, - ) - .add_argument("description", location="json", store_missing=False, type=validate_description_length) - .add_argument( - "indexing_technique", - type=str, - location="json", - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help="Invalid indexing technique.", - ) - .add_argument( - "permission", - type=str, - location="json", - choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), - help="Invalid permission.", - ) - .add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") - .add_argument("embedding_model_provider", type=str, location="json", help="Invalid embedding model provider.") - .add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") - .add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") - .add_argument( - "external_retrieval_model", - type=dict, - required=False, - nullable=True, - location="json", - help="Invalid external retrieval model.", - ) - .add_argument( - "external_knowledge_id", - type=str, - required=False, - nullable=True, - location="json", - help="Invalid external knowledge id.", - ) - .add_argument( - "external_knowledge_api_id", - type=str, - required=False, - nullable=True, - location="json", - help="Invalid external knowledge api id.", - ) -) -tag_create_parser = reqparse.RequestParser().add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 50 characters.", - type=lambda x: x - if x and 1 <= len(x) <= 50 - else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")), -) +class TagNamePayload(BaseModel): + name: str = Field(..., min_length=1, max_length=50) -tag_update_parser = ( - reqparse.RequestParser() - .add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 50 characters.", - type=lambda x: x - if x and 1 <= len(x) <= 50 - else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")), - ) - .add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) -) -tag_delete_parser = reqparse.RequestParser().add_argument( - "tag_id", nullable=False, required=True, help="Id of a tag.", type=str -) +class TagCreatePayload(TagNamePayload): + pass -tag_binding_parser = ( - reqparse.RequestParser() - .add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.") - .add_argument( - "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required." - ) -) -tag_unbinding_parser = ( - reqparse.RequestParser() - .add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") - .add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") +class TagUpdatePayload(TagNamePayload): + tag_id: str + + +class TagDeletePayload(BaseModel): + tag_id: str + + +class TagBindingPayload(BaseModel): + tag_ids: list[str] + target_id: str + + @field_validator("tag_ids") + @classmethod + def validate_tag_ids(cls, value: list[str]) -> list[str]: + if not value: + raise ValueError("Tag IDs is required.") + return value + + +class TagUnbindingPayload(BaseModel): + tag_id: str + target_id: str + + +register_schema_models( + service_api_ns, + DatasetCreatePayload, + DatasetUpdatePayload, + TagCreatePayload, + TagUpdatePayload, + TagDeletePayload, + TagBindingPayload, + TagUnbindingPayload, ) @@ -239,7 +151,7 @@ class DatasetListApi(DatasetApiResource): response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 - @service_api_ns.expect(dataset_create_parser) + @service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__]) @service_api_ns.doc("create_dataset") @service_api_ns.doc(description="Create a new dataset") @service_api_ns.doc( @@ -252,42 +164,41 @@ class DatasetListApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id): """Resource for creating datasets.""" - args = dataset_create_parser.parse_args() + payload = DatasetCreatePayload.model_validate(service_api_ns.payload or {}) - embedding_model_provider = args.get("embedding_model_provider") - embedding_model = args.get("embedding_model") + embedding_model_provider = payload.embedding_model_provider + embedding_model = payload.embedding_model if embedding_model_provider and embedding_model: DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model) - retrieval_model = args.get("retrieval_model") + retrieval_model = payload.retrieval_model if ( retrieval_model - and retrieval_model.get("reranking_model") - and retrieval_model.get("reranking_model").get("reranking_provider_name") + and retrieval_model.reranking_model + and retrieval_model.reranking_model.reranking_provider_name + and retrieval_model.reranking_model.reranking_model_name ): DatasetService.check_reranking_model_setting( tenant_id, - retrieval_model.get("reranking_model").get("reranking_provider_name"), - retrieval_model.get("reranking_model").get("reranking_model_name"), + retrieval_model.reranking_model.reranking_provider_name, + retrieval_model.reranking_model.reranking_model_name, ) try: assert isinstance(current_user, Account) dataset = DatasetService.create_empty_dataset( tenant_id=tenant_id, - name=args["name"], - description=args["description"], - indexing_technique=args["indexing_technique"], + name=payload.name, + description=payload.description, + indexing_technique=payload.indexing_technique, account=current_user, - permission=args["permission"], - provider=args["provider"], - external_knowledge_api_id=args["external_knowledge_api_id"], - external_knowledge_id=args["external_knowledge_id"], - embedding_model_provider=args["embedding_model_provider"], - embedding_model_name=args["embedding_model"], - retrieval_model=RetrievalModel.model_validate(args["retrieval_model"]) - if args["retrieval_model"] is not None - else None, + permission=str(payload.permission) if payload.permission else None, + provider=payload.provider, + external_knowledge_api_id=payload.external_knowledge_api_id, + external_knowledge_id=payload.external_knowledge_id, + embedding_model_provider=payload.embedding_model_provider, + embedding_model_name=payload.embedding_model, + retrieval_model=payload.retrieval_model, ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() @@ -353,7 +264,7 @@ class DatasetApi(DatasetApiResource): return data, 200 - @service_api_ns.expect(dataset_update_parser) + @service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__]) @service_api_ns.doc("update_dataset") @service_api_ns.doc(description="Update an existing dataset") @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @@ -372,36 +283,45 @@ class DatasetApi(DatasetApiResource): if dataset is None: raise NotFound("Dataset not found.") - args = dataset_update_parser.parse_args() - data = request.get_json() + payload_dict = service_api_ns.payload or {} + payload = DatasetUpdatePayload.model_validate(payload_dict) + update_data = payload.model_dump(exclude_unset=True) + if payload.permission is not None: + update_data["permission"] = str(payload.permission) + if payload.retrieval_model is not None: + update_data["retrieval_model"] = payload.retrieval_model.model_dump() # check embedding model setting - embedding_model_provider = data.get("embedding_model_provider") - embedding_model = data.get("embedding_model") - if data.get("indexing_technique") == "high_quality" or embedding_model_provider: + embedding_model_provider = payload.embedding_model_provider + embedding_model = payload.embedding_model + if payload.indexing_technique == "high_quality" or embedding_model_provider: if embedding_model_provider and embedding_model: DatasetService.check_embedding_model_setting( dataset.tenant_id, embedding_model_provider, embedding_model ) - retrieval_model = data.get("retrieval_model") + retrieval_model = payload.retrieval_model if ( retrieval_model - and retrieval_model.get("reranking_model") - and retrieval_model.get("reranking_model").get("reranking_provider_name") + and retrieval_model.reranking_model + and retrieval_model.reranking_model.reranking_provider_name + and retrieval_model.reranking_model.reranking_model_name ): DatasetService.check_reranking_model_setting( dataset.tenant_id, - retrieval_model.get("reranking_model").get("reranking_provider_name"), - retrieval_model.get("reranking_model").get("reranking_model_name"), + retrieval_model.reranking_model.reranking_provider_name, + retrieval_model.reranking_model.reranking_model_name, ) # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator DatasetPermissionService.check_permission( - current_user, dataset, data.get("permission"), data.get("partial_member_list") + current_user, + dataset, + str(payload.permission) if payload.permission else None, + payload.partial_member_list, ) - dataset = DatasetService.update_dataset(dataset_id_str, args, current_user) + dataset = DatasetService.update_dataset(dataset_id_str, update_data, current_user) if dataset is None: raise NotFound("Dataset not found.") @@ -410,15 +330,10 @@ class DatasetApi(DatasetApiResource): assert isinstance(current_user, Account) tenant_id = current_user.current_tenant_id - if data.get("partial_member_list") and data.get("permission") == "partial_members": - DatasetPermissionService.update_partial_member_list( - tenant_id, dataset_id_str, data.get("partial_member_list") - ) + if payload.partial_member_list and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM: + DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list) # clear partial member list when permission is only_me or all_team_members - elif ( - data.get("permission") == DatasetPermissionEnum.ONLY_ME - or data.get("permission") == DatasetPermissionEnum.ALL_TEAM - ): + elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}: DatasetPermissionService.clear_partial_member_list(dataset_id_str) partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) @@ -556,7 +471,7 @@ class DatasetTagsApi(DatasetApiResource): return tags, 200 - @service_api_ns.expect(tag_create_parser) + @service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__]) @service_api_ns.doc("create_dataset_tag") @service_api_ns.doc(description="Add a knowledge type tag") @service_api_ns.doc( @@ -574,14 +489,13 @@ class DatasetTagsApi(DatasetApiResource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = tag_create_parser.parse_args() - args["type"] = "knowledge" - tag = TagService.save_tags(args) + payload = TagCreatePayload.model_validate(service_api_ns.payload or {}) + tag = TagService.save_tags({"name": payload.name, "type": "knowledge"}) response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} return response, 200 - @service_api_ns.expect(tag_update_parser) + @service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__]) @service_api_ns.doc("update_dataset_tag") @service_api_ns.doc(description="Update a knowledge type tag") @service_api_ns.doc( @@ -598,10 +512,10 @@ class DatasetTagsApi(DatasetApiResource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = tag_update_parser.parse_args() - args["type"] = "knowledge" - tag_id = args["tag_id"] - tag = TagService.update_tags(args, tag_id) + payload = TagUpdatePayload.model_validate(service_api_ns.payload or {}) + params = {"name": payload.name, "type": "knowledge"} + tag_id = payload.tag_id + tag = TagService.update_tags(params, tag_id) binding_count = TagService.get_tag_binding_count(tag_id) @@ -609,7 +523,7 @@ class DatasetTagsApi(DatasetApiResource): return response, 200 - @service_api_ns.expect(tag_delete_parser) + @service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__]) @service_api_ns.doc("delete_dataset_tag") @service_api_ns.doc(description="Delete a knowledge type tag") @service_api_ns.doc( @@ -623,15 +537,15 @@ class DatasetTagsApi(DatasetApiResource): @edit_permission_required def delete(self, _, dataset_id): """Delete a knowledge type tag.""" - args = tag_delete_parser.parse_args() - TagService.delete_tag(args["tag_id"]) + payload = TagDeletePayload.model_validate(service_api_ns.payload or {}) + TagService.delete_tag(payload.tag_id) return 204 @service_api_ns.route("/datasets/tags/binding") class DatasetTagBindingApi(DatasetApiResource): - @service_api_ns.expect(tag_binding_parser) + @service_api_ns.expect(service_api_ns.models[TagBindingPayload.__name__]) @service_api_ns.doc("bind_dataset_tags") @service_api_ns.doc(description="Bind tags to a dataset") @service_api_ns.doc( @@ -648,16 +562,15 @@ class DatasetTagBindingApi(DatasetApiResource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = tag_binding_parser.parse_args() - args["type"] = "knowledge" - TagService.save_tag_binding(args) + payload = TagBindingPayload.model_validate(service_api_ns.payload or {}) + TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"}) return 204 @service_api_ns.route("/datasets/tags/unbinding") class DatasetTagUnbindingApi(DatasetApiResource): - @service_api_ns.expect(tag_unbinding_parser) + @service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__]) @service_api_ns.doc("unbind_dataset_tag") @service_api_ns.doc(description="Unbind a tag from a dataset") @service_api_ns.doc( @@ -674,9 +587,8 @@ class DatasetTagUnbindingApi(DatasetApiResource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - args = tag_unbinding_parser.parse_args() - args["type"] = "knowledge" - TagService.delete_tag_binding(args) + payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {}) + TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"}) return 204 diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index ed47e706b6..c800c0e4e1 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -3,8 +3,8 @@ from typing import Self from uuid import UUID from flask import request -from flask_restx import marshal, reqparse -from pydantic import BaseModel, model_validator +from flask_restx import marshal +from pydantic import BaseModel, Field, model_validator from sqlalchemy import desc, select from werkzeug.exceptions import Forbidden, NotFound @@ -37,22 +37,19 @@ from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel from services.file_service import FileService -# Define parsers for document operations -document_text_create_parser = ( - reqparse.RequestParser() - .add_argument("name", type=str, required=True, nullable=False, location="json") - .add_argument("text", type=str, required=True, nullable=False, location="json") - .add_argument("process_rule", type=dict, required=False, nullable=True, location="json") - .add_argument("original_document_id", type=str, required=False, location="json") - .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json") - .add_argument( - "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" - ) - .add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") - .add_argument("embedding_model", type=str, required=False, nullable=True, location="json") - .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") -) + +class DocumentTextCreatePayload(BaseModel): + name: str + text: str + process_rule: ProcessRule | None = None + original_document_id: str | None = None + doc_form: str = Field(default="text_model") + doc_language: str = Field(default="English") + indexing_technique: str | None = None + retrieval_model: RetrievalModel | None = None + embedding_model: str | None = None + embedding_model_provider: str | None = None + DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -72,7 +69,7 @@ class DocumentTextUpdate(BaseModel): return self -for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]: +for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate]: service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore @@ -83,7 +80,7 @@ for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]: class DocumentAddByTextApi(DatasetApiResource): """Resource for documents.""" - @service_api_ns.expect(document_text_create_parser) + @service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__]) @service_api_ns.doc("create_document_by_text") @service_api_ns.doc(description="Create a new document by providing text content") @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @@ -99,7 +96,8 @@ class DocumentAddByTextApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create document by text.""" - args = document_text_create_parser.parse_args() + payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {}) + args = payload.model_dump(exclude_none=True) dataset_id = str(dataset_id) tenant_id = str(tenant_id) @@ -111,33 +109,29 @@ class DocumentAddByTextApi(DatasetApiResource): if not dataset.indexing_technique and not args["indexing_technique"]: raise ValueError("indexing_technique is required.") - text = args.get("text") - name = args.get("name") - if text is None or name is None: - raise ValueError("Both 'text' and 'name' must be non-null values.") - - embedding_model_provider = args.get("embedding_model_provider") - embedding_model = args.get("embedding_model") + embedding_model_provider = payload.embedding_model_provider + embedding_model = payload.embedding_model if embedding_model_provider and embedding_model: DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model) - retrieval_model = args.get("retrieval_model") + retrieval_model = payload.retrieval_model if ( retrieval_model - and retrieval_model.get("reranking_model") - and retrieval_model.get("reranking_model").get("reranking_provider_name") + and retrieval_model.reranking_model + and retrieval_model.reranking_model.reranking_provider_name + and retrieval_model.reranking_model.reranking_model_name ): DatasetService.check_reranking_model_setting( tenant_id, - retrieval_model.get("reranking_model").get("reranking_provider_name"), - retrieval_model.get("reranking_model").get("reranking_model_name"), + retrieval_model.reranking_model.reranking_provider_name, + retrieval_model.reranking_model.reranking_model_name, ) if not current_user: raise ValueError("current_user is required") upload_file = FileService(db.engine).upload_text( - text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id + text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id ) data_source = { "type": "upload_file", @@ -174,7 +168,7 @@ class DocumentAddByTextApi(DatasetApiResource): class DocumentUpdateByTextApi(DatasetApiResource): """Resource for update documents.""" - @service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__], validate=True) + @service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__]) @service_api_ns.doc("update_document_by_text") @service_api_ns.doc(description="Update an existing document by providing text content") @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @@ -189,22 +183,23 @@ class DocumentUpdateByTextApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): """Update document by text.""" - args = DocumentTextUpdate.model_validate(service_api_ns.payload).model_dump(exclude_unset=True) + payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {}) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first() - + args = payload.model_dump(exclude_none=True) if not dataset: raise ValueError("Dataset does not exist.") - retrieval_model = args.get("retrieval_model") + retrieval_model = payload.retrieval_model if ( retrieval_model - and retrieval_model.get("reranking_model") - and retrieval_model.get("reranking_model").get("reranking_provider_name") + and retrieval_model.reranking_model + and retrieval_model.reranking_model.reranking_provider_name + and retrieval_model.reranking_model.reranking_model_name ): DatasetService.check_reranking_model_setting( tenant_id, - retrieval_model.get("reranking_model").get("reranking_provider_name"), - retrieval_model.get("reranking_model").get("reranking_model_name"), + retrieval_model.reranking_model.reranking_provider_name, + retrieval_model.reranking_model.reranking_model_name, ) # indexing_technique is already set in dataset since this is an update diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index f646f1f4fa..aab25c1af3 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -1,9 +1,11 @@ from typing import Literal from flask_login import current_user -from flask_restx import marshal, reqparse +from flask_restx import marshal +from pydantic import BaseModel from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_model, register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check from fields.dataset_fields import dataset_metadata_fields @@ -14,25 +16,18 @@ from services.entities.knowledge_entities.knowledge_entities import ( ) from services.metadata_service import MetadataService -# Define parsers for metadata APIs -metadata_create_parser = ( - reqparse.RequestParser() - .add_argument("type", type=str, required=True, nullable=False, location="json", help="Metadata type") - .add_argument("name", type=str, required=True, nullable=False, location="json", help="Metadata name") -) -metadata_update_parser = reqparse.RequestParser().add_argument( - "name", type=str, required=True, nullable=False, location="json", help="New metadata name" -) +class MetadataUpdatePayload(BaseModel): + name: str -document_metadata_parser = reqparse.RequestParser().add_argument( - "operation_data", type=list, required=True, nullable=False, location="json", help="Metadata operation data" -) + +register_schema_model(service_api_ns, MetadataUpdatePayload) +register_schema_models(service_api_ns, MetadataArgs, MetadataOperationData) @service_api_ns.route("/datasets//metadata") class DatasetMetadataCreateServiceApi(DatasetApiResource): - @service_api_ns.expect(metadata_create_parser) + @service_api_ns.expect(service_api_ns.models[MetadataArgs.__name__]) @service_api_ns.doc("create_dataset_metadata") @service_api_ns.doc(description="Create metadata for a dataset") @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @@ -46,8 +41,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create metadata for a dataset.""" - args = metadata_create_parser.parse_args() - metadata_args = MetadataArgs.model_validate(args) + metadata_args = MetadataArgs.model_validate(service_api_ns.payload or {}) dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -79,7 +73,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): @service_api_ns.route("/datasets//metadata/") class DatasetMetadataServiceApi(DatasetApiResource): - @service_api_ns.expect(metadata_update_parser) + @service_api_ns.expect(service_api_ns.models[MetadataUpdatePayload.__name__]) @service_api_ns.doc("update_dataset_metadata") @service_api_ns.doc(description="Update metadata name") @service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"}) @@ -93,7 +87,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, tenant_id, dataset_id, metadata_id): """Update metadata name.""" - args = metadata_update_parser.parse_args() + payload = MetadataUpdatePayload.model_validate(service_api_ns.payload or {}) dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) @@ -102,7 +96,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"]) + metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name) return marshal(metadata, dataset_metadata_fields), 200 @service_api_ns.doc("delete_dataset_metadata") @@ -175,7 +169,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): @service_api_ns.route("/datasets//documents/metadata") class DocumentMetadataEditServiceApi(DatasetApiResource): - @service_api_ns.expect(document_metadata_parser) + @service_api_ns.expect(service_api_ns.models[MetadataOperationData.__name__]) @service_api_ns.doc("update_documents_metadata") @service_api_ns.doc(description="Update metadata for multiple documents") @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @@ -195,8 +189,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - args = document_metadata_parser.parse_args() - metadata_args = MetadataOperationData.model_validate(args) + metadata_args = MetadataOperationData.model_validate(service_api_ns.payload or {}) MetadataService.update_documents_metadata(dataset, metadata_args) diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index c177e9180a..0a2017e2bd 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -4,12 +4,12 @@ from collections.abc import Generator from typing import Any from flask import request -from flask_restx import reqparse -from flask_restx.reqparse import ParseResult, RequestParser +from pydantic import BaseModel from werkzeug.exceptions import Forbidden import services from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError +from controllers.common.schema import register_schema_model from controllers.service_api import service_api_ns from controllers.service_api.dataset.error import PipelineRunError from controllers.service_api.wraps import DatasetApiResource @@ -22,11 +22,25 @@ from models.dataset import Pipeline from models.engine import db from services.errors.file import FileTooLargeError, UnsupportedFileTypeError from services.file_service import FileService -from services.rag_pipeline.entity.pipeline_service_api_entities import DatasourceNodeRunApiEntity +from services.rag_pipeline.entity.pipeline_service_api_entities import ( + DatasourceNodeRunApiEntity, + PipelineRunApiEntity, +) from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService from services.rag_pipeline.rag_pipeline import RagPipelineService +class DatasourceNodeRunPayload(BaseModel): + inputs: dict[str, Any] + datasource_type: str + credential_id: str | None = None + is_published: bool + + +register_schema_model(service_api_ns, DatasourceNodeRunPayload) +register_schema_model(service_api_ns, PipelineRunApiEntity) + + @service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins") class DatasourcePluginsApi(DatasetApiResource): """Resource for datasource plugins.""" @@ -88,22 +102,20 @@ class DatasourceNodeRunApi(DatasetApiResource): 401: "Unauthorized - invalid API token", } ) + @service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__]) def post(self, tenant_id: str, dataset_id: str, node_id: str): """Resource for getting datasource plugins.""" - # Get query parameter to determine published or draft - parser: RequestParser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("datasource_type", type=str, required=True, location="json") - .add_argument("credential_id", type=str, required=False, location="json") - .add_argument("is_published", type=bool, required=True, location="json") - ) - args: ParseResult = parser.parse_args() - - datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args) + payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {}) assert isinstance(current_user, Account) rag_pipeline_service: RagPipelineService = RagPipelineService() pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id) + datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate( + { + **payload.model_dump(exclude_none=True), + "pipeline_id": str(pipeline.id), + "node_id": node_id, + } + ) return helper.compact_generate_response( PipelineGenerator.convert_to_event_stream( rag_pipeline_service.run_datasource_workflow_node( @@ -147,25 +159,10 @@ class PipelineRunApi(DatasetApiResource): 401: "Unauthorized - invalid API token", } ) + @service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__]) def post(self, tenant_id: str, dataset_id: str): """Resource for running a rag pipeline.""" - parser: RequestParser = ( - reqparse.RequestParser() - .add_argument("inputs", type=dict, required=True, nullable=False, location="json") - .add_argument("datasource_type", type=str, required=True, location="json") - .add_argument("datasource_info_list", type=list, required=True, location="json") - .add_argument("start_node_id", type=str, required=True, location="json") - .add_argument("is_published", type=bool, required=True, default=True, location="json") - .add_argument( - "response_mode", - type=str, - required=True, - choices=["streaming", "blocking"], - default="blocking", - location="json", - ) - ) - args: ParseResult = parser.parse_args() + payload = PipelineRunApiEntity.model_validate(service_api_ns.payload or {}) if not isinstance(current_user, Account): raise Forbidden() @@ -176,9 +173,9 @@ class PipelineRunApi(DatasetApiResource): response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate( pipeline=pipeline, user=current_user, - args=args, - invoke_from=InvokeFrom.PUBLISHED if args.get("is_published") else InvokeFrom.DEBUGGER, - streaming=args.get("response_mode") == "streaming", + args=payload.model_dump(), + invoke_from=InvokeFrom.PUBLISHED if payload.is_published else InvokeFrom.DEBUGGER, + streaming=payload.response_mode == "streaming", ) return helper.compact_generate_response(response) diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 9ca500b044..b242fd2c3e 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -1,8 +1,12 @@ +from typing import Any + from flask import request -from flask_restx import marshal, reqparse +from flask_restx import marshal +from pydantic import BaseModel, Field from werkzeug.exceptions import NotFound from configs import dify_config +from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.wraps import ( @@ -24,34 +28,42 @@ from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexing from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError -# Define parsers for segment operations -segment_create_parser = reqparse.RequestParser().add_argument( - "segments", type=list, required=False, nullable=True, location="json" -) -segment_list_parser = ( - reqparse.RequestParser() - .add_argument("status", type=str, action="append", default=[], location="args") - .add_argument("keyword", type=str, default=None, location="args") -) +class SegmentCreatePayload(BaseModel): + segments: list[dict[str, Any]] | None = None -segment_update_parser = reqparse.RequestParser().add_argument( - "segment", type=dict, required=False, nullable=True, location="json" -) -child_chunk_create_parser = reqparse.RequestParser().add_argument( - "content", type=str, required=True, nullable=False, location="json" -) +class SegmentListQuery(BaseModel): + status: list[str] = Field(default_factory=list) + keyword: str | None = None -child_chunk_list_parser = ( - reqparse.RequestParser() - .add_argument("limit", type=int, default=20, location="args") - .add_argument("keyword", type=str, default=None, location="args") - .add_argument("page", type=int, default=1, location="args") -) -child_chunk_update_parser = reqparse.RequestParser().add_argument( - "content", type=str, required=True, nullable=False, location="json" +class SegmentUpdatePayload(BaseModel): + segment: SegmentUpdateArgs + + +class ChildChunkCreatePayload(BaseModel): + content: str + + +class ChildChunkListQuery(BaseModel): + limit: int = Field(default=20, ge=1) + keyword: str | None = None + page: int = Field(default=1, ge=1) + + +class ChildChunkUpdatePayload(BaseModel): + content: str + + +register_schema_models( + service_api_ns, + SegmentCreatePayload, + SegmentListQuery, + SegmentUpdatePayload, + ChildChunkCreatePayload, + ChildChunkListQuery, + ChildChunkUpdatePayload, ) @@ -59,7 +71,7 @@ child_chunk_update_parser = reqparse.RequestParser().add_argument( class SegmentApi(DatasetApiResource): """Resource for segments.""" - @service_api_ns.expect(segment_create_parser) + @service_api_ns.expect(service_api_ns.models[SegmentCreatePayload.__name__]) @service_api_ns.doc("create_segments") @service_api_ns.doc(description="Create segments in a document") @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @@ -106,20 +118,20 @@ class SegmentApi(DatasetApiResource): except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # validate args - args = segment_create_parser.parse_args() - if args["segments"] is not None: + payload = SegmentCreatePayload.model_validate(service_api_ns.payload or {}) + if payload.segments is not None: segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST - if segments_limit > 0 and len(args["segments"]) > segments_limit: + if segments_limit > 0 and len(payload.segments) > segments_limit: raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.") - for args_item in args["segments"]: + for args_item in payload.segments: SegmentService.segment_create_args_validate(args_item, document) - segments = SegmentService.multi_create_segment(args["segments"], document, dataset) + segments = SegmentService.multi_create_segment(payload.segments, document, dataset) return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200 else: return {"error": "Segments is required"}, 400 - @service_api_ns.expect(segment_list_parser) + @service_api_ns.expect(service_api_ns.models[SegmentListQuery.__name__]) @service_api_ns.doc("list_segments") @service_api_ns.doc(description="List segments in a document") @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @@ -160,13 +172,18 @@ class SegmentApi(DatasetApiResource): except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) - args = segment_list_parser.parse_args() + args = SegmentListQuery.model_validate( + { + "status": request.args.getlist("status"), + "keyword": request.args.get("keyword"), + } + ) segments, total = SegmentService.get_segments( document_id=document_id, tenant_id=current_tenant_id, - status_list=args["status"], - keyword=args["keyword"], + status_list=args.status, + keyword=args.keyword, page=page, limit=limit, ) @@ -217,7 +234,7 @@ class DatasetSegmentApi(DatasetApiResource): SegmentService.delete_segment(segment, document, dataset) return 204 - @service_api_ns.expect(segment_update_parser) + @service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__]) @service_api_ns.doc("update_segment") @service_api_ns.doc(description="Update a specific segment") @service_api_ns.doc( @@ -265,12 +282,9 @@ class DatasetSegmentApi(DatasetApiResource): if not segment: raise NotFound("Segment not found.") - # validate args - args = segment_update_parser.parse_args() + payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {}) - updated_segment = SegmentService.update_segment( - SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset - ) + updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset) return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200 @service_api_ns.doc("get_segment") @@ -308,7 +322,7 @@ class DatasetSegmentApi(DatasetApiResource): class ChildChunkApi(DatasetApiResource): """Resource for child chunks.""" - @service_api_ns.expect(child_chunk_create_parser) + @service_api_ns.expect(service_api_ns.models[ChildChunkCreatePayload.__name__]) @service_api_ns.doc("create_child_chunk") @service_api_ns.doc(description="Create a new child chunk for a segment") @service_api_ns.doc( @@ -360,16 +374,16 @@ class ChildChunkApi(DatasetApiResource): raise ProviderNotInitializeError(ex.description) # validate args - args = child_chunk_create_parser.parse_args() + payload = ChildChunkCreatePayload.model_validate(service_api_ns.payload or {}) try: - child_chunk = SegmentService.create_child_chunk(args["content"], segment, document, dataset) + child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunk, child_chunk_fields)}, 200 - @service_api_ns.expect(child_chunk_list_parser) + @service_api_ns.expect(service_api_ns.models[ChildChunkListQuery.__name__]) @service_api_ns.doc("list_child_chunks") @service_api_ns.doc(description="List child chunks for a segment") @service_api_ns.doc( @@ -400,11 +414,17 @@ class ChildChunkApi(DatasetApiResource): if not segment: raise NotFound("Segment not found.") - args = child_chunk_list_parser.parse_args() + args = ChildChunkListQuery.model_validate( + { + "limit": request.args.get("limit", default=20, type=int), + "keyword": request.args.get("keyword"), + "page": request.args.get("page", default=1, type=int), + } + ) - page = args["page"] - limit = min(args["limit"], 100) - keyword = args["keyword"] + page = args.page + limit = min(args.limit, 100) + keyword = args.keyword child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword) @@ -480,7 +500,7 @@ class DatasetChildChunkApi(DatasetApiResource): return 204 - @service_api_ns.expect(child_chunk_update_parser) + @service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__]) @service_api_ns.doc("update_child_chunk") @service_api_ns.doc(description="Update a specific child chunk") @service_api_ns.doc( @@ -533,10 +553,10 @@ class DatasetChildChunkApi(DatasetApiResource): raise NotFound("Child chunk not found.") # validate args - args = child_chunk_update_parser.parse_args() + payload = ChildChunkUpdatePayload.model_validate(service_api_ns.payload or {}) try: - child_chunk = SegmentService.update_child_chunk(args["content"], child_chunk, segment, document, dataset) + child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index c07e18c686..24acced0d1 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,3 +1,4 @@ +import logging import time from collections.abc import Callable from datetime import timedelta @@ -28,6 +29,8 @@ P = ParamSpec("P") R = TypeVar("R") T = TypeVar("T") +logger = logging.getLogger(__name__) + class WhereisUserArg(StrEnum): """ @@ -238,8 +241,8 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None): # Basic check: UUIDs are 36 chars with hyphens if len(str_id) == 36 and str_id.count("-") == 4: dataset_id = str_id - except: - pass + except Exception: + logger.exception("Failed to parse dataset_id from class method args") elif len(args) > 0: # Not a class method, check if args[0] looks like a UUID potential_id = args[0] @@ -247,8 +250,8 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None): str_id = str(potential_id) if len(str_id) == 36 and str_id.count("-") == 4: dataset_id = str_id - except: - pass + except Exception: + logger.exception("Failed to parse dataset_id from positional args") # Validate dataset if dataset_id is provided if dataset_id: @@ -316,18 +319,16 @@ def validate_and_get_api_token(scope: str | None = None): ApiToken.type == scope, ) .values(last_used_at=current_time) - .returning(ApiToken) ) + stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) result = session.execute(update_stmt) - api_token = result.scalar_one_or_none() + api_token = session.scalar(stmt) + + if hasattr(result, "rowcount") and result.rowcount > 0: + session.commit() if not api_token: - stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) - api_token = session.scalar(stmt) - if not api_token: - raise Unauthorized("Access token is invalid") - else: - session.commit() + raise Unauthorized("Access token is invalid") return api_token diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 25ad6dc060..b32e35d0ca 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -1,4 +1,5 @@ import json +import logging from abc import ABC, abstractmethod from collections.abc import Generator, Mapping, Sequence from typing import Any @@ -23,6 +24,8 @@ from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine from models.model import Message +logger = logging.getLogger(__name__) + class CotAgentRunner(BaseAgentRunner, ABC): _is_first_iteration = True @@ -400,8 +403,8 @@ class CotAgentRunner(BaseAgentRunner, ABC): action_input=json.loads(message.tool_calls[0].function.arguments), ) current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict()) - except: - pass + except Exception: + logger.exception("Failed to parse tool call from assistant message") elif isinstance(message, ToolPromptMessage): if current_scratchpad: assert isinstance(message.content, str) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 2aa36ddc49..93f2742599 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -2,6 +2,7 @@ from collections.abc import Sequence from enum import StrEnum, auto from typing import Any, Literal +from jsonschema import Draft7Validator, SchemaError from pydantic import BaseModel, Field, field_validator from core.file import FileTransferMethod, FileType, FileUploadConfig @@ -98,6 +99,7 @@ class VariableEntityType(StrEnum): FILE = "file" FILE_LIST = "file-list" CHECKBOX = "checkbox" + JSON_OBJECT = "json_object" class VariableEntity(BaseModel): @@ -118,6 +120,7 @@ class VariableEntity(BaseModel): allowed_file_types: Sequence[FileType] | None = Field(default_factory=list) allowed_file_extensions: Sequence[str] | None = Field(default_factory=list) allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list) + json_schema: dict[str, Any] | None = Field(default=None) @field_validator("description", mode="before") @classmethod @@ -129,6 +132,17 @@ class VariableEntity(BaseModel): def convert_none_options(cls, v: Any) -> Sequence[str]: return v or [] + @field_validator("json_schema") + @classmethod + def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None: + if schema is None: + return None + try: + Draft7Validator.check_schema(schema) + except SchemaError as e: + raise ValueError(f"Invalid JSON schema: {e.message}") + return schema + class RagPipelineVariableEntity(VariableEntity): """ diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index c029e00553..ee092e55c5 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -35,6 +35,7 @@ from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from extensions.ext_redis import redis_client +from extensions.otel import WorkflowAppRunnerHandler, trace_span from models import Workflow from models.enums import UserFrom from models.model import App, Conversation, Message, MessageAnnotation @@ -80,6 +81,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self._workflow_execution_repository = workflow_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository + @trace_span(WorkflowAppRunnerHandler) def run(self): app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index c98bc1ffdd..da1e9f19b6 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -62,8 +62,7 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.enums import WorkflowExecutionStatus from core.workflow.nodes import NodeType from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory @@ -73,7 +72,7 @@ from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile from models.enums import CreatorUserRole -from models.workflow import Workflow, WorkflowNodeExecutionModel +from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -581,7 +580,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): with self._database_session() as session: # Save message - self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager) + self._save_message(session=session, graph_runtime_state=resolved_state) yield workflow_finish_resp elif event.stopped_by in ( @@ -591,7 +590,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): # When hitting input-moderation or annotation-reply, the workflow will not start with self._database_session() as session: # Save message - self._save_message(session=session, trace_manager=trace_manager) + self._save_message(session=session) yield self._message_end_to_stream_response() @@ -600,7 +599,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): event: QueueAdvancedChatMessageEndEvent, *, graph_runtime_state: GraphRuntimeState | None = None, - trace_manager: TraceQueueManager | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: """Handle advanced chat message end events.""" @@ -618,7 +616,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): # Save message with self._database_session() as session: - self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager) + self._save_message(session=session, graph_runtime_state=resolved_state) yield self._message_end_to_stream_response() @@ -770,15 +768,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): tts_publisher.publish(None) if self._conversation_name_generate_thread: - self._conversation_name_generate_thread.join() + logger.debug("Conversation name generation running as daemon thread") - def _save_message( - self, - *, - session: Session, - graph_runtime_state: GraphRuntimeState | None = None, - trace_manager: TraceQueueManager | None = None, - ): + def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None): message = self._get_message(session=session) # If there are assistant files, remove markdown image links from answer @@ -817,14 +809,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): metadata = self._task_state.metadata.model_dump() message.message_metadata = json.dumps(jsonable_encoder(metadata)) - - # Extract model provider and model_id from workflow node executions for tracing - if message.workflow_run_id: - model_info = self._extract_model_info_from_workflow(session, message.workflow_run_id) - if model_info: - message.model_provider = model_info.get("provider") - message.model_id = model_info.get("model") - message_files = [ MessageFile( message_id=message.id, @@ -842,68 +826,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ] session.add_all(message_files) - # Trigger MESSAGE_TRACE for tracing integrations - if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id - ) - ) - - def _extract_model_info_from_workflow(self, session: Session, workflow_run_id: str) -> dict[str, str] | None: - """ - Extract model provider and model_id from workflow node executions. - Returns dict with 'provider' and 'model' keys, or None if not found. - """ - try: - # Query workflow node executions for LLM or Agent nodes - stmt = ( - select(WorkflowNodeExecutionModel) - .where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) - .where(WorkflowNodeExecutionModel.node_type.in_(["llm", "agent"])) - .order_by(WorkflowNodeExecutionModel.created_at.desc()) - .limit(1) - ) - node_execution = session.scalar(stmt) - - if not node_execution: - return None - - # Try to extract from execution_metadata for agent nodes - if node_execution.execution_metadata: - try: - metadata = json.loads(node_execution.execution_metadata) - agent_log = metadata.get("agent_log", []) - # Look for the first agent thought with provider info - for log_entry in agent_log: - entry_metadata = log_entry.get("metadata", {}) - provider_str = entry_metadata.get("provider") - if provider_str: - # Parse format like "langgenius/deepseek/deepseek" - parts = provider_str.split("/") - if len(parts) >= 3: - return {"provider": parts[1], "model": parts[2]} - elif len(parts) == 2: - return {"provider": parts[0], "model": parts[1]} - except (json.JSONDecodeError, KeyError, AttributeError) as e: - logger.debug("Failed to parse execution_metadata: %s", e) - - # Try to extract from process_data for llm nodes - if node_execution.process_data: - try: - process_data = json.loads(node_execution.process_data) - provider = process_data.get("model_provider") - model = process_data.get("model_name") - if provider and model: - return {"provider": provider, "model": model} - except (json.JSONDecodeError, KeyError) as e: - logger.debug("Failed to parse process_data: %s", e) - - return None - except Exception as e: - logger.warning("Failed to extract model info from workflow: %s", e) - return None - def _seed_graph_runtime_state_from_queue_manager(self) -> None: """Bootstrap the cached runtime state from the queue manager when present.""" candidate = self._base_task_pipeline.queue_manager.graph_runtime_state diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 1c6ca87925..1b0474142e 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -99,6 +99,15 @@ class BaseAppGenerator: if value is None: return None + # Treat empty placeholders for optional file inputs as unset + if ( + variable_entity.type in {VariableEntityType.FILE, VariableEntityType.FILE_LIST} + and not variable_entity.required + ): + # Treat empty string (frontend default) or empty list as unset + if not value and isinstance(value, (str, list)): + return None + if variable_entity.type in { VariableEntityType.TEXT_INPUT, VariableEntityType.SELECT, diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 9a9832dd4a..e2e6c11480 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -83,6 +83,7 @@ class AppRunner: context: str | None = None, memory: TokenBufferMemory | None = None, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: """ Organize prompt messages @@ -111,6 +112,7 @@ class AppRunner: memory=memory, model_config=model_config, image_detail_config=image_detail_config, + context_files=context_files, ) else: memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 53188cf506..f8338b226b 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.file import File from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import ImagePromptMessageContent @@ -146,6 +147,7 @@ class ChatAppRunner(AppRunner): # get context from datasets context = None + context_files: list[File] = [] if app_config.dataset and app_config.dataset.dataset_ids: hit_callback = DatasetIndexToolCallbackHandler( queue_manager, @@ -156,7 +158,7 @@ class ChatAppRunner(AppRunner): ) dataset_retrieval = DatasetRetrieval(application_generate_entity) - context = dataset_retrieval.retrieve( + context, retrieved_files = dataset_retrieval.retrieve( app_id=app_record.id, user_id=application_generate_entity.user_id, tenant_id=app_record.tenant_id, @@ -171,7 +173,11 @@ class ChatAppRunner(AppRunner): memory=memory, message_id=message.id, inputs=inputs, + vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get( + "enabled", False + ), ) + context_files = retrieved_files or [] # reorganize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -186,6 +192,7 @@ class ChatAppRunner(AppRunner): context=context, memory=memory, image_detail_config=image_detail_config, + context_files=context_files, ) # check hosting moderation diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 14795a430c..38ecec5d30 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -1,3 +1,4 @@ +import logging import time from collections.abc import Mapping, Sequence from dataclasses import dataclass @@ -55,6 +56,7 @@ from models import Account, EndUser from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator NodeExecutionId = NewType("NodeExecutionId", str) +logger = logging.getLogger(__name__) @dataclass(slots=True) @@ -289,26 +291,30 @@ class WorkflowResponseConverter: ), ) - if event.node_type == NodeType.TOOL: - response.data.extras["icon"] = ToolManager.get_tool_icon( - tenant_id=self._application_generate_entity.app_config.tenant_id, - provider_type=ToolProviderType(event.provider_type), - provider_id=event.provider_id, - ) - elif event.node_type == NodeType.DATASOURCE: - manager = PluginDatasourceManager() - provider_entity = manager.fetch_datasource_provider( - self._application_generate_entity.app_config.tenant_id, - event.provider_id, - ) - response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url( - self._application_generate_entity.app_config.tenant_id - ) - elif event.node_type == NodeType.TRIGGER_PLUGIN: - response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon( - self._application_generate_entity.app_config.tenant_id, - event.provider_id, - ) + try: + if event.node_type == NodeType.TOOL: + response.data.extras["icon"] = ToolManager.get_tool_icon( + tenant_id=self._application_generate_entity.app_config.tenant_id, + provider_type=ToolProviderType(event.provider_type), + provider_id=event.provider_id, + ) + elif event.node_type == NodeType.DATASOURCE: + manager = PluginDatasourceManager() + provider_entity = manager.fetch_datasource_provider( + self._application_generate_entity.app_config.tenant_id, + event.provider_id, + ) + response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url( + self._application_generate_entity.app_config.tenant_id + ) + elif event.node_type == NodeType.TRIGGER_PLUGIN: + response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon( + self._application_generate_entity.app_config.tenant_id, + event.provider_id, + ) + except Exception: + # metadata fetch may fail, for example, the plugin daemon is down or plugin is uninstalled. + logger.warning("failed to fetch icon for %s", event.provider_id) return response diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index e2be4146e1..ddfb5725b4 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import ( CompletionAppGenerateEntity, ) from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.file import File from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.moderation.base import ModerationError @@ -102,6 +103,7 @@ class CompletionAppRunner(AppRunner): # get context from datasets context = None + context_files: list[File] = [] if app_config.dataset and app_config.dataset.dataset_ids: hit_callback = DatasetIndexToolCallbackHandler( queue_manager, @@ -116,7 +118,7 @@ class CompletionAppRunner(AppRunner): query = inputs.get(dataset_config.retrieve_config.query_variable, "") dataset_retrieval = DatasetRetrieval(application_generate_entity) - context = dataset_retrieval.retrieve( + context, retrieved_files = dataset_retrieval.retrieve( app_id=app_record.id, user_id=application_generate_entity.user_id, tenant_id=app_record.tenant_id, @@ -130,7 +132,11 @@ class CompletionAppRunner(AppRunner): hit_callback=hit_callback, message_id=message.id, inputs=inputs, + vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get( + "enabled", False + ), ) + context_files = retrieved_files or [] # reorganize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -144,6 +150,7 @@ class CompletionAppRunner(AppRunner): query=query, context=context, image_detail_config=image_detail_config, + context_files=context_files, ) # check hosting moderation diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 53e67fd578..57617d8863 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -156,79 +156,86 @@ class MessageBasedAppGenerator(BaseAppGenerator): query = application_generate_entity.query or "New conversation" conversation_name = (query[:20] + "…") if len(query) > 20 else query - if not conversation: - conversation = Conversation( + try: + if not conversation: + conversation = Conversation( + app_id=app_config.app_id, + app_model_config_id=app_model_config_id, + model_provider=model_provider, + model_id=model_id, + override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, + mode=app_config.app_mode.value, + name=conversation_name, + inputs=application_generate_entity.inputs, + introduction=introduction, + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from=application_generate_entity.invoke_from.value, + from_source=from_source, + from_end_user_id=end_user_id, + from_account_id=account_id, + ) + + db.session.add(conversation) + db.session.flush() + db.session.refresh(conversation) + else: + conversation.updated_at = naive_utc_now() + + message = Message( app_id=app_config.app_id, - app_model_config_id=app_model_config_id, model_provider=model_provider, model_id=model_id, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, - mode=app_config.app_mode.value, - name=conversation_name, + conversation_id=conversation.id, inputs=application_generate_entity.inputs, - introduction=introduction, - system_instruction="", - system_instruction_tokens=0, - status="normal", + query=application_generate_entity.query, + message="", + message_tokens=0, + message_unit_price=0, + message_price_unit=0, + answer="", + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0, + parent_message_id=getattr(application_generate_entity, "parent_message_id", None), + provider_response_latency=0, + total_price=0, + currency="USD", invoke_from=application_generate_entity.invoke_from.value, from_source=from_source, from_end_user_id=end_user_id, from_account_id=account_id, + app_mode=app_config.app_mode, ) - db.session.add(conversation) + db.session.add(message) + db.session.flush() + db.session.refresh(message) + + message_files = [] + for file in application_generate_entity.files: + message_file = MessageFile( + message_id=message.id, + type=file.type, + transfer_method=file.transfer_method, + belongs_to="user", + url=file.remote_url, + upload_file_id=file.related_id, + created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER), + created_by=account_id or end_user_id or "", + ) + message_files.append(message_file) + + if message_files: + db.session.add_all(message_files) + db.session.commit() - db.session.refresh(conversation) - else: - conversation.updated_at = naive_utc_now() - db.session.commit() - - message = Message( - app_id=app_config.app_id, - model_provider=model_provider, - model_id=model_id, - override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, - conversation_id=conversation.id, - inputs=application_generate_entity.inputs, - query=application_generate_entity.query, - message="", - message_tokens=0, - message_unit_price=0, - message_price_unit=0, - answer="", - answer_tokens=0, - answer_unit_price=0, - answer_price_unit=0, - parent_message_id=getattr(application_generate_entity, "parent_message_id", None), - provider_response_latency=0, - total_price=0, - currency="USD", - invoke_from=application_generate_entity.invoke_from.value, - from_source=from_source, - from_end_user_id=end_user_id, - from_account_id=account_id, - app_mode=app_config.app_mode, - ) - - db.session.add(message) - db.session.commit() - db.session.refresh(message) - - for file in application_generate_entity.files: - message_file = MessageFile( - message_id=message.id, - type=file.type, - transfer_method=file.transfer_method, - belongs_to="user", - url=file.remote_url, - upload_file_id=file.related_id, - created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER), - created_by=account_id or end_user_id or "", - ) - db.session.add(message_file) - db.session.commit() - - return conversation, message + return conversation, message + except Exception: + db.session.rollback() + raise def _get_conversation_introduction(self, application_generate_entity: AppGenerateEntity) -> str: """ diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index d8460df390..894e6f397a 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -18,6 +18,7 @@ from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_redis import redis_client +from extensions.otel import WorkflowAppRunnerHandler, trace_span from libs.datetime_utils import naive_utc_now from models.enums import UserFrom from models.workflow import Workflow @@ -56,6 +57,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): self._workflow_execution_repository = workflow_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository + @trace_span(WorkflowAppRunnerHandler) def run(self): """ Run application diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 7692128985..79a5e657b3 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -40,9 +40,6 @@ class EasyUITaskState(TaskState): """ llm_result: LLMResult - first_token_time: float | None = None - last_token_time: float | None = None - is_streaming_response: bool = False class WorkflowTaskState(TaskState): diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index c49db9aad1..5c169f4db1 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -332,12 +332,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if not self._task_state.llm_result.prompt_messages: self._task_state.llm_result.prompt_messages = chunk.prompt_messages - # Track streaming response times - if self._task_state.first_token_time is None: - self._task_state.first_token_time = time.perf_counter() - self._task_state.is_streaming_response = True - self._task_state.last_token_time = time.perf_counter() - # handle output moderation chunk should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text)) if should_direct_answer: @@ -366,7 +360,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if publisher: publisher.publish(None) if self._conversation_name_generate_thread: - self._conversation_name_generate_thread.join() + logger.debug("Conversation name generation running as daemon thread") def _save_message(self, *, session: Session, trace_manager: TraceQueueManager | None = None): """ @@ -404,18 +398,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): message.total_price = usage.total_price message.currency = usage.currency self._task_state.llm_result.usage.latency = message.provider_response_latency - - # Add streaming metrics to usage if available - if self._task_state.is_streaming_response and self._task_state.first_token_time: - start_time = self.start_at - first_token_time = self._task_state.first_token_time - last_token_time = self._task_state.last_token_time or first_token_time - usage.time_to_first_token = round(first_token_time - start_time, 3) - usage.time_to_generate = round(last_token_time - first_token_time, 3) - - # Update metadata with the complete usage info - self._task_state.metadata.usage = usage - message.message_metadata = self._task_state.metadata.model_dump_json() if trace_manager: diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index e7daeb4a32..2e6f92efa5 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -1,4 +1,6 @@ +import hashlib import logging +import time from threading import Thread from typing import Union @@ -31,6 +33,7 @@ from core.app.entities.task_entities import ( from core.llm_generator.llm_generator import LLMGenerator from core.tools.signature import sign_tool_file from extensions.ext_database import db +from extensions.ext_redis import redis_client from models.model import AppMode, Conversation, MessageAnnotation, MessageFile from services.annotation_service import AppAnnotationService @@ -68,6 +71,8 @@ class MessageCycleManager: if auto_generate_conversation_name and is_first_message: # start generate thread + # time.sleep not block other logic + time.sleep(1) thread = Thread( target=self._generate_conversation_name_worker, kwargs={ @@ -76,7 +81,7 @@ class MessageCycleManager: "query": query, }, ) - + thread.daemon = True thread.start() return thread @@ -98,15 +103,23 @@ class MessageCycleManager: return # generate conversation name - try: - name = LLMGenerator.generate_conversation_name( - app_model.tenant_id, query, conversation_id, conversation.app_id - ) - conversation.name = name - except Exception: - if dify_config.DEBUG: - logger.exception("generate conversation name failed, conversation_id: %s", conversation_id) + query_hash = hashlib.md5(query.encode()).hexdigest()[:16] + cache_key = f"conv_name:{conversation_id}:{query_hash}" + cached_name = redis_client.get(cache_key) + if cached_name: + name = cached_name.decode("utf-8") + else: + try: + name = LLMGenerator.generate_conversation_name( + app_model.tenant_id, query, conversation_id, conversation.app_id + ) + redis_client.setex(cache_key, 3600, name) + except Exception: + if dify_config.DEBUG: + logger.exception("generate conversation name failed, conversation_id: %s", conversation_id) + name = query[:47] + "..." if len(query) > 50 else query + conversation.name = name db.session.commit() db.session.close() diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 14d5f38dcd..d0279349ca 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -7,7 +7,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import ChildChunk, DatasetQuery, DocumentSegment @@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler: document_id, ) continue - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunk_stmt = select(ChildChunk).where( ChildChunk.index_node_id == document.metadata["doc_id"], ChildChunk.dataset_id == dataset_document.dataset_id, diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index b9ca7414dc..bed3a35400 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field class PreviewDetail(BaseModel): @@ -20,7 +20,7 @@ class IndexingEstimate(BaseModel): class PipelineDataset(BaseModel): id: str name: str - description: str + description: str | None = Field(default="", description="knowledge dataset description") chunk_structure: str diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index 7484cea04a..7fdf5e4be6 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -213,12 +213,23 @@ class MCPProviderEntity(BaseModel): return None def retrieve_tokens(self) -> OAuthTokens | None: - """OAuth tokens if available""" + """Retrieve OAuth tokens if authentication is complete. + + Returns: + OAuthTokens if the provider has been authenticated, None otherwise. + """ if not self.credentials: return None credentials = self.decrypt_credentials() + access_token = credentials.get("access_token", "") + # Return None if access_token is empty to avoid generating invalid "Authorization: Bearer " header. + # Note: We don't check for whitespace-only strings here because: + # 1. OAuth servers don't return whitespace-only access tokens in practice + # 2. Even if they did, the server would return 401, triggering the OAuth flow correctly + if not access_token: + return None return OAuthTokens( - access_token=credentials.get("access_token", ""), + access_token=access_token, token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE), expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN), refresh_token=credentials.get("refresh_token", ""), diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 663a8164c6..12431976f0 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -29,6 +29,7 @@ class SimpleModelProviderEntity(BaseModel): provider: str label: I18nObject icon_small: I18nObject | None = None + icon_small_dark: I18nObject | None = None icon_large: I18nObject | None = None supported_model_types: list[ModelType] @@ -42,6 +43,7 @@ class SimpleModelProviderEntity(BaseModel): provider=provider_entity.provider, label=provider_entity.label, icon_small=provider_entity.icon_small, + icon_small_dark=provider_entity.icon_small_dark, icon_large=provider_entity.icon_large, supported_model_types=provider_entity.supported_model_types, ) diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 56c133e598..e8d41b9387 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -253,7 +253,7 @@ class ProviderConfiguration(BaseModel): try: credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key]) except Exception: - pass + logger.exception("Failed to decrypt credential secret variable %s", key) return self.obfuscated_credentials( credentials=credentials, @@ -765,7 +765,7 @@ class ProviderConfiguration(BaseModel): try: credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key]) except Exception: - pass + logger.exception("Failed to decrypt model credential secret variable %s", key) current_credential_id = credential_record.id current_credential_name = credential_record.credential_name diff --git a/api/core/helper/marketplace.py b/api/core/helper/marketplace.py index b2286d39ed..25dc4ba9ed 100644 --- a/api/core/helper/marketplace.py +++ b/api/core/helper/marketplace.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Sequence import httpx @@ -8,6 +9,7 @@ from core.helper.download import download_with_size_limit from core.plugin.entities.marketplace import MarketplacePluginDeclaration marketplace_api_url = URL(str(dify_config.MARKETPLACE_API_URL)) +logger = logging.getLogger(__name__) def get_plugin_pkg_url(plugin_unique_identifier: str) -> str: @@ -55,7 +57,9 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error( try: result.append(MarketplacePluginDeclaration.model_validate(plugin)) except Exception: - pass + logger.exception( + "Failed to deserialize marketplace plugin manifest for %s", plugin.get("plugin_id", "unknown") + ) return result diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py new file mode 100644 index 0000000000..eef5937407 --- /dev/null +++ b/api/core/helper/tool_provider_cache.py @@ -0,0 +1,56 @@ +import json +import logging +from typing import Any + +from core.tools.entities.api_entities import ToolProviderTypeApiLiteral +from extensions.ext_redis import redis_client, redis_fallback + +logger = logging.getLogger(__name__) + + +class ToolProviderListCache: + """Cache for tool provider lists""" + + CACHE_TTL = 300 # 5 minutes + + @staticmethod + def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str: + """Generate cache key for tool providers list""" + type_filter = typ or "all" + return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}" + + @staticmethod + @redis_fallback(default_return=None) + def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None: + """Get cached tool providers""" + cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) + cached_data = redis_client.get(cache_key) + if cached_data: + try: + return json.loads(cached_data.decode("utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError): + logger.warning("Failed to decode cached tool providers data") + return None + return None + + @staticmethod + @redis_fallback() + def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]): + """Cache tool providers""" + cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) + redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers)) + + @staticmethod + @redis_fallback() + def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None): + """Invalidate cache for tool providers""" + if typ: + # Invalidate specific type cache + cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) + redis_client.delete(cache_key) + else: + # Invalidate all caches for this tenant + pattern = f"tool_providers:tenant_id:{tenant_id}:*" + keys = list(redis_client.scan_iter(pattern)) + if keys: + redis_client.delete(*keys) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 36b38b7b45..59de4f403d 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -7,7 +7,7 @@ import time import uuid from typing import Any -from flask import current_app +from flask import Flask, current_app from sqlalchemy import select from sqlalchemy.orm.exc import ObjectDeletedError @@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import ChildDocument, Document @@ -36,6 +36,7 @@ from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper from libs.datetime_utils import naive_utc_now +from models import Account from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import UploadFile @@ -89,8 +90,17 @@ class IndexingRunner: text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform + current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first() + if not current_user: + raise ValueError("no current user found") + current_user.set_tenant_id(dataset.tenant_id) documents = self._transform( - index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict() + index_processor, + dataset, + text_docs, + requeried_document.doc_language, + processing_rule.to_dict(), + current_user=current_user, ) # save segment self._load_segments(dataset, requeried_document, documents) @@ -136,7 +146,7 @@ class IndexingRunner: for document_segment in document_segments: db.session.delete(document_segment) - if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: # delete child chunks db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() db.session.commit() @@ -152,8 +162,17 @@ class IndexingRunner: text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform + current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first() + if not current_user: + raise ValueError("no current user found") + current_user.set_tenant_id(dataset.tenant_id) documents = self._transform( - index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict() + index_processor, + dataset, + text_docs, + requeried_document.doc_language, + processing_rule.to_dict(), + current_user=current_user, ) # save segment self._load_segments(dataset, requeried_document, documents) @@ -209,7 +228,7 @@ class IndexingRunner: "dataset_id": document_segment.dataset_id, }, ) - if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = document_segment.get_child_chunks() if child_chunks: child_documents = [] @@ -302,6 +321,7 @@ class IndexingRunner: text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) documents = index_processor.transform( text_docs, + current_user=None, embedding_model_instance=embedding_model_instance, process_rule=processing_rule.to_dict(), tenant_id=tenant_id, @@ -551,7 +571,10 @@ class IndexingRunner: indexing_start_at = time.perf_counter() tokens = 0 create_keyword_thread = None - if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": + if ( + dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX + and dataset.indexing_technique == "economy" + ): # create keyword index create_keyword_thread = threading.Thread( target=self._process_keyword_index, @@ -590,7 +613,7 @@ class IndexingRunner: for future in futures: tokens += future.result() if ( - dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX + dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy" and create_keyword_thread is not None ): @@ -635,7 +658,13 @@ class IndexingRunner: db.session.commit() def _process_chunk( - self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance + self, + flask_app: Flask, + index_processor: BaseIndexProcessor, + chunk_documents: list[Document], + dataset: Dataset, + dataset_document: DatasetDocument, + embedding_model_instance: ModelInstance | None, ): with flask_app.app_context(): # check document is paused @@ -646,8 +675,15 @@ class IndexingRunner: page_content_list = [document.page_content for document in chunk_documents] tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list)) + multimodal_documents = [] + for document in chunk_documents: + if document.attachments and dataset.is_multimodal: + multimodal_documents.extend(document.attachments) + # load index - index_processor.load(dataset, chunk_documents, with_keywords=False) + index_processor.load( + dataset, chunk_documents, multimodal_documents=multimodal_documents, with_keywords=False + ) document_ids = [document.metadata["doc_id"] for document in chunk_documents] db.session.query(DocumentSegment).where( @@ -710,6 +746,7 @@ class IndexingRunner: text_docs: list[Document], doc_language: str, process_rule: dict, + current_user: Account | None = None, ) -> list[Document]: # get embedding model instance embedding_model_instance = None @@ -729,6 +766,7 @@ class IndexingRunner: documents = index_processor.transform( text_docs, + current_user, embedding_model_instance=embedding_model_instance, process_rule=process_rule, tenant_id=dataset.tenant_id, @@ -737,14 +775,16 @@ class IndexingRunner: return documents - def _load_segments(self, dataset, dataset_document, documents): + def _load_segments(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]): # save node to document segment doc_store = DatasetDocumentStore( dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id ) # add document segments - doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX) + doc_store.add_documents( + docs=documents, save_child=dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX + ) # update document status to indexing cur_time = naive_utc_now() diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index bd893b17f1..4a577e6c38 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -15,6 +15,8 @@ from core.llm_generator.prompts import ( LLM_MODIFY_CODE_SYSTEM, LLM_MODIFY_PROMPT_SYSTEM, PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE, + SUGGESTED_QUESTIONS_MAX_TOKENS, + SUGGESTED_QUESTIONS_TEMPERATURE, SYSTEM_STRUCTURED_OUTPUT_GENERATE, WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, ) @@ -124,7 +126,10 @@ class LLMGenerator: try: response: LLMResult = model_instance.invoke_llm( prompt_messages=list(prompt_messages), - model_parameters={"max_tokens": 256, "temperature": 0}, + model_parameters={ + "max_tokens": SUGGESTED_QUESTIONS_MAX_TOKENS, + "temperature": SUGGESTED_QUESTIONS_TEMPERATURE, + }, stream=False, ) @@ -549,11 +554,16 @@ class LLMGenerator: prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False ) - generated_raw = cast(str, response.message.content) + generated_raw = response.message.get_text_content() first_brace = generated_raw.find("{") last_brace = generated_raw.rfind("}") - return {**json.loads(generated_raw[first_brace : last_brace + 1])} - + if first_brace == -1 or last_brace == -1 or last_brace < first_brace: + raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}") + json_str = generated_raw[first_brace : last_brace + 1] + data = json_repair.loads(json_str) + if not isinstance(data, dict): + raise TypeError(f"Expected a JSON object, but got {type(data).__name__}") + return data except InvokeError as e: error = str(e) return {"error": f"Failed to generate code. Error: {error}"} diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index 9268347526..ec2b7f2d44 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -1,4 +1,6 @@ # Written by YORKI MINAKO🤡, Edited by Xiaoyi, Edited by yasu-oh +import os + CONVERSATION_TITLE_PROMPT = """You are asked to generate a concise chat title by decomposing the user’s input into two parts: “Intention” and “Subject”. 1. Detect Input Language @@ -94,7 +96,8 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = ( ) -SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( +# Default prompt for suggested questions (can be overridden by environment variable) +_DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT = ( "Please help me predict the three most likely questions that human would ask, " "and keep each question under 20 characters.\n" "MAKE SURE your output is the SAME language as the Assistant's latest response. " @@ -102,6 +105,15 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( '["question1","question2","question3"]\n' ) +# Environment variable override for suggested questions prompt +SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = os.getenv( + "SUGGESTED_QUESTIONS_PROMPT", _DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT +) + +# Configurable LLM parameters for suggested questions (can be overridden by environment variables) +SUGGESTED_QUESTIONS_MAX_TOKENS = int(os.getenv("SUGGESTED_QUESTIONS_MAX_TOKENS", "256")) +SUGGESTED_QUESTIONS_TEMPERATURE = float(os.getenv("SUGGESTED_QUESTIONS_TEMPERATURE", "0")) + GENERATOR_QA_PROMPT = ( " The user will send a long text. Generate a Question and Answer pairs only using the knowledge" " in the long text. Please think step by step." diff --git a/api/core/model_manager.py b/api/core/model_manager.py index a63e94d59c..5a28bbcc3a 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -10,9 +10,9 @@ from core.errors.error import ProviderTokenNotInitError from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.entities.rerank_entities import RerankResult -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.moderation_model import ModerationModel @@ -200,7 +200,7 @@ class ModelInstance: def invoke_text_embedding( self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT - ) -> TextEmbeddingResult: + ) -> EmbeddingResult: """ Invoke large language model @@ -212,7 +212,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return cast( - TextEmbeddingResult, + EmbeddingResult, self._round_robin_invoke( function=self.model_type_instance.invoke, model=self.model, @@ -223,6 +223,34 @@ class ModelInstance: ), ) + def invoke_multimodal_embedding( + self, + multimodel_documents: list[dict], + user: str | None = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> EmbeddingResult: + """ + Invoke large language model + + :param multimodel_documents: multimodel documents to embed + :param user: unique user id + :param input_type: input type + :return: embeddings result + """ + if not isinstance(self.model_type_instance, TextEmbeddingModel): + raise Exception("Model type instance is not TextEmbeddingModel") + return cast( + EmbeddingResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + multimodel_documents=multimodel_documents, + user=user, + input_type=input_type, + ), + ) + def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]: """ Get number of tokens for text embedding @@ -276,6 +304,40 @@ class ModelInstance: ), ) + def invoke_multimodal_rerank( + self, + query: dict, + docs: list[dict], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + if not isinstance(self.model_type_instance, RerankModel): + raise Exception("Model type instance is not RerankModel") + return cast( + RerankResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke_multimodal_rerank, + model=self.model, + credentials=self.credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + user=user, + ), + ) + def invoke_moderation(self, text: str, user: str | None = None) -> bool: """ Invoke moderation model @@ -461,6 +523,32 @@ class ModelManager: model=default_model_entity.model, ) + def check_model_support_vision(self, tenant_id: str, provider: str, model: str, model_type: ModelType) -> bool: + """ + Check if model supports vision + :param tenant_id: tenant id + :param provider: provider name + :param model: model name + :return: True if model supports vision, False otherwise + """ + model_instance = self.get_model_instance(tenant_id, provider, model_type, model) + model_type_instance = model_instance.model_type_instance + match model_type: + case ModelType.LLM: + model_type_instance = cast(LargeLanguageModel, model_type_instance) + case ModelType.TEXT_EMBEDDING: + model_type_instance = cast(TextEmbeddingModel, model_type_instance) + case ModelType.RERANK: + model_type_instance = cast(RerankModel, model_type_instance) + case _: + raise ValueError(f"Model type {model_type} is not supported") + model_schema = model_type_instance.get_model_schema(model, model_instance.credentials) + if not model_schema: + return False + if model_schema.features and ModelFeature.VISION in model_schema.features: + return True + return False + class LBModelManager: def __init__( diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index 0508116962..648b209ef1 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -99,6 +99,7 @@ class SimpleProviderEntity(BaseModel): provider: str label: I18nObject icon_small: I18nObject | None = None + icon_small_dark: I18nObject | None = None icon_large: I18nObject | None = None supported_model_types: Sequence[ModelType] models: list[AIModelEntity] = [] @@ -124,7 +125,6 @@ class ProviderEntity(BaseModel): icon_small: I18nObject | None = None icon_large: I18nObject | None = None icon_small_dark: I18nObject | None = None - icon_large_dark: I18nObject | None = None background: str | None = None help: ProviderHelpEntity | None = None supported_model_types: Sequence[ModelType] diff --git a/api/core/model_runtime/entities/text_embedding_entities.py b/api/core/model_runtime/entities/text_embedding_entities.py index 846b89d658..854c448250 100644 --- a/api/core/model_runtime/entities/text_embedding_entities.py +++ b/api/core/model_runtime/entities/text_embedding_entities.py @@ -19,7 +19,7 @@ class EmbeddingUsage(ModelUsage): latency: float -class TextEmbeddingResult(BaseModel): +class EmbeddingResult(BaseModel): """ Model class for text embedding result. """ @@ -27,3 +27,13 @@ class TextEmbeddingResult(BaseModel): model: str embeddings: list[list[float]] usage: EmbeddingUsage + + +class FileEmbeddingResult(BaseModel): + """ + Model class for file embedding result. + """ + + model: str + embeddings: list[list[float]] + usage: EmbeddingUsage diff --git a/api/core/model_runtime/model_providers/__base/rerank_model.py b/api/core/model_runtime/model_providers/__base/rerank_model.py index 36067118b0..0a576b832a 100644 --- a/api/core/model_runtime/model_providers/__base/rerank_model.py +++ b/api/core/model_runtime/model_providers/__base/rerank_model.py @@ -50,3 +50,43 @@ class RerankModel(AIModel): ) except Exception as e: raise self._transform_invoke_error(e) + + def invoke_multimodal_rerank( + self, + model: str, + credentials: dict, + query: dict, + docs: list[dict], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + ) -> RerankResult: + """ + Invoke multimodal rerank model + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + try: + from core.plugin.impl.model import PluginModelClient + + plugin_model_manager = PluginModelClient() + return plugin_model_manager.invoke_multimodal_rerank( + tenant_id=self.tenant_id, + user_id=user or "unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model=model, + credentials=credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + ) + except Exception as e: + raise self._transform_invoke_error(e) diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index bd68ffe903..4c902e2c11 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -2,7 +2,7 @@ from pydantic import ConfigDict from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult from core.model_runtime.model_providers.__base.ai_model import AIModel @@ -20,16 +20,18 @@ class TextEmbeddingModel(AIModel): self, model: str, credentials: dict, - texts: list[str], + texts: list[str] | None = None, + multimodel_documents: list[dict] | None = None, user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, - ) -> TextEmbeddingResult: + ) -> EmbeddingResult: """ Invoke text embedding model :param model: model name :param credentials: model credentials :param texts: texts to embed + :param files: files to embed :param user: unique user id :param input_type: input type :return: embeddings result @@ -38,16 +40,29 @@ class TextEmbeddingModel(AIModel): try: plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_text_embedding( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - texts=texts, - input_type=input_type, - ) + if texts: + return plugin_model_manager.invoke_text_embedding( + tenant_id=self.tenant_id, + user_id=user or "unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model=model, + credentials=credentials, + texts=texts, + input_type=input_type, + ) + if multimodel_documents: + return plugin_model_manager.invoke_multimodal_embedding( + tenant_id=self.tenant_id, + user_id=user or "unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model=model, + credentials=credentials, + documents=multimodel_documents, + input_type=input_type, + ) + raise ValueError("No texts or files provided") except Exception as e: raise self._transform_invoke_error(e) diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index e1afc41bee..b8704ef4ed 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -300,6 +300,14 @@ class ModelProviderFactory: file_name = provider_schema.icon_small.zh_Hans else: file_name = provider_schema.icon_small.en_US + elif icon_type.lower() == "icon_small_dark": + if not provider_schema.icon_small_dark: + raise ValueError(f"Provider {provider} does not have small dark icon.") + + if lang.lower() == "zh_hans": + file_name = provider_schema.icon_small_dark.zh_Hans + else: + file_name = provider_schema.icon_small_dark.en_US else: if not provider_schema.icon_large: raise ValueError(f"Provider {provider} does not have large icon.") diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index a7d8576d8d..d6bd4d2015 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -296,7 +296,7 @@ class AliyunDataTrace(BaseTraceInstance): node_span = self.build_workflow_task_span(trace_info, node_execution, trace_metadata) return node_span except Exception as e: - logger.debug("Error occurred in build_workflow_node_span: %s", e, exc_info=True) + logger.warning("Error occurred in build_workflow_node_span: %s", e, exc_info=True) return None def build_workflow_task_span( diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py index 5aa9fb6689..d3324f8f82 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py @@ -21,6 +21,7 @@ from opentelemetry.trace import Link, SpanContext, TraceFlags from configs import dify_config from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData +from core.ops.aliyun_trace.entities.semconv import ACS_ARMS_SERVICE_FEATURE INVALID_SPAN_ID: Final[int] = 0x0000000000000000 INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000 @@ -48,6 +49,7 @@ class TraceClient: ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}", ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}", ResourceAttributes.HOST_NAME: socket.gethostname(), + ACS_ARMS_SERVICE_FEATURE: "genai_app", } ) self.span_builder = SpanBuilder(self.resource) @@ -75,10 +77,10 @@ class TraceClient: if response.status_code == 405: return True else: - logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code) + logger.warning("AliyunTrace API check failed: Unexpected status code: %s", response.status_code) return False except httpx.RequestError as e: - logger.debug("AliyunTrace API check failed: %s", str(e)) + logger.warning("AliyunTrace API check failed: %s", str(e)) raise ValueError(f"AliyunTrace API check failed: {str(e)}") def get_project_url(self) -> str: @@ -116,7 +118,7 @@ class TraceClient: try: self.exporter.export(spans_to_export) except Exception as e: - logger.debug("Error exporting spans: %s", e) + logger.warning("Error exporting spans: %s", e) def shutdown(self) -> None: with self.condition: diff --git a/api/core/ops/aliyun_trace/entities/semconv.py b/api/core/ops/aliyun_trace/entities/semconv.py index c823fcab8a..aff893816c 100644 --- a/api/core/ops/aliyun_trace/entities/semconv.py +++ b/api/core/ops/aliyun_trace/entities/semconv.py @@ -1,6 +1,8 @@ from enum import StrEnum from typing import Final +ACS_ARMS_SERVICE_FEATURE: Final[str] = "acs.arms.service.feature" + # Public attributes GEN_AI_SESSION_ID: Final[str] = "gen_ai.session.id" GEN_AI_USER_ID: Final[str] = "gen_ai.user.id" diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index ce2b0239cd..f45f15a6da 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -377,20 +377,20 @@ class OpsTraceManager: return app_model_config @classmethod - def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str): + def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str | None): """ Update app tracing config :param app_id: app id :param enabled: enabled - :param tracing_provider: tracing provider + :param tracing_provider: tracing provider (None when disabling) :return: """ # auth check - try: - if enabled or tracing_provider is not None: + if tracing_provider is not None: + try: provider_config_map[tracing_provider] - except KeyError: - raise ValueError(f"Invalid tracing provider: {tracing_provider}") + except KeyError: + raise ValueError(f"Invalid tracing provider: {tracing_provider}") app_config: App | None = db.session.query(App).where(App.id == app_id).first() if not app_config: diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/core/ops/tencent_trace/span_builder.py index db92e9b8bd..26e8779e3e 100644 --- a/api/core/ops/tencent_trace/span_builder.py +++ b/api/core/ops/tencent_trace/span_builder.py @@ -222,59 +222,6 @@ class TencentSpanBuilder: links=links, ) - @staticmethod - def build_message_llm_span( - trace_info: MessageTraceInfo, trace_id: int, parent_span_id: int, user_id: str - ) -> SpanData: - """Build LLM span for message traces with detailed LLM attributes.""" - status = Status(StatusCode.OK) - if trace_info.error: - status = Status(StatusCode.ERROR, trace_info.error) - - # Extract model information from `metadata`` or `message_data` - trace_metadata = trace_info.metadata or {} - message_data = trace_info.message_data or {} - - model_provider = trace_metadata.get("ls_provider") or ( - message_data.get("model_provider", "") if isinstance(message_data, dict) else "" - ) - model_name = trace_metadata.get("ls_model_name") or ( - message_data.get("model_id", "") if isinstance(message_data, dict) else "" - ) - - inputs_str = str(trace_info.inputs or "") - outputs_str = str(trace_info.outputs or "") - - attributes = { - GEN_AI_SESSION_ID: trace_metadata.get("conversation_id", ""), - GEN_AI_USER_ID: str(user_id), - GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value, - GEN_AI_FRAMEWORK: "dify", - GEN_AI_MODEL_NAME: str(model_name), - GEN_AI_PROVIDER: str(model_provider), - GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens or 0), - GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens or 0), - GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens or 0), - GEN_AI_PROMPT: inputs_str, - GEN_AI_COMPLETION: outputs_str, - INPUT_VALUE: inputs_str, - OUTPUT_VALUE: outputs_str, - } - - if trace_info.is_streaming_request: - attributes[GEN_AI_IS_STREAMING_REQUEST] = "true" - - return SpanData( - trace_id=trace_id, - parent_span_id=parent_span_id, - span_id=TencentTraceUtils.convert_to_span_id(trace_info.message_id, "llm"), - name="GENERATION", - start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time), - end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time), - attributes=attributes, - status=status, - ) - @staticmethod def build_tool_span(trace_info: ToolTraceInfo, trace_id: int, parent_span_id: int) -> SpanData: """Build tool span.""" diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/core/ops/tencent_trace/tencent_trace.py index 3d176da97a..93ec186863 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/core/ops/tencent_trace/tencent_trace.py @@ -107,12 +107,8 @@ class TencentDataTrace(BaseTraceInstance): links.append(TencentTraceUtils.create_link(trace_info.trace_id)) message_span = TencentSpanBuilder.build_message_span(trace_info, trace_id, str(user_id), links) - self.trace_client.add_span(message_span) - # Add LLM child span with detailed attributes - parent_span_id = TencentTraceUtils.convert_to_span_id(trace_info.message_id, "message") - llm_span = TencentSpanBuilder.build_message_llm_span(trace_info, trace_id, parent_span_id, str(user_id)) - self.trace_client.add_span(llm_span) + self.trace_client.add_span(message_span) self._record_message_llm_metrics(trace_info) @@ -521,4 +517,4 @@ class TencentDataTrace(BaseTraceInstance): if hasattr(self, "trace_client"): self.trace_client.shutdown() except Exception: - pass + logger.exception("[Tencent APM] Failed to shutdown trace client during cleanup") diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 5dfc3c212e..5d70980967 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -6,7 +6,7 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.rerank_entities import RerankResult -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, @@ -243,14 +243,14 @@ class PluginModelClient(BasePluginClient): credentials: dict, texts: list[str], input_type: str, - ) -> TextEmbeddingResult: + ) -> EmbeddingResult: """ Invoke text embedding """ response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke", - type_=TextEmbeddingResult, + type_=EmbeddingResult, data=jsonable_encoder( { "user_id": user_id, @@ -275,6 +275,48 @@ class PluginModelClient(BasePluginClient): raise ValueError("Failed to invoke text embedding") + def invoke_multimodal_embedding( + self, + tenant_id: str, + user_id: str, + plugin_id: str, + provider: str, + model: str, + credentials: dict, + documents: list[dict], + input_type: str, + ) -> EmbeddingResult: + """ + Invoke file embedding + """ + response = self._request_with_plugin_daemon_response_stream( + method="POST", + path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke", + type_=EmbeddingResult, + data=jsonable_encoder( + { + "user_id": user_id, + "data": { + "provider": provider, + "model_type": "text-embedding", + "model": model, + "credentials": credentials, + "documents": documents, + "input_type": input_type, + }, + } + ), + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + + for resp in response: + return resp + + raise ValueError("Failed to invoke file embedding") + def get_text_embedding_num_tokens( self, tenant_id: str, @@ -361,6 +403,51 @@ class PluginModelClient(BasePluginClient): raise ValueError("Failed to invoke rerank") + def invoke_multimodal_rerank( + self, + tenant_id: str, + user_id: str, + plugin_id: str, + provider: str, + model: str, + credentials: dict, + query: dict, + docs: list[dict], + score_threshold: float | None = None, + top_n: int | None = None, + ) -> RerankResult: + """ + Invoke multimodal rerank + """ + response = self._request_with_plugin_daemon_response_stream( + method="POST", + path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke", + type_=RerankResult, + data=jsonable_encoder( + { + "user_id": user_id, + "data": { + "provider": provider, + "model_type": "rerank", + "model": model, + "credentials": credentials, + "query": query, + "docs": docs, + "score_threshold": score_threshold, + "top_n": top_n, + }, + } + ), + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + for resp in response: + return resp + + raise ValueError("Failed to invoke multimodal rerank") + def invoke_tts( self, tenant_id: str, diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index d1d518a55d..f072092ea7 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -49,6 +49,7 @@ class SimplePromptTransform(PromptTransform): memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: inputs = {key: str(value) for key, value in inputs.items()} @@ -64,6 +65,7 @@ class SimplePromptTransform(PromptTransform): memory=memory, model_config=model_config, image_detail_config=image_detail_config, + context_files=context_files, ) else: prompt_messages, stops = self._get_completion_model_prompt_messages( @@ -76,6 +78,7 @@ class SimplePromptTransform(PromptTransform): memory=memory, model_config=model_config, image_detail_config=image_detail_config, + context_files=context_files, ) return prompt_messages, stops @@ -187,6 +190,7 @@ class SimplePromptTransform(PromptTransform): memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: prompt_messages: list[PromptMessage] = [] @@ -216,9 +220,9 @@ class SimplePromptTransform(PromptTransform): ) if query: - prompt_messages.append(self._get_last_user_message(query, files, image_detail_config)) + prompt_messages.append(self._get_last_user_message(query, files, image_detail_config, context_files)) else: - prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config)) + prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config, context_files)) return prompt_messages, None @@ -233,6 +237,7 @@ class SimplePromptTransform(PromptTransform): memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: # get prompt prompt, prompt_rules = self._get_prompt_str_and_rules( @@ -275,20 +280,27 @@ class SimplePromptTransform(PromptTransform): if stops is not None and len(stops) == 0: stops = None - return [self._get_last_user_message(prompt, files, image_detail_config)], stops + return [self._get_last_user_message(prompt, files, image_detail_config, context_files)], stops def _get_last_user_message( self, prompt: str, files: Sequence["File"], image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> UserPromptMessage: + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] if files: - prompt_message_contents: list[PromptMessageContentUnionTypes] = [] for file in files: prompt_message_contents.append( file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) ) + if context_files: + for file in context_files: + prompt_message_contents.append( + file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) + ) + if prompt_message_contents: prompt_message_contents.append(TextPromptMessageContent(data=prompt)) prompt_message = UserPromptMessage(content=prompt_message_contents) diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index cc946a72c3..bfa8781e9f 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -2,6 +2,7 @@ from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.data_post_processor.reorder import ReorderRunner +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner @@ -30,9 +31,10 @@ class DataPostProcessor: score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: if self.rerank_runner: - documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user) + documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user, query_type) if self.reorder_runner: documents = self.reorder_runner.run(documents) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 2290de19bc..a139fba4d0 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,23 +1,30 @@ import concurrent.futures from concurrent.futures import ThreadPoolExecutor +from typing import Any from flask import Flask, current_app from sqlalchemy import select from sqlalchemy.orm import Session, load_only from configs import dify_config +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector from core.rag.embedding.retrieval import RetrievalSegments from core.rag.entities.metadata_entities import MetadataCondition -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.tools.signature import sign_upload_file from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DocumentSegment +from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument +from models.model import UploadFile from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { @@ -37,14 +44,15 @@ class RetrievalService: retrieval_method: RetrievalMethod, dataset_id: str, query: str, - top_k: int, + top_k: int = 4, score_threshold: float | None = 0.0, reranking_model: dict | None = None, reranking_mode: str = "reranking_model", weights: dict | None = None, document_ids_filter: list[str] | None = None, + attachment_ids: list | None = None, ): - if not query: + if not query and not attachment_ids: return [] dataset = cls._get_dataset(dataset_id) if not dataset: @@ -56,69 +64,52 @@ class RetrievalService: # Optimize multithreading with thread pools with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore futures = [] - if retrieval_method == RetrievalMethod.KEYWORD_SEARCH: + retrieval_service = RetrievalService() + if query: futures.append( executor.submit( - cls.keyword_search, + retrieval_service._retrieve, flask_app=current_app._get_current_object(), # type: ignore - dataset_id=dataset_id, - query=query, - top_k=top_k, - all_documents=all_documents, - exceptions=exceptions, - document_ids_filter=document_ids_filter, - ) - ) - if RetrievalMethod.is_support_semantic_search(retrieval_method): - futures.append( - executor.submit( - cls.embedding_search, - flask_app=current_app._get_current_object(), # type: ignore - dataset_id=dataset_id, + retrieval_method=retrieval_method, + dataset=dataset, query=query, top_k=top_k, score_threshold=score_threshold, reranking_model=reranking_model, - all_documents=all_documents, - retrieval_method=retrieval_method, - exceptions=exceptions, + reranking_mode=reranking_mode, + weights=weights, document_ids_filter=document_ids_filter, + attachment_id=None, + all_documents=all_documents, + exceptions=exceptions, ) ) - if RetrievalMethod.is_support_fulltext_search(retrieval_method): - futures.append( - executor.submit( - cls.full_text_index_search, - flask_app=current_app._get_current_object(), # type: ignore - dataset_id=dataset_id, - query=query, - top_k=top_k, - score_threshold=score_threshold, - reranking_model=reranking_model, - all_documents=all_documents, - retrieval_method=retrieval_method, - exceptions=exceptions, - document_ids_filter=document_ids_filter, + if attachment_ids: + for attachment_id in attachment_ids: + futures.append( + executor.submit( + retrieval_service._retrieve, + flask_app=current_app._get_current_object(), # type: ignore + retrieval_method=retrieval_method, + dataset=dataset, + query=None, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + reranking_mode=reranking_mode, + weights=weights, + document_ids_filter=document_ids_filter, + attachment_id=attachment_id, + all_documents=all_documents, + exceptions=exceptions, + ) ) - ) - concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED) + + concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED) if exceptions: raise ValueError(";\n".join(exceptions)) - # Deduplicate documents for hybrid search to avoid duplicate chunks - if retrieval_method == RetrievalMethod.HYBRID_SEARCH: - all_documents = cls._deduplicate_documents(all_documents) - data_post_processor = DataPostProcessor( - str(dataset.tenant_id), reranking_mode, reranking_model, weights, False - ) - all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=top_k, - ) - return all_documents @classmethod @@ -223,6 +214,7 @@ class RetrievalService: retrieval_method: RetrievalMethod, exceptions: list, document_ids_filter: list[str] | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ): with flask_app.app_context(): try: @@ -231,14 +223,30 @@ class RetrievalService: raise ValueError("dataset not found") vector = Vector(dataset=dataset) - documents = vector.search_by_vector( - query, - search_type="similarity_score_threshold", - top_k=top_k, - score_threshold=score_threshold, - filter={"group_id": [dataset.id]}, - document_ids_filter=document_ids_filter, - ) + documents = [] + if query_type == QueryType.TEXT_QUERY: + documents.extend( + vector.search_by_vector( + query, + search_type="similarity_score_threshold", + top_k=top_k, + score_threshold=score_threshold, + filter={"group_id": [dataset.id]}, + document_ids_filter=document_ids_filter, + ) + ) + if query_type == QueryType.IMAGE_QUERY: + if not dataset.is_multimodal: + return + documents.extend( + vector.search_by_file( + file_id=query, + top_k=top_k, + score_threshold=score_threshold, + filter={"group_id": [dataset.id]}, + document_ids_filter=document_ids_filter, + ) + ) if documents: if ( @@ -250,14 +258,37 @@ class RetrievalService: data_post_processor = DataPostProcessor( str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False ) - all_documents.extend( - data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents), + if dataset.is_multimodal: + model_manager = ModelManager() + is_support_vision = model_manager.check_model_support_vision( + tenant_id=dataset.tenant_id, + provider=reranking_model.get("reranking_provider_name") or "", + model=reranking_model.get("reranking_model_name") or "", + model_type=ModelType.RERANK, + ) + if is_support_vision: + all_documents.extend( + data_post_processor.invoke( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents), + query_type=query_type, + ) + ) + else: + # not effective, return original documents + all_documents.extend(documents) + else: + all_documents.extend( + data_post_processor.invoke( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents), + query_type=query_type, + ) ) - ) else: all_documents.extend(documents) except Exception as e: @@ -339,103 +370,161 @@ class RetrievalService: records = [] include_segment_ids = set() segment_child_map = {} - - # Process documents - for document in documents: - document_id = document.metadata.get("document_id") - if document_id not in dataset_documents: - continue - - dataset_document = dataset_documents[document_id] - if not dataset_document: - continue - - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - # Handle parent-child documents - child_index_node_id = document.metadata.get("doc_id") - child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id) - child_chunk = db.session.scalar(child_chunk_stmt) - - if not child_chunk: + segment_file_map = {} + with Session(bind=db.engine, expire_on_commit=False) as session: + # Process documents + for document in documents: + segment_id = None + attachment_info = None + child_chunk = None + document_id = document.metadata.get("document_id") + if document_id not in dataset_documents: continue - segment = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.id == child_chunk.segment_id, - ) - .options( - load_only( - DocumentSegment.id, - DocumentSegment.content, - DocumentSegment.answer, + dataset_document = dataset_documents[document_id] + if not dataset_document: + continue + + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + # Handle parent-child documents + if document.metadata.get("doc_type") == DocType.IMAGE: + attachment_info_dict = cls.get_segment_attachment_info( + dataset_document.dataset_id, + dataset_document.tenant_id, + document.metadata.get("doc_id") or "", + session, ) + if attachment_info_dict: + attachment_info = attachment_info_dict["attachment_info"] + segment_id = attachment_info_dict["segment_id"] + else: + child_index_node_id = document.metadata.get("doc_id") + child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id) + child_chunk = session.scalar(child_chunk_stmt) + + if not child_chunk: + continue + segment_id = child_chunk.segment_id + + if not segment_id: + continue + + segment = ( + session.query(DocumentSegment) + .where( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.id == segment_id, + ) + .first() ) - .first() - ) - if not segment: - continue + if not segment: + continue - if segment.id not in include_segment_ids: - include_segment_ids.add(segment.id) - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - map_detail = { - "max_score": document.metadata.get("score", 0.0), - "child_chunks": [child_chunk_detail], - } - segment_child_map[segment.id] = map_detail - record = { - "segment": segment, - } - records.append(record) + if segment.id not in include_segment_ids: + include_segment_ids.add(segment.id) + if child_chunk: + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + map_detail = { + "max_score": document.metadata.get("score", 0.0), + "child_chunks": [child_chunk_detail], + } + segment_child_map[segment.id] = map_detail + record = { + "segment": segment, + } + if attachment_info: + segment_file_map[segment.id] = [attachment_info] + records.append(record) + else: + if child_chunk: + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + if segment.id in segment_child_map: + segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) + segment_child_map[segment.id]["max_score"] = max( + segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) + ) + else: + segment_child_map[segment.id] = { + "max_score": document.metadata.get("score", 0.0), + "child_chunks": [child_chunk_detail], + } + if attachment_info: + if segment.id in segment_file_map: + segment_file_map[segment.id].append(attachment_info) + else: + segment_file_map[segment.id] = [attachment_info] else: - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) - segment_child_map[segment.id]["max_score"] = max( - segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) - ) - else: - # Handle normal documents - index_node_id = document.metadata.get("doc_id") - if not index_node_id: - continue - document_segment_stmt = select(DocumentSegment).where( - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.index_node_id == index_node_id, - ) - segment = db.session.scalar(document_segment_stmt) + # Handle normal documents + segment = None + if document.metadata.get("doc_type") == DocType.IMAGE: + attachment_info_dict = cls.get_segment_attachment_info( + dataset_document.dataset_id, + dataset_document.tenant_id, + document.metadata.get("doc_id") or "", + session, + ) + if attachment_info_dict: + attachment_info = attachment_info_dict["attachment_info"] + segment_id = attachment_info_dict["segment_id"] + document_segment_stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.id == segment_id, + ) + segment = session.scalar(document_segment_stmt) + if segment: + segment_file_map[segment.id] = [attachment_info] + else: + index_node_id = document.metadata.get("doc_id") + if not index_node_id: + continue + document_segment_stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id == index_node_id, + ) + segment = session.scalar(document_segment_stmt) - if not segment: - continue - - include_segment_ids.add(segment.id) - record = { - "segment": segment, - "score": document.metadata.get("score"), # type: ignore - } - records.append(record) + if not segment: + continue + if segment.id not in include_segment_ids: + include_segment_ids.add(segment.id) + record = { + "segment": segment, + "score": document.metadata.get("score"), # type: ignore + } + if attachment_info: + segment_file_map[segment.id] = [attachment_info] + records.append(record) + else: + if attachment_info: + attachment_infos = segment_file_map.get(segment.id, []) + if attachment_info not in attachment_infos: + attachment_infos.append(attachment_info) + segment_file_map[segment.id] = attachment_infos # Add child chunks information to records for record in records: if record["segment"].id in segment_child_map: record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore record["score"] = segment_child_map[record["segment"].id]["max_score"] + if record["segment"].id in segment_file_map: + record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment] result = [] for record in records: @@ -447,6 +536,11 @@ class RetrievalService: if not isinstance(child_chunks, list): child_chunks = None + # Extract files, ensuring it's a list or None + files = record.get("files") + if not isinstance(files, list): + files = None + # Extract score, ensuring it's a float or None score_value = record.get("score") score = ( @@ -456,10 +550,149 @@ class RetrievalService: ) # Create RetrievalSegments object - retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score) + retrieval_segment = RetrievalSegments( + segment=segment, child_chunks=child_chunks, score=score, files=files + ) result.append(retrieval_segment) return result except Exception as e: db.session.rollback() raise e + + def _retrieve( + self, + flask_app: Flask, + retrieval_method: RetrievalMethod, + dataset: Dataset, + query: str | None = None, + top_k: int = 4, + score_threshold: float | None = 0.0, + reranking_model: dict | None = None, + reranking_mode: str = "reranking_model", + weights: dict | None = None, + document_ids_filter: list[str] | None = None, + attachment_id: str | None = None, + all_documents: list[Document] = [], + exceptions: list[str] = [], + ): + if not query and not attachment_id: + return + with flask_app.app_context(): + all_documents_item: list[Document] = [] + # Optimize multithreading with thread pools + with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore + futures = [] + if retrieval_method == RetrievalMethod.KEYWORD_SEARCH and query: + futures.append( + executor.submit( + self.keyword_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=query, + top_k=top_k, + all_documents=all_documents_item, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + ) + ) + if RetrievalMethod.is_support_semantic_search(retrieval_method): + if query: + futures.append( + executor.submit( + self.embedding_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + all_documents=all_documents_item, + retrieval_method=retrieval_method, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + query_type=QueryType.TEXT_QUERY, + ) + ) + if attachment_id: + futures.append( + executor.submit( + self.embedding_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=attachment_id, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + all_documents=all_documents_item, + retrieval_method=retrieval_method, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + query_type=QueryType.IMAGE_QUERY, + ) + ) + if RetrievalMethod.is_support_fulltext_search(retrieval_method) and query: + futures.append( + executor.submit( + self.full_text_index_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + all_documents=all_documents_item, + retrieval_method=retrieval_method, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + ) + ) + concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED) + + if exceptions: + raise ValueError(";\n".join(exceptions)) + + # Deduplicate documents for hybrid search to avoid duplicate chunks + if retrieval_method == RetrievalMethod.HYBRID_SEARCH: + if attachment_id and reranking_mode == RerankMode.WEIGHTED_SCORE: + all_documents.extend(all_documents_item) + all_documents_item = self._deduplicate_documents(all_documents_item) + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), reranking_mode, reranking_model, weights, False + ) + + query = query or attachment_id + if not query: + return + all_documents_item = data_post_processor.invoke( + query=query, + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.TEXT_QUERY if query else QueryType.IMAGE_QUERY, + ) + + all_documents.extend(all_documents_item) + + @classmethod + def get_segment_attachment_info( + cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session + ) -> dict[str, Any] | None: + upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first() + if upload_file: + attachment_binding = ( + session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.attachment_id == upload_file.id) + .first() + ) + if attachment_binding: + attachment_info = { + "id": upload_file.id, + "name": upload_file.name, + "extension": "." + upload_file.extension, + "mime_type": upload_file.mime_type, + "source_url": sign_upload_file(upload_file.id, upload_file.extension), + "size": upload_file.size, + } + return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id} + return None diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 0beb388693..3a47241293 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -1,3 +1,4 @@ +import base64 import logging import time from abc import ABC, abstractmethod @@ -12,10 +13,13 @@ from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.embedding.embedding_base import Embeddings +from core.rag.index_processor.constant.doc_type import DocType from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client +from extensions.ext_storage import storage from models.dataset import Dataset, Whitelist +from models.model import UploadFile logger = logging.getLogger(__name__) @@ -203,6 +207,47 @@ class Vector: self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs) logger.info("Embedding %s texts took %s s", len(texts), time.time() - start) + def create_multimodal(self, file_documents: list | None = None, **kwargs): + if file_documents: + start = time.time() + logger.info("start embedding %s files %s", len(file_documents), start) + batch_size = 1000 + total_batches = len(file_documents) + batch_size - 1 + for i in range(0, len(file_documents), batch_size): + batch = file_documents[i : i + batch_size] + batch_start = time.time() + logger.info("Processing batch %s/%s (%s files)", i // batch_size + 1, total_batches, len(batch)) + + # Batch query all upload files to avoid N+1 queries + attachment_ids = [doc.metadata["doc_id"] for doc in batch] + stmt = select(UploadFile).where(UploadFile.id.in_(attachment_ids)) + upload_files = db.session.scalars(stmt).all() + upload_file_map = {str(f.id): f for f in upload_files} + + file_base64_list = [] + real_batch = [] + for document in batch: + attachment_id = document.metadata["doc_id"] + doc_type = document.metadata["doc_type"] + upload_file = upload_file_map.get(attachment_id) + if upload_file: + blob = storage.load_once(upload_file.key) + file_base64_str = base64.b64encode(blob).decode() + file_base64_list.append( + { + "content": file_base64_str, + "content_type": doc_type, + "file_id": attachment_id, + } + ) + real_batch.append(document) + batch_embeddings = self._embeddings.embed_multimodal_documents(file_base64_list) + logger.info( + "Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start + ) + self._vector_processor.create(texts=real_batch, embeddings=batch_embeddings, **kwargs) + logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start) + def add_texts(self, documents: list[Document], **kwargs): if kwargs.get("duplicate_check", False): documents = self._filter_duplicate_texts(documents) @@ -223,6 +268,22 @@ class Vector: query_vector = self._embeddings.embed_query(query) return self._vector_processor.search_by_vector(query_vector, **kwargs) + def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]: + upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + + if not upload_file: + return [] + blob = storage.load_once(upload_file.key) + file_base64_str = base64.b64encode(blob).decode() + multimodal_vector = self._embeddings.embed_multimodal_query( + { + "content": file_base64_str, + "content_type": DocType.IMAGE, + "file_id": file_id, + } + ) + return self._vector_processor.search_by_vector(multimodal_vector, **kwargs) + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return self._vector_processor.search_by_full_text(query, **kwargs) diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 2c7bc592c0..84d1e26b34 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -79,6 +79,18 @@ class WeaviateVector(BaseVector): self._client = self._init_client(config) self._attributes = attributes + def __del__(self): + """ + Destructor to properly close the Weaviate client connection. + Prevents connection leaks and resource warnings. + """ + if hasattr(self, "_client") and self._client is not None: + try: + self._client.close() + except Exception as e: + # Ignore errors during cleanup as object is being destroyed + logger.warning("Error closing Weaviate client %s", e, exc_info=True) + def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient: """ Initializes and returns a connected Weaviate client. diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 74a2653e9d..1fe74d3042 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -5,9 +5,9 @@ from sqlalchemy import func, select from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.rag.models.document import Document +from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DocumentSegment +from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding class DatasetDocumentStore: @@ -120,6 +120,9 @@ class DatasetDocumentStore: db.session.add(segment_document) db.session.flush() + self.add_multimodel_documents_binding( + segment_id=segment_document.id, multimodel_documents=doc.attachments + ) if save_child: if doc.children: for position, child in enumerate(doc.children, start=1): @@ -144,6 +147,9 @@ class DatasetDocumentStore: segment_document.index_node_hash = doc.metadata.get("doc_hash") segment_document.word_count = len(doc.page_content) segment_document.tokens = tokens + self.add_multimodel_documents_binding( + segment_id=segment_document.id, multimodel_documents=doc.attachments + ) if save_child and doc.children: # delete the existing child chunks db.session.query(ChildChunk).where( @@ -233,3 +239,15 @@ class DatasetDocumentStore: document_segment = db.session.scalar(stmt) return document_segment + + def add_multimodel_documents_binding(self, segment_id: str, multimodel_documents: list[AttachmentDocument] | None): + if multimodel_documents: + for multimodel_document in multimodel_documents: + binding = SegmentAttachmentBinding( + tenant_id=self._dataset.tenant_id, + dataset_id=self._dataset.id, + document_id=self._document_id, + segment_id=segment_id, + attachment_id=multimodel_document.metadata["doc_id"], + ) + db.session.add(binding) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 7fb20c1941..3cbc7db75d 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -104,6 +104,88 @@ class CacheEmbedding(Embeddings): return text_embeddings + def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]: + """Embed file documents.""" + # use doc embedding cache or store if not exists + multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))] + embedding_queue_indices = [] + for i, multimodel_document in enumerate(multimodel_documents): + file_id = multimodel_document["file_id"] + embedding = ( + db.session.query(Embedding) + .filter_by( + model_name=self._model_instance.model, hash=file_id, provider_name=self._model_instance.provider + ) + .first() + ) + if embedding: + multimodel_embeddings[i] = embedding.get_embedding() + else: + embedding_queue_indices.append(i) + + # NOTE: avoid closing the shared scoped session here; downstream code may still have pending work + + if embedding_queue_indices: + embedding_queue_multimodel_documents = [multimodel_documents[i] for i in embedding_queue_indices] + embedding_queue_embeddings = [] + try: + model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) + model_schema = model_type_instance.get_model_schema( + self._model_instance.model, self._model_instance.credentials + ) + max_chunks = ( + model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties + else 1 + ) + for i in range(0, len(embedding_queue_multimodel_documents), max_chunks): + batch_multimodel_documents = embedding_queue_multimodel_documents[i : i + max_chunks] + + embedding_result = self._model_instance.invoke_multimodal_embedding( + multimodel_documents=batch_multimodel_documents, + user=self._user, + input_type=EmbeddingInputType.DOCUMENT, + ) + + for vector in embedding_result.embeddings: + try: + # FIXME: type ignore for numpy here + normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore + # stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan + if np.isnan(normalized_embedding).any(): + # for issue #11827 float values are not json compliant + logger.warning("Normalized embedding is nan: %s", normalized_embedding) + continue + embedding_queue_embeddings.append(normalized_embedding) + except IntegrityError: + db.session.rollback() + except Exception: + logger.exception("Failed transform embedding") + cache_embeddings = [] + try: + for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings): + multimodel_embeddings[i] = n_embedding + file_id = multimodel_documents[i]["file_id"] + if file_id not in cache_embeddings: + embedding_cache = Embedding( + model_name=self._model_instance.model, + hash=file_id, + provider_name=self._model_instance.provider, + embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL), + ) + embedding_cache.set_embedding(n_embedding) + db.session.add(embedding_cache) + cache_embeddings.append(file_id) + db.session.commit() + except IntegrityError: + db.session.rollback() + except Exception as ex: + db.session.rollback() + logger.exception("Failed to embed documents") + raise ex + + return multimodel_embeddings + def embed_query(self, text: str) -> list[float]: """Embed query text.""" # use doc embedding cache or store if not exists @@ -146,3 +228,46 @@ class CacheEmbedding(Embeddings): raise ex return embedding_results # type: ignore + + def embed_multimodal_query(self, multimodel_document: dict) -> list[float]: + """Embed multimodal documents.""" + # use doc embedding cache or store if not exists + file_id = multimodel_document["file_id"] + embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{file_id}" + embedding = redis_client.get(embedding_cache_key) + if embedding: + redis_client.expire(embedding_cache_key, 600) + decoded_embedding = np.frombuffer(base64.b64decode(embedding), dtype="float") + return [float(x) for x in decoded_embedding] + try: + embedding_result = self._model_instance.invoke_multimodal_embedding( + multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY + ) + + embedding_results = embedding_result.embeddings[0] + # FIXME: type ignore for numpy here + embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore + if np.isnan(embedding_results).any(): + raise ValueError("Normalized embedding is nan please try again") + except Exception as ex: + if dify_config.DEBUG: + logger.exception("Failed to embed multimodal document '%s'", multimodel_document["file_id"]) + raise ex + + try: + # encode embedding to base64 + embedding_vector = np.array(embedding_results) + vector_bytes = embedding_vector.tobytes() + # Transform to Base64 + encoded_vector = base64.b64encode(vector_bytes) + # Transform to string + encoded_str = encoded_vector.decode("utf-8") + redis_client.setex(embedding_cache_key, 600, encoded_str) + except Exception as ex: + if dify_config.DEBUG: + logger.exception( + "Failed to add embedding to redis for the multimodal document '%s'", multimodel_document["file_id"] + ) + raise ex + + return embedding_results # type: ignore diff --git a/api/core/rag/embedding/embedding_base.py b/api/core/rag/embedding/embedding_base.py index 9f232ab910..1be55bda80 100644 --- a/api/core/rag/embedding/embedding_base.py +++ b/api/core/rag/embedding/embedding_base.py @@ -9,11 +9,21 @@ class Embeddings(ABC): """Embed search docs.""" raise NotImplementedError + @abstractmethod + def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]: + """Embed file documents.""" + raise NotImplementedError + @abstractmethod def embed_query(self, text: str) -> list[float]: """Embed query text.""" raise NotImplementedError + @abstractmethod + def embed_multimodal_query(self, multimodel_document: dict) -> list[float]: + """Embed multimodal query.""" + raise NotImplementedError + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """Asynchronous Embed search docs.""" raise NotImplementedError diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py index 8e92191568..b54a37b49e 100644 --- a/api/core/rag/embedding/retrieval.py +++ b/api/core/rag/embedding/retrieval.py @@ -19,3 +19,4 @@ class RetrievalSegments(BaseModel): segment: DocumentSegment child_chunks: list[RetrievalChildChunk] | None = None score: float | None = None + files: list[dict[str, str | int]] | None = None diff --git a/api/core/rag/entities/citation_metadata.py b/api/core/rag/entities/citation_metadata.py index aca879df7d..9f66cd9a03 100644 --- a/api/core/rag/entities/citation_metadata.py +++ b/api/core/rag/entities/citation_metadata.py @@ -21,3 +21,4 @@ class RetrievalSourceMetadata(BaseModel): page: int | None = None doc_metadata: dict[str, Any] | None = None title: str | None = None + files: list[dict[str, Any]] | None = None diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index ea9c6bd73a..875bfd1439 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -1,7 +1,7 @@ """Abstract interface for document loader implementations.""" import os -from typing import cast +from typing import TypedDict import pandas as pd from openpyxl import load_workbook @@ -10,6 +10,12 @@ from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document +class Candidate(TypedDict): + idx: int + count: int + map: dict[int, str] + + class ExcelExtractor(BaseExtractor): """Load Excel files. @@ -30,32 +36,38 @@ class ExcelExtractor(BaseExtractor): file_extension = os.path.splitext(self._file_path)[-1].lower() if file_extension == ".xlsx": - wb = load_workbook(self._file_path, data_only=True) - for sheet_name in wb.sheetnames: - sheet = wb[sheet_name] - data = sheet.values - cols = next(data, None) - if cols is None: - continue - df = pd.DataFrame(data, columns=cols) - - df.dropna(how="all", inplace=True) - - for index, row in df.iterrows(): - page_content = [] - for col_index, (k, v) in enumerate(row.items()): - if pd.notna(v): - cell = sheet.cell( - row=cast(int, index) + 2, column=col_index + 1 - ) # +2 to account for header and 1-based index - if cell.hyperlink: - value = f"[{v}]({cell.hyperlink.target})" - page_content.append(f'"{k}":"{value}"') - else: - page_content.append(f'"{k}":"{v}"') - documents.append( - Document(page_content=";".join(page_content), metadata={"source": self._file_path}) - ) + wb = load_workbook(self._file_path, read_only=True, data_only=True) + try: + for sheet_name in wb.sheetnames: + sheet = wb[sheet_name] + header_row_idx, column_map, max_col_idx = self._find_header_and_columns(sheet) + if not column_map: + continue + start_row = header_row_idx + 1 + for row in sheet.iter_rows(min_row=start_row, max_col=max_col_idx, values_only=False): + if all(cell.value is None for cell in row): + continue + page_content = [] + for col_idx, cell in enumerate(row): + value = cell.value + if col_idx in column_map: + col_name = column_map[col_idx] + if hasattr(cell, "hyperlink") and cell.hyperlink: + target = getattr(cell.hyperlink, "target", None) + if target: + value = f"[{value}]({target})" + if value is None: + value = "" + elif not isinstance(value, str): + value = str(value) + value = value.strip().replace('"', '\\"') + page_content.append(f'"{col_name}":"{value}"') + if page_content: + documents.append( + Document(page_content=";".join(page_content), metadata={"source": self._file_path}) + ) + finally: + wb.close() elif file_extension == ".xls": excel_file = pd.ExcelFile(self._file_path, engine="xlrd") @@ -63,9 +75,9 @@ class ExcelExtractor(BaseExtractor): df = excel_file.parse(sheet_name=excel_sheet_name) df.dropna(how="all", inplace=True) - for _, row in df.iterrows(): + for _, series_row in df.iterrows(): page_content = [] - for k, v in row.items(): + for k, v in series_row.items(): if pd.notna(v): page_content.append(f'"{k}":"{v}"') documents.append( @@ -75,3 +87,61 @@ class ExcelExtractor(BaseExtractor): raise ValueError(f"Unsupported file extension: {file_extension}") return documents + + def _find_header_and_columns(self, sheet, scan_rows=10) -> tuple[int, dict[int, str], int]: + """ + Scan first N rows to find the most likely header row. + Returns: + header_row_idx: 1-based index of the header row + column_map: Dict mapping 0-based column index to column name + max_col_idx: 1-based index of the last valid column (for iter_rows boundary) + """ + # Store potential candidates: (row_index, non_empty_count, column_map) + candidates: list[Candidate] = [] + + # Limit scan to avoid performance issues on huge files + # We iterate manually to control the read scope + for current_row_idx, row in enumerate(sheet.iter_rows(min_row=1, max_row=scan_rows, values_only=True), start=1): + # Filter out empty cells and build a temp map for this row + # col_idx is 0-based + row_map = {} + for col_idx, cell_value in enumerate(row): + if cell_value is not None and str(cell_value).strip(): + row_map[col_idx] = str(cell_value).strip().replace('"', '\\"') + + if not row_map: + continue + + non_empty_count = len(row_map) + + # Header selection heuristic (implemented): + # - Prefer the first row with at least 2 non-empty columns. + # - Fallback: choose the row with the most non-empty columns + # (tie-breaker: smaller row index). + candidates.append({"idx": current_row_idx, "count": non_empty_count, "map": row_map}) + + if not candidates: + return 0, {}, 0 + + # Choose the best candidate header row. + + best_candidate: Candidate | None = None + + # Strategy: prefer the first row with >= 2 non-empty columns; otherwise fallback. + + for cand in candidates: + if cand["count"] >= 2: + best_candidate = cand + break + + # Fallback: if no row has >= 2 columns, or all have 1, just take the one with max columns + if not best_candidate: + # Sort by count desc, then index asc + candidates.sort(key=lambda x: (-x["count"], x["idx"])) + best_candidate = candidates[0] + + # Determine max_col_idx (1-based for openpyxl) + # It is the index of the last valid column in our map + 1 + max_col_idx = max(best_candidate["map"].keys()) + 1 + + return best_candidate["idx"], best_candidate["map"], max_col_idx diff --git a/api/core/rag/extractor/helpers.py b/api/core/rag/extractor/helpers.py index 00004409d6..5166c0c768 100644 --- a/api/core/rag/extractor/helpers.py +++ b/api/core/rag/extractor/helpers.py @@ -1,7 +1,9 @@ """Document loader helpers.""" import concurrent.futures -from typing import NamedTuple, cast +from typing import NamedTuple + +import charset_normalizer class FileEncoding(NamedTuple): @@ -27,14 +29,14 @@ def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1 sample_size: The number of bytes to read for encoding detection. Default is 1MB. For large files, reading only a sample is sufficient and prevents timeout. """ - import chardet - def read_and_detect(file_path: str): - with open(file_path, "rb") as f: - # Read only a sample of the file for encoding detection - # This prevents timeout on large files while still providing accurate encoding detection - rawdata = f.read(sample_size) - return cast(list[dict], chardet.detect_all(rawdata)) + def read_and_detect(filename: str): + rst = charset_normalizer.from_path(filename) + best = rst.best() + if best is None: + return [] + file_encoding = FileEncoding(encoding=best.encoding, confidence=best.coherence, language=best.language) + return [file_encoding] with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(read_and_detect, file_path) diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index c7a5568866..438932cfd6 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -84,22 +84,46 @@ class WordExtractor(BaseExtractor): image_count = 0 image_map = {} - for rel in doc.part.rels.values(): + for rId, rel in doc.part.rels.items(): if "image" in rel.target_ref: image_count += 1 if rel.is_external: url = rel.target_ref - response = ssrf_proxy.get(url) + if not self._is_valid_url(url): + continue + try: + response = ssrf_proxy.get(url) + except Exception as e: + logger.warning("Failed to download image from URL: %s: %s", url, str(e)) + continue if response.status_code == 200: - image_ext = mimetypes.guess_extension(response.headers["Content-Type"]) + image_ext = mimetypes.guess_extension(response.headers.get("Content-Type", "")) if image_ext is None: continue file_uuid = str(uuid.uuid4()) - file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext + file_key = "image_files/" + self.tenant_id + "/" + file_uuid + image_ext mime_type, _ = mimetypes.guess_type(file_key) storage.save(file_key, response.content) - else: - continue + # save file to db + upload_file = UploadFile( + tenant_id=self.tenant_id, + storage_type=dify_config.STORAGE_TYPE, + key=file_key, + name=file_key, + size=0, + extension=str(image_ext), + mime_type=mime_type or "", + created_by=self.user_id, + created_by_role=CreatorUserRole.ACCOUNT, + created_at=naive_utc_now(), + used=True, + used_by=self.user_id, + used_at=naive_utc_now(), + ) + db.session.add(upload_file) + db.session.commit() + # Use rId as key for external images since target_part is undefined + image_map[rId] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)" else: image_ext = rel.target_ref.split(".")[-1] if image_ext is None: @@ -110,26 +134,28 @@ class WordExtractor(BaseExtractor): mime_type, _ = mimetypes.guess_type(file_key) storage.save(file_key, rel.target_part.blob) - # save file to db - upload_file = UploadFile( - tenant_id=self.tenant_id, - storage_type=dify_config.STORAGE_TYPE, - key=file_key, - name=file_key, - size=0, - extension=str(image_ext), - mime_type=mime_type or "", - created_by=self.user_id, - created_by_role=CreatorUserRole.ACCOUNT, - created_at=naive_utc_now(), - used=True, - used_by=self.user_id, - used_at=naive_utc_now(), - ) - - db.session.add(upload_file) - db.session.commit() - image_map[rel.target_part] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)" + # save file to db + upload_file = UploadFile( + tenant_id=self.tenant_id, + storage_type=dify_config.STORAGE_TYPE, + key=file_key, + name=file_key, + size=0, + extension=str(image_ext), + mime_type=mime_type or "", + created_by=self.user_id, + created_by_role=CreatorUserRole.ACCOUNT, + created_at=naive_utc_now(), + used=True, + used_by=self.user_id, + used_at=naive_utc_now(), + ) + db.session.add(upload_file) + db.session.commit() + # Use target_part as key for internal images + image_map[rel.target_part] = ( + f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)" + ) return image_map @@ -186,11 +212,17 @@ class WordExtractor(BaseExtractor): image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") if not image_id: continue - image_part = paragraph.part.rels[image_id].target_part - - if image_part in image_map: - image_link = image_map[image_part] - paragraph_content.append(image_link) + rel = paragraph.part.rels.get(image_id) + if rel is None: + continue + # For external images, use image_id as key; for internal, use target_part + if rel.is_external: + if image_id in image_map: + paragraph_content.append(image_map[image_id]) + else: + image_part = rel.target_part + if image_part in image_map: + paragraph_content.append(image_map[image_part]) else: paragraph_content.append(run.text) return "".join(paragraph_content).strip() @@ -227,6 +259,18 @@ class WordExtractor(BaseExtractor): def parse_paragraph(paragraph): paragraph_content = [] + + def append_image_link(image_id, has_drawing): + """Helper to append image link from image_map based on relationship type.""" + rel = doc.part.rels[image_id] + if rel.is_external: + if image_id in image_map and not has_drawing: + paragraph_content.append(image_map[image_id]) + else: + image_part = rel.target_part + if image_part in image_map and not has_drawing: + paragraph_content.append(image_map[image_part]) + for run in paragraph.runs: if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"): # Process drawing type images @@ -243,10 +287,18 @@ class WordExtractor(BaseExtractor): "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed" ) if embed_id: - image_part = doc.part.related_parts.get(embed_id) - if image_part in image_map: - has_drawing = True - paragraph_content.append(image_map[image_part]) + rel = doc.part.rels.get(embed_id) + if rel is not None and rel.is_external: + # External image: use embed_id as key + if embed_id in image_map: + has_drawing = True + paragraph_content.append(image_map[embed_id]) + else: + # Internal image: use target_part as key + image_part = doc.part.related_parts.get(embed_id) + if image_part in image_map: + has_drawing = True + paragraph_content.append(image_map[image_part]) # Process pict type images shape_elements = run.element.findall( ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict" @@ -261,9 +313,7 @@ class WordExtractor(BaseExtractor): "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id" ) if image_id and image_id in doc.part.rels: - image_part = doc.part.rels[image_id].target_part - if image_part in image_map and not has_drawing: - paragraph_content.append(image_map[image_part]) + append_image_link(image_id, has_drawing) # Find imagedata element in VML image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata") if image_data is not None: @@ -271,9 +321,7 @@ class WordExtractor(BaseExtractor): "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id" ) if image_id and image_id in doc.part.rels: - image_part = doc.part.rels[image_id].target_part - if image_part in image_map and not has_drawing: - paragraph_content.append(image_map[image_part]) + append_image_link(image_id, has_drawing) if run.text.strip(): paragraph_content.append(run.text.strip()) return "".join(paragraph_content) if paragraph_content else "" diff --git a/api/core/rag/index_processor/constant/doc_type.py b/api/core/rag/index_processor/constant/doc_type.py new file mode 100644 index 0000000000..93c8fecb8d --- /dev/null +++ b/api/core/rag/index_processor/constant/doc_type.py @@ -0,0 +1,6 @@ +from enum import StrEnum + + +class DocType(StrEnum): + TEXT = "text" + IMAGE = "image" diff --git a/api/core/rag/index_processor/constant/index_type.py b/api/core/rag/index_processor/constant/index_type.py index 659086e808..09617413f7 100644 --- a/api/core/rag/index_processor/constant/index_type.py +++ b/api/core/rag/index_processor/constant/index_type.py @@ -1,7 +1,12 @@ from enum import StrEnum -class IndexType(StrEnum): +class IndexStructureType(StrEnum): PARAGRAPH_INDEX = "text_model" QA_INDEX = "qa_model" PARENT_CHILD_INDEX = "hierarchical_model" + + +class IndexTechniqueType(StrEnum): + ECONOMY = "economy" + HIGH_QUALITY = "high_quality" diff --git a/api/core/rag/index_processor/constant/query_type.py b/api/core/rag/index_processor/constant/query_type.py new file mode 100644 index 0000000000..342bfef3f7 --- /dev/null +++ b/api/core/rag/index_processor/constant/query_type.py @@ -0,0 +1,6 @@ +from enum import StrEnum + + +class QueryType(StrEnum): + TEXT_QUERY = "text_query" + IMAGE_QUERY = "image_query" diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index d4eff53204..8a28eb477a 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -1,20 +1,34 @@ """Abstract interface for document loader implementations.""" +import cgi +import logging +import mimetypes +import os +import re from abc import ABC, abstractmethod from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Optional +from urllib.parse import unquote, urlparse + +import httpx from configs import dify_config +from core.helper import ssrf_proxy from core.rag.extractor.entity.extract_setting import ExtractSetting -from core.rag.models.document import Document +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.models.document import AttachmentDocument, Document from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.splitter.fixed_text_splitter import ( EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter, ) from core.rag.splitter.text_splitter import TextSplitter +from extensions.ext_database import db +from extensions.ext_storage import storage +from models import Account, ToolFile from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument +from models.model import UploadFile if TYPE_CHECKING: from core.model_manager import ModelInstance @@ -28,11 +42,18 @@ class BaseIndexProcessor(ABC): raise NotImplementedError @abstractmethod - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: raise NotImplementedError @abstractmethod - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): raise NotImplementedError @abstractmethod @@ -96,3 +117,178 @@ class BaseIndexProcessor(ABC): ) return character_splitter # type: ignore + + def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]: + """ + Get the content files from the document. + """ + multi_model_documents: list[AttachmentDocument] = [] + text = document.page_content + images = self._extract_markdown_images(text) + if not images: + return multi_model_documents + upload_file_id_list = [] + + for image in images: + # Collect all upload_file_ids including duplicates to preserve occurrence count + + # For data before v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?" + match = re.search(pattern, image) + if match: + upload_file_id = match.group(1) + upload_file_id_list.append(upload_file_id) + continue + + # For data after v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?" + match = re.search(pattern, image) + if match: + upload_file_id = match.group(1) + upload_file_id_list.append(upload_file_id) + continue + + # For tools directory - direct file formats (e.g., .png, .jpg, etc.) + # Match URL including any query parameters up to common URL boundaries (space, parenthesis, quotes) + pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?" + match = re.search(pattern, image) + if match: + if current_user: + tool_file_id = match.group(1) + upload_file_id = self._download_tool_file(tool_file_id, current_user) + if upload_file_id: + upload_file_id_list.append(upload_file_id) + continue + if current_user: + upload_file_id = self._download_image(image.split(" ")[0], current_user) + if upload_file_id: + upload_file_id_list.append(upload_file_id) + + if not upload_file_id_list: + return multi_model_documents + + # Get unique IDs for database query + unique_upload_file_ids = list(set(upload_file_id_list)) + upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all() + + # Create a mapping from ID to UploadFile for quick lookup + upload_file_map = {upload_file.id: upload_file for upload_file in upload_files} + + # Create a Document for each occurrence (including duplicates) + for upload_file_id in upload_file_id_list: + upload_file = upload_file_map.get(upload_file_id) + if upload_file: + multi_model_documents.append( + AttachmentDocument( + page_content=upload_file.name, + metadata={ + "doc_id": upload_file.id, + "doc_hash": "", + "document_id": document.metadata.get("document_id"), + "dataset_id": document.metadata.get("dataset_id"), + "doc_type": DocType.IMAGE, + }, + ) + ) + return multi_model_documents + + def _extract_markdown_images(self, text: str) -> list[str]: + """ + Extract the markdown images from the text. + """ + pattern = r"!\[.*?\]\((.*?)\)" + return re.findall(pattern, text) + + def _download_image(self, image_url: str, current_user: Account) -> str | None: + """ + Download the image from the URL. + Image size must not exceed 2MB. + """ + from services.file_service import FileService + + MAX_IMAGE_SIZE = dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 + DOWNLOAD_TIMEOUT = dify_config.ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT + + try: + # Download with timeout + response = ssrf_proxy.get(image_url, timeout=DOWNLOAD_TIMEOUT) + response.raise_for_status() + + # Check Content-Length header if available + content_length = response.headers.get("Content-Length") + if content_length and int(content_length) > MAX_IMAGE_SIZE: + logging.warning("Image from %s exceeds 2MB limit (size: %s bytes)", image_url, content_length) + return None + + filename = None + + content_disposition = response.headers.get("content-disposition") + if content_disposition: + _, params = cgi.parse_header(content_disposition) + if "filename" in params: + filename = params["filename"] + filename = unquote(filename) + + if not filename: + parsed_url = urlparse(image_url) + # unquote 处理 URL 中的中文 + path = unquote(parsed_url.path) + filename = os.path.basename(path) + + if not filename: + filename = "downloaded_image_file" + + name, current_ext = os.path.splitext(filename) + + content_type = response.headers.get("content-type", "").split(";")[0].strip() + + real_ext = mimetypes.guess_extension(content_type) + + if not current_ext and real_ext or current_ext in [".php", ".jsp", ".asp", ".html"] and real_ext: + filename = f"{name}{real_ext}" + # Download content with size limit + blob = b"" + for chunk in response.iter_bytes(chunk_size=8192): + blob += chunk + if len(blob) > MAX_IMAGE_SIZE: + logging.warning("Image from %s exceeds 2MB limit during download", image_url) + return None + + if not blob: + logging.warning("Image from %s is empty", image_url) + return None + + upload_file = FileService(db.engine).upload_file( + filename=filename, + content=blob, + mimetype=content_type, + user=current_user, + ) + return upload_file.id + except httpx.TimeoutException: + logging.warning("Timeout downloading image from %s after %s seconds", image_url, DOWNLOAD_TIMEOUT) + return None + except httpx.RequestError as e: + logging.warning("Error downloading image from %s: %s", image_url, str(e)) + return None + except Exception: + logging.exception("Unexpected error downloading image from %s", image_url) + return None + + def _download_tool_file(self, tool_file_id: str, current_user: Account) -> str | None: + """ + Download the tool file from the ID. + """ + from services.file_service import FileService + + tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first() + if not tool_file: + return None + blob = storage.load_once(tool_file.file_key) + upload_file = FileService(db.engine).upload_file( + filename=tool_file.name, + content=blob, + mimetype=tool_file.mimetype, + user=current_user, + ) + return upload_file.id diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py index c987edf342..ea6ab24699 100644 --- a/api/core/rag/index_processor/index_processor_factory.py +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -1,6 +1,6 @@ """Abstract interface for document loader implementations.""" -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor @@ -19,11 +19,11 @@ class IndexProcessorFactory: if not self._index_type: raise ValueError("Index type must be specified.") - if self._index_type == IndexType.PARAGRAPH_INDEX: + if self._index_type == IndexStructureType.PARAGRAPH_INDEX: return ParagraphIndexProcessor() - elif self._index_type == IndexType.QA_INDEX: + elif self._index_type == IndexStructureType.QA_INDEX: return QAIndexProcessor() - elif self._index_type == IndexType.PARENT_CHILD_INDEX: + elif self._index_type == IndexStructureType.PARENT_CHILD_INDEX: return ParentChildIndexProcessor() else: raise ValueError(f"Index type {self._index_type} is not supported.") diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 5e5fea7ea9..cf68cff7dc 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -11,14 +11,17 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import Document +from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper +from models.account import Account from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument +from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import Rule @@ -33,7 +36,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): return text_docs - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: process_rule = kwargs.get("process_rule") if not process_rule: raise ValueError("No process rule found.") @@ -69,6 +72,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor): if document_node.metadata is not None: document_node.metadata["doc_id"] = doc_id document_node.metadata["doc_hash"] = hash + multimodal_documents = ( + self._get_content_files(document_node, current_user) if document_node.metadata else None + ) + if multimodal_documents: + document_node.attachments = multimodal_documents # delete Splitter character page_content = remove_leading_symbols(document_node.page_content).strip() if len(page_content) > 0: @@ -77,10 +85,19 @@ class ParagraphIndexProcessor(BaseIndexProcessor): all_documents.extend(split_documents) return all_documents - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) + if multimodal_documents and dataset.is_multimodal: + vector.create_multimodal(multimodal_documents) with_keywords = False if with_keywords: keywords_list = kwargs.get("keywords_list") @@ -134,8 +151,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor): return docs def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any): + documents: list[Any] = [] + all_multimodal_documents: list[Any] = [] if isinstance(chunks, list): - documents = [] for content in chunks: metadata = { "dataset_id": dataset.id, @@ -144,26 +162,68 @@ class ParagraphIndexProcessor(BaseIndexProcessor): "doc_hash": helper.generate_text_hash(content), } doc = Document(page_content=content, metadata=metadata) + attachments = self._get_content_files(doc) + if attachments: + doc.attachments = attachments + all_multimodal_documents.extend(attachments) documents.append(doc) - if documents: - # save node to document segment - doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) - # add document segments - doc_store.add_documents(docs=documents, save_child=False) - if dataset.indexing_technique == "high_quality": - vector = Vector(dataset) - vector.create(documents) - elif dataset.indexing_technique == "economy": - keyword = Keyword(dataset) - keyword.add_texts(documents) else: - raise ValueError("Chunks is not a list") + multimodal_general_structure = MultimodalGeneralStructureChunk.model_validate(chunks) + for general_chunk in multimodal_general_structure.general_chunks: + metadata = { + "dataset_id": dataset.id, + "document_id": document.id, + "doc_id": str(uuid.uuid4()), + "doc_hash": helper.generate_text_hash(general_chunk.content), + } + doc = Document(page_content=general_chunk.content, metadata=metadata) + if general_chunk.files: + attachments = [] + for file in general_chunk.files: + file_metadata = { + "doc_id": file.id, + "doc_hash": "", + "document_id": document.id, + "dataset_id": dataset.id, + "doc_type": DocType.IMAGE, + } + file_document = AttachmentDocument( + page_content=file.filename or "image_file", metadata=file_metadata + ) + attachments.append(file_document) + all_multimodal_documents.append(file_document) + doc.attachments = attachments + else: + account = AccountService.load_user(document.created_by) + if not account: + raise ValueError("Invalid account") + doc.attachments = self._get_content_files(doc, current_user=account) + if doc.attachments: + all_multimodal_documents.extend(doc.attachments) + documents.append(doc) + if documents: + # save node to document segment + doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) + # add document segments + doc_store.add_documents(docs=documents, save_child=False) + if dataset.indexing_technique == "high_quality": + vector = Vector(dataset) + vector.create(documents) + if all_multimodal_documents and dataset.is_multimodal: + vector.create_multimodal(all_multimodal_documents) + elif dataset.indexing_technique == "economy": + keyword = Keyword(dataset) + keyword.add_texts(documents) def format_preview(self, chunks: Any) -> Mapping[str, Any]: if isinstance(chunks, list): preview = [] for content in chunks: preview.append({"content": content}) - return {"chunk_structure": IndexType.PARAGRAPH_INDEX, "preview": preview, "total_segments": len(chunks)} + return { + "chunk_structure": IndexStructureType.PARAGRAPH_INDEX, + "preview": preview, + "total_segments": len(chunks), + } else: raise ValueError("Chunks is not a list") diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 4fa78e2f95..0366f3259f 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -13,14 +13,17 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk +from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from libs import helper +from models import Account from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument +from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule @@ -35,7 +38,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): return text_docs - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: process_rule = kwargs.get("process_rule") if not process_rule: raise ValueError("No process rule found.") @@ -77,6 +80,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor): page_content = page_content if len(page_content) > 0: document_node.page_content = page_content + multimodel_documents = self._get_content_files(document_node, current_user) + if multimodel_documents: + document_node.attachments = multimodel_documents # parse document to child nodes child_nodes = self._split_child_nodes( document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") @@ -87,6 +93,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor): elif rules.parent_mode == ParentMode.FULL_DOC: page_content = "\n".join([document.page_content for document in documents]) document = Document(page_content=page_content, metadata=documents[0].metadata) + multimodel_documents = self._get_content_files(document) + if multimodel_documents: + document.attachments = multimodel_documents # parse document to child nodes child_nodes = self._split_child_nodes( document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") @@ -104,7 +113,14 @@ class ParentChildIndexProcessor(BaseIndexProcessor): return all_documents - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) for document in documents: @@ -114,6 +130,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor): Document.model_validate(child_document.model_dump()) for child_document in child_documents ] vector.create(formatted_child_documents) + if multimodal_documents and dataset.is_multimodal: + vector.create_multimodal(multimodal_documents) def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): # node_ids is segment's node_ids @@ -244,6 +262,24 @@ class ParentChildIndexProcessor(BaseIndexProcessor): } child_documents.append(ChildDocument(page_content=child, metadata=child_metadata)) doc = Document(page_content=parent_child.parent_content, metadata=metadata, children=child_documents) + if parent_child.files and len(parent_child.files) > 0: + attachments = [] + for file in parent_child.files: + file_metadata = { + "doc_id": file.id, + "doc_hash": "", + "document_id": document.id, + "dataset_id": dataset.id, + "doc_type": DocType.IMAGE, + } + file_document = AttachmentDocument(page_content=file.filename or "", metadata=file_metadata) + attachments.append(file_document) + doc.attachments = attachments + else: + account = AccountService.load_user(document.created_by) + if not account: + raise ValueError("Invalid account") + doc.attachments = self._get_content_files(doc, current_user=account) documents.append(doc) if documents: # update document parent mode @@ -267,12 +303,17 @@ class ParentChildIndexProcessor(BaseIndexProcessor): doc_store.add_documents(docs=documents, save_child=True) if dataset.indexing_technique == "high_quality": all_child_documents = [] + all_multimodal_documents = [] for doc in documents: if doc.children: all_child_documents.extend(doc.children) + if doc.attachments: + all_multimodal_documents.extend(doc.attachments) + vector = Vector(dataset) if all_child_documents: - vector = Vector(dataset) vector.create(all_child_documents) + if all_multimodal_documents and dataset.is_multimodal: + vector.create_multimodal(all_multimodal_documents) def format_preview(self, chunks: Any) -> Mapping[str, Any]: parent_childs = ParentChildStructureChunk.model_validate(chunks) @@ -280,7 +321,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): for parent_child in parent_childs.parent_child_chunks: preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents}) return { - "chunk_structure": IndexType.PARENT_CHILD_INDEX, + "chunk_structure": IndexStructureType.PARENT_CHILD_INDEX, "parent_mode": parent_childs.parent_mode, "preview": preview, "total_segments": len(parent_childs.parent_child_chunks), diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 3e3deb0180..1183d5fbd7 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -18,12 +18,13 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import Document, QAStructureChunk +from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper +from models.account import Account from models.dataset import Dataset from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import Rule @@ -41,7 +42,7 @@ class QAIndexProcessor(BaseIndexProcessor): ) return text_docs - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: preview = kwargs.get("preview") process_rule = kwargs.get("process_rule") if not process_rule: @@ -116,7 +117,7 @@ class QAIndexProcessor(BaseIndexProcessor): try: # Skip the first row - df = pd.read_csv(file) + df = pd.read_csv(file) # type: ignore text_docs = [] for _, row in df.iterrows(): data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]}) @@ -128,10 +129,19 @@ class QAIndexProcessor(BaseIndexProcessor): raise ValueError(str(e)) return text_docs - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) + if multimodal_documents and dataset.is_multimodal: + vector.create_multimodal(multimodal_documents) def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): vector = Vector(dataset) @@ -197,7 +207,7 @@ class QAIndexProcessor(BaseIndexProcessor): for qa_chunk in qa_chunks.qa_chunks: preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer}) return { - "chunk_structure": IndexType.QA_INDEX, + "chunk_structure": IndexStructureType.QA_INDEX, "qa_preview": preview, "total_segments": len(qa_chunks.qa_chunks), } diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 4bd7b1d62e..611fad9a18 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -4,6 +4,8 @@ from typing import Any from pydantic import BaseModel, Field +from core.file import File + class ChildDocument(BaseModel): """Class for storing a piece of text and associated metadata.""" @@ -15,7 +17,19 @@ class ChildDocument(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: dict = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class AttachmentDocument(BaseModel): + """Class for storing a piece of text and associated metadata.""" + + page_content: str + + provider: str | None = "dify" + + vector: list[float] | None = None + + metadata: dict[str, Any] = Field(default_factory=dict) class Document(BaseModel): @@ -28,12 +42,31 @@ class Document(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: dict = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) provider: str | None = "dify" children: list[ChildDocument] | None = None + attachments: list[AttachmentDocument] | None = None + + +class GeneralChunk(BaseModel): + """ + General Chunk. + """ + + content: str + files: list[File] | None = None + + +class MultimodalGeneralStructureChunk(BaseModel): + """ + Multimodal General Structure Chunk. + """ + + general_chunks: list[GeneralChunk] + class GeneralStructureChunk(BaseModel): """ @@ -50,6 +83,7 @@ class ParentChildChunk(BaseModel): parent_content: str child_contents: list[str] + files: list[File] | None = None class ParentChildStructureChunk(BaseModel): diff --git a/api/core/rag/rerank/rerank_base.py b/api/core/rag/rerank/rerank_base.py index 3561def008..88acb75133 100644 --- a/api/core/rag/rerank/rerank_base.py +++ b/api/core/rag/rerank/rerank_base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document @@ -12,6 +13,7 @@ class BaseRerankRunner(ABC): score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ Run rerank model diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index e855b0083f..38309d3d77 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,6 +1,15 @@ -from core.model_manager import ModelInstance +import base64 + +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.rerank_entities import RerankResult +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.model import UploadFile class RerankModelRunner(BaseRerankRunner): @@ -14,6 +23,7 @@ class RerankModelRunner(BaseRerankRunner): score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ Run rerank model @@ -24,6 +34,56 @@ class RerankModelRunner(BaseRerankRunner): :param user: unique user id if needed :return: """ + model_manager = ModelManager() + is_support_vision = model_manager.check_model_support_vision( + tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id, + provider=self.rerank_model_instance.provider, + model=self.rerank_model_instance.model, + model_type=ModelType.RERANK, + ) + if not is_support_vision: + if query_type == QueryType.TEXT_QUERY: + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + else: + return documents + else: + rerank_result, unique_documents = self.fetch_multimodal_rerank( + query, documents, score_threshold, top_n, user, query_type + ) + + rerank_documents = [] + for result in rerank_result.docs: + if score_threshold is None or result.score >= score_threshold: + # format document + rerank_document = Document( + page_content=result.text, + metadata=unique_documents[result.index].metadata, + provider=unique_documents[result.index].provider, + ) + if rerank_document.metadata is not None: + rerank_document.metadata["score"] = result.score + rerank_documents.append(rerank_document) + + rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True) + return rerank_documents[:top_n] if top_n else rerank_documents + + def fetch_text_rerank( + self, + query: str, + documents: list[Document], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + ) -> tuple[RerankResult, list[Document]]: + """ + Fetch text rerank + :param query: search query + :param documents: documents for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id if needed + :return: + """ docs = [] doc_ids = set() unique_documents = [] @@ -33,33 +93,99 @@ class RerankModelRunner(BaseRerankRunner): and document.metadata is not None and document.metadata["doc_id"] not in doc_ids ): - doc_ids.add(document.metadata["doc_id"]) - docs.append(document.page_content) - unique_documents.append(document) + if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT: + doc_ids.add(document.metadata["doc_id"]) + docs.append(document.page_content) + unique_documents.append(document) elif document.provider == "external": if document not in unique_documents: docs.append(document.page_content) unique_documents.append(document) - documents = unique_documents - rerank_result = self.rerank_model_instance.invoke_rerank( query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user ) + return rerank_result, unique_documents - rerank_documents = [] + def fetch_multimodal_rerank( + self, + query: str, + documents: list[Document], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, + ) -> tuple[RerankResult, list[Document]]: + """ + Fetch multimodal rerank + :param query: search query + :param documents: documents for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id if needed + :param query_type: query type + :return: rerank result + """ + docs = [] + doc_ids = set() + unique_documents = [] + for document in documents: + if ( + document.provider == "dify" + and document.metadata is not None + and document.metadata["doc_id"] not in doc_ids + ): + if document.metadata.get("doc_type") == DocType.IMAGE: + # Query file info within db.session context to ensure thread-safe access + upload_file = ( + db.session.query(UploadFile).where(UploadFile.id == document.metadata["doc_id"]).first() + ) + if upload_file: + blob = storage.load_once(upload_file.key) + document_file_base64 = base64.b64encode(blob).decode() + document_file_dict = { + "content": document_file_base64, + "content_type": document.metadata["doc_type"], + } + docs.append(document_file_dict) + else: + document_text_dict = { + "content": document.page_content, + "content_type": document.metadata.get("doc_type") or DocType.TEXT, + } + docs.append(document_text_dict) + doc_ids.add(document.metadata["doc_id"]) + unique_documents.append(document) + elif document.provider == "external": + if document not in unique_documents: + docs.append( + { + "content": document.page_content, + "content_type": document.metadata.get("doc_type") or DocType.TEXT, + } + ) + unique_documents.append(document) - for result in rerank_result.docs: - if score_threshold is None or result.score >= score_threshold: - # format document - rerank_document = Document( - page_content=result.text, - metadata=documents[result.index].metadata, - provider=documents[result.index].provider, + documents = unique_documents + if query_type == QueryType.TEXT_QUERY: + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + return rerank_result, unique_documents + elif query_type == QueryType.IMAGE_QUERY: + # Query file info within db.session context to ensure thread-safe access + upload_file = db.session.query(UploadFile).where(UploadFile.id == query).first() + if upload_file: + blob = storage.load_once(upload_file.key) + file_query = base64.b64encode(blob).decode() + file_query_dict = { + "content": file_query, + "content_type": DocType.IMAGE, + } + rerank_result = self.rerank_model_instance.invoke_multimodal_rerank( + query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user ) - if rerank_document.metadata is not None: - rerank_document.metadata["score"] = result.score - rerank_documents.append(rerank_document) + return rerank_result, unique_documents + else: + raise ValueError(f"Upload file not found for query: {query}") - rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True) - return rerank_documents[:top_n] if top_n else rerank_documents + else: + raise ValueError(f"Query type {query_type} is not supported") diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index c455db6095..18020608cb 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -7,6 +7,8 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.embedding.cached_embedding import CacheEmbedding +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner @@ -24,6 +26,7 @@ class WeightRerankRunner(BaseRerankRunner): score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ Run rerank model @@ -43,8 +46,10 @@ class WeightRerankRunner(BaseRerankRunner): and document.metadata is not None and document.metadata["doc_id"] not in doc_ids ): - doc_ids.add(document.metadata["doc_id"]) - unique_documents.append(document) + # weight rerank only support text documents + if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT: + doc_ids.add(document.metadata["doc_id"]) + unique_documents.append(document) else: if document not in unique_documents: unique_documents.append(document) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 3db67efb0e..635eab73f0 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -8,6 +8,7 @@ from typing import Any, Union, cast from flask import Flask, current_app from sqlalchemy import and_, or_, select +from sqlalchemy.orm import Session from core.app.app_config.entities import ( DatasetEntity, @@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus +from core.file import File, FileTransferMethod, FileType from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage @@ -37,7 +39,9 @@ from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext from core.rag.entities.metadata_entities import Condition, MetadataCondition -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -52,10 +56,12 @@ from core.rag.retrieval.template_prompts import ( METADATA_FILTER_USER_PROMPT_2, METADATA_FILTER_USER_PROMPT_3, ) +from core.tools.signature import sign_upload_file from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from libs.json_in_md_parser import parse_and_check_json_markdown -from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment +from models import UploadFile +from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService @@ -99,7 +105,8 @@ class DatasetRetrieval: message_id: str, memory: TokenBufferMemory | None = None, inputs: Mapping[str, Any] | None = None, - ) -> str | None: + vision_enabled: bool = False, + ) -> tuple[str | None, list[File] | None]: """ Retrieve dataset. :param app_id: app_id @@ -118,7 +125,7 @@ class DatasetRetrieval: """ dataset_ids = config.dataset_ids if len(dataset_ids) == 0: - return None + return None, [] retrieve_config = config.retrieve_config # check model is support tool calling @@ -136,7 +143,7 @@ class DatasetRetrieval: ) if not model_schema: - return None + return None, [] planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features @@ -182,8 +189,8 @@ class DatasetRetrieval: tenant_id, user_id, user_from, - available_datasets, query, + available_datasets, model_instance, model_config, planning_strategy, @@ -213,6 +220,7 @@ class DatasetRetrieval: dify_documents = [item for item in all_documents if item.provider == "dify"] external_documents = [item for item in all_documents if item.provider == "external"] document_context_list: list[DocumentContext] = [] + context_files: list[File] = [] retrieval_resource_list: list[RetrievalSourceMetadata] = [] # deal with external documents for item in external_documents: @@ -248,6 +256,31 @@ class DatasetRetrieval: score=record.score, ) ) + if vision_enabled: + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.segment_id == segment.id, + ) + ).all() + if attachments_with_bindings: + for _, upload_file in attachments_with_bindings: + attachment_info = File( + id=upload_file.id, + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=segment.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + remote_url=upload_file.source_url, + related_id=upload_file.id, + size=upload_file.size, + storage_key=upload_file.key, + url=sign_upload_file(upload_file.id, upload_file.extension), + ) + context_files.append(attachment_info) if show_retrieve_source: for record in records: segment = record.segment @@ -288,8 +321,10 @@ class DatasetRetrieval: hit_callback.return_retriever_resource_info(retrieval_resource_list) if document_context_list: document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) - return str("\n".join([document_context.content for document_context in document_context_list])) - return "" + return str( + "\n".join([document_context.content for document_context in document_context_list]) + ), context_files + return "", context_files def single_retrieve( self, @@ -297,8 +332,8 @@ class DatasetRetrieval: tenant_id: str, user_id: str, user_from: str, - available_datasets: list, query: str, + available_datasets: list, model_instance: ModelInstance, model_config: ModelConfigWithCredentialsEntity, planning_strategy: PlanningStrategy, @@ -336,7 +371,7 @@ class DatasetRetrieval: dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance) self._record_usage(router_usage) - + timer = None if dataset_id: # get retrieval model config dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) @@ -406,10 +441,19 @@ class DatasetRetrieval: weights=retrieval_model_config.get("weights", None), document_ids_filter=document_ids_filter, ) - self._on_query(query, [dataset_id], app_id, user_from, user_id) + self._on_query(query, None, [dataset_id], app_id, user_from, user_id) if results: - self._on_retrieval_end(results, message_id, timer) + thread = threading.Thread( + target=self._on_retrieval_end, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "documents": results, + "message_id": message_id, + "timer": timer, + }, + ) + thread.start() return results return [] @@ -421,7 +465,7 @@ class DatasetRetrieval: user_id: str, user_from: str, available_datasets: list, - query: str, + query: str | None, top_k: int, score_threshold: float, reranking_mode: str, @@ -431,10 +475,11 @@ class DatasetRetrieval: message_id: str | None = None, metadata_filter_document_ids: dict[str, list[str]] | None = None, metadata_condition: MetadataCondition | None = None, + attachment_ids: list[str] | None = None, ): if not available_datasets: return [] - threads = [] + all_threads = [] all_documents: list[Document] = [] dataset_ids = [dataset.id for dataset in available_datasets] index_type_check = all( @@ -467,102 +512,187 @@ class DatasetRetrieval: 0 ].embedding_model_provider weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model - - for dataset in available_datasets: - index_type = dataset.indexing_technique - document_ids_filter = None - if dataset.provider != "external": - if metadata_condition and not metadata_filter_document_ids: - continue - if metadata_filter_document_ids: - document_ids = metadata_filter_document_ids.get(dataset.id, []) - if document_ids: - document_ids_filter = document_ids - else: - continue - retrieval_thread = threading.Thread( - target=self._retriever, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "dataset_id": dataset.id, - "query": query, - "top_k": top_k, - "all_documents": all_documents, - "document_ids_filter": document_ids_filter, - "metadata_condition": metadata_condition, - }, - ) - threads.append(retrieval_thread) - retrieval_thread.start() - for thread in threads: - thread.join() - with measure_time() as timer: - if reranking_enable: - # do rerank for searched documents - data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) - - all_documents = data_post_processor.invoke( - query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k + if query: + query_thread = threading.Thread( + target=self._multiple_retrieve_thread, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "available_datasets": available_datasets, + "metadata_condition": metadata_condition, + "metadata_filter_document_ids": metadata_filter_document_ids, + "all_documents": all_documents, + "tenant_id": tenant_id, + "reranking_enable": reranking_enable, + "reranking_mode": reranking_mode, + "reranking_model": reranking_model, + "weights": weights, + "top_k": top_k, + "score_threshold": score_threshold, + "query": query, + "attachment_id": None, + }, ) - else: - if index_type == "economy": - all_documents = self.calculate_keyword_score(query, all_documents, top_k) - elif index_type == "high_quality": - all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold) - else: - all_documents = all_documents[:top_k] if top_k else all_documents - - self._on_query(query, dataset_ids, app_id, user_from, user_id) + all_threads.append(query_thread) + query_thread.start() + if attachment_ids: + for attachment_id in attachment_ids: + attachment_thread = threading.Thread( + target=self._multiple_retrieve_thread, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "available_datasets": available_datasets, + "metadata_condition": metadata_condition, + "metadata_filter_document_ids": metadata_filter_document_ids, + "all_documents": all_documents, + "tenant_id": tenant_id, + "reranking_enable": reranking_enable, + "reranking_mode": reranking_mode, + "reranking_model": reranking_model, + "weights": weights, + "top_k": top_k, + "score_threshold": score_threshold, + "query": None, + "attachment_id": attachment_id, + }, + ) + all_threads.append(attachment_thread) + attachment_thread.start() + for thread in all_threads: + thread.join() + self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id) if all_documents: - self._on_retrieval_end(all_documents, message_id, timer) + # add thread to call _on_retrieval_end + retrieval_end_thread = threading.Thread( + target=self._on_retrieval_end, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "documents": all_documents, + "message_id": message_id, + "timer": timer, + }, + ) + retrieval_end_thread.start() + retrieval_resource_list = [] + doc_ids_filter = [] + for document in all_documents: + if document.provider == "dify": + doc_id = document.metadata.get("doc_id") + if doc_id and doc_id not in doc_ids_filter: + doc_ids_filter.append(doc_id) + retrieval_resource_list.append(document) + elif document.provider == "external": + retrieval_resource_list.append(document) + return retrieval_resource_list - return all_documents - - def _on_retrieval_end(self, documents: list[Document], message_id: str | None = None, timer: dict | None = None): + def _on_retrieval_end( + self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None + ): """Handle retrieval end.""" - dify_documents = [document for document in documents if document.provider == "dify"] - for document in dify_documents: - if document.metadata is not None: - dataset_document_stmt = select(DatasetDocument).where( - DatasetDocument.id == document.metadata["document_id"] - ) - dataset_document = db.session.scalar(dataset_document_stmt) - if dataset_document: - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - child_chunk_stmt = select(ChildChunk).where( - ChildChunk.index_node_id == document.metadata["doc_id"], - ChildChunk.dataset_id == dataset_document.dataset_id, - ChildChunk.document_id == dataset_document.id, - ) - child_chunk = db.session.scalar(child_chunk_stmt) - if child_chunk: - _ = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.id == child_chunk.segment_id) - .update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False, - ) - ) + with flask_app.app_context(): + dify_documents = [document for document in documents if document.provider == "dify"] + if not dify_documents: + self._send_trace_task(message_id, documents, timer) + return + + with Session(db.engine) as session: + # Collect all document_ids and batch fetch DatasetDocuments + document_ids = { + doc.metadata["document_id"] + for doc in dify_documents + if doc.metadata and "document_id" in doc.metadata + } + if not document_ids: + self._send_trace_task(message_id, documents, timer) + return + + dataset_docs_stmt = select(DatasetDocument).where(DatasetDocument.id.in_(document_ids)) + dataset_docs = session.scalars(dataset_docs_stmt).all() + dataset_doc_map = {str(doc.id): doc for doc in dataset_docs} + + # Categorize documents by type and collect necessary IDs + parent_child_text_docs: list[tuple[Document, DatasetDocument]] = [] + parent_child_image_docs: list[tuple[Document, DatasetDocument]] = [] + normal_text_docs: list[tuple[Document, DatasetDocument]] = [] + normal_image_docs: list[tuple[Document, DatasetDocument]] = [] + + for doc in dify_documents: + if not doc.metadata or "document_id" not in doc.metadata: + continue + dataset_doc = dataset_doc_map.get(doc.metadata["document_id"]) + if not dataset_doc: + continue + + is_image = doc.metadata.get("doc_type") == DocType.IMAGE + is_parent_child = dataset_doc.doc_form == IndexStructureType.PARENT_CHILD_INDEX + + if is_parent_child: + if is_image: + parent_child_image_docs.append((doc, dataset_doc)) + else: + parent_child_text_docs.append((doc, dataset_doc)) else: - query = db.session.query(DocumentSegment).where( - DocumentSegment.index_node_id == document.metadata["doc_id"] + if is_image: + normal_image_docs.append((doc, dataset_doc)) + else: + normal_text_docs.append((doc, dataset_doc)) + + segment_ids_to_update: set[str] = set() + + # Process PARENT_CHILD_INDEX text documents - batch fetch ChildChunks + if parent_child_text_docs: + index_node_ids = [doc.metadata["doc_id"] for doc, _ in parent_child_text_docs if doc.metadata] + if index_node_ids: + child_chunks_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(index_node_ids)) + child_chunks = session.scalars(child_chunks_stmt).all() + child_chunk_map = {chunk.index_node_id: chunk.segment_id for chunk in child_chunks} + for doc, _ in parent_child_text_docs: + if doc.metadata: + segment_id = child_chunk_map.get(doc.metadata["doc_id"]) + if segment_id: + segment_ids_to_update.add(str(segment_id)) + + # Process non-PARENT_CHILD_INDEX text documents - batch fetch DocumentSegments + if normal_text_docs: + index_node_ids = [doc.metadata["doc_id"] for doc, _ in normal_text_docs if doc.metadata] + if index_node_ids: + segments_stmt = select(DocumentSegment).where(DocumentSegment.index_node_id.in_(index_node_ids)) + segments = session.scalars(segments_stmt).all() + segment_map = {seg.index_node_id: seg.id for seg in segments} + for doc, _ in normal_text_docs: + if doc.metadata: + segment_id = segment_map.get(doc.metadata["doc_id"]) + if segment_id: + segment_ids_to_update.add(str(segment_id)) + + # Process IMAGE documents - batch fetch SegmentAttachmentBindings + all_image_docs = parent_child_image_docs + normal_image_docs + if all_image_docs: + attachment_ids = [ + doc.metadata["doc_id"] + for doc, _ in all_image_docs + if doc.metadata and doc.metadata.get("doc_id") + ] + if attachment_ids: + bindings_stmt = select(SegmentAttachmentBinding).where( + SegmentAttachmentBinding.attachment_id.in_(attachment_ids) ) + bindings = session.scalars(bindings_stmt).all() + segment_ids_to_update.update(str(binding.segment_id) for binding in bindings) - # if 'dataset_id' in document.metadata: - if "dataset_id" in document.metadata: - query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + # Batch update hit_count for all segments + if segment_ids_to_update: + session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update( + {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, + synchronize_session=False, + ) + session.commit() - # add hit count to document segment - query.update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False - ) + self._send_trace_task(message_id, documents, timer) - db.session.commit() - - # get tracing instance + def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict | None): + """Send trace task if trace manager is available.""" trace_manager: TraceQueueManager | None = ( self.application_generate_entity.trace_manager if self.application_generate_entity else None ) @@ -573,25 +703,40 @@ class DatasetRetrieval: ) ) - def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str): + def _on_query( + self, + query: str | None, + attachment_ids: list[str] | None, + dataset_ids: list[str], + app_id: str, + user_from: str, + user_id: str, + ): """ Handle query. """ - if not query: + if not query and not attachment_ids: return dataset_queries = [] for dataset_id in dataset_ids: - dataset_query = DatasetQuery( - dataset_id=dataset_id, - content=query, - source="app", - source_app_id=app_id, - created_by_role=user_from, - created_by=user_id, - ) - dataset_queries.append(dataset_query) - if dataset_queries: - db.session.add_all(dataset_queries) + contents = [] + if query: + contents.append({"content_type": QueryType.TEXT_QUERY, "content": query}) + if attachment_ids: + for attachment_id in attachment_ids: + contents.append({"content_type": QueryType.IMAGE_QUERY, "content": attachment_id}) + if contents: + dataset_query = DatasetQuery( + dataset_id=dataset_id, + content=json.dumps(contents), + source="app", + source_app_id=app_id, + created_by_role=user_from, + created_by=user_id, + ) + dataset_queries.append(dataset_query) + if dataset_queries: + db.session.add_all(dataset_queries) db.session.commit() def _retriever( @@ -603,6 +748,7 @@ class DatasetRetrieval: all_documents: list, document_ids_filter: list[str] | None = None, metadata_condition: MetadataCondition | None = None, + attachment_ids: list[str] | None = None, ): with flask_app.app_context(): dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) @@ -611,7 +757,7 @@ class DatasetRetrieval: if not dataset: return [] - if dataset.provider == "external": + if dataset.provider == "external" and query: external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( tenant_id=dataset.tenant_id, dataset_id=dataset_id, @@ -663,6 +809,7 @@ class DatasetRetrieval: reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", weights=retrieval_model.get("weights", None), document_ids_filter=document_ids_filter, + attachment_ids=attachment_ids, ) all_documents.extend(documents) @@ -1222,3 +1369,86 @@ class DatasetRetrieval: usage = LLMUsage.empty_usage() return full_text, usage + + def _multiple_retrieve_thread( + self, + flask_app: Flask, + available_datasets: list, + metadata_condition: MetadataCondition | None, + metadata_filter_document_ids: dict[str, list[str]] | None, + all_documents: list[Document], + tenant_id: str, + reranking_enable: bool, + reranking_mode: str, + reranking_model: dict | None, + weights: dict[str, Any] | None, + top_k: int, + score_threshold: float, + query: str | None, + attachment_id: str | None, + ): + with flask_app.app_context(): + threads = [] + all_documents_item: list[Document] = [] + index_type = None + for dataset in available_datasets: + index_type = dataset.indexing_technique + document_ids_filter = None + if dataset.provider != "external": + if metadata_condition and not metadata_filter_document_ids: + continue + if metadata_filter_document_ids: + document_ids = metadata_filter_document_ids.get(dataset.id, []) + if document_ids: + document_ids_filter = document_ids + else: + continue + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": flask_app, + "dataset_id": dataset.id, + "query": query, + "top_k": top_k, + "all_documents": all_documents_item, + "document_ids_filter": document_ids_filter, + "metadata_condition": metadata_condition, + "attachment_ids": [attachment_id] if attachment_id else None, + }, + ) + threads.append(retrieval_thread) + retrieval_thread.start() + for thread in threads: + thread.join() + + if reranking_enable: + # do rerank for searched documents + data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) + if query: + all_documents_item = data_post_processor.invoke( + query=query, + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.TEXT_QUERY, + ) + if attachment_id: + all_documents_item = data_post_processor.invoke( + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.IMAGE_QUERY, + query=attachment_id, + ) + else: + if index_type == IndexTechniqueType.ECONOMY: + if not query: + all_documents_item = [] + else: + all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k) + elif index_type == IndexTechniqueType.HIGH_QUALITY: + all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold) + else: + all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item + if all_documents_item: + all_documents.extend(all_documents_item) diff --git a/api/core/schemas/builtin/schemas/v1/multimodal_general_structure.json b/api/core/schemas/builtin/schemas/v1/multimodal_general_structure.json new file mode 100644 index 0000000000..1a07869662 --- /dev/null +++ b/api/core/schemas/builtin/schemas/v1/multimodal_general_structure.json @@ -0,0 +1,65 @@ +{ + "$id": "https://dify.ai/schemas/v1/multimodal_general_structure.json", + "$schema": "http://json-schema.org/draft-07/schema#", + "version": "1.0.0", + "type": "array", + "title": "Multimodal General Structure", + "description": "Schema for multimodal general structure (v1) - array of objects", + "properties": { + "general_chunks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "The content" + }, + "files": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "file name" + }, + "size": { + "type": "number", + "description": "file size" + }, + "extension": { + "type": "string", + "description": "file extension" + }, + "type": { + "type": "string", + "description": "file type" + }, + "mime_type": { + "type": "string", + "description": "file mime type" + }, + "transfer_method": { + "type": "string", + "description": "file transfer method" + }, + "url": { + "type": "string", + "description": "file url" + }, + "related_id": { + "type": "string", + "description": "file related id" + } + }, + "description": "List of files" + } + } + }, + "required": ["content"] + }, + "description": "List of content and files" + } + } +} \ No newline at end of file diff --git a/api/core/schemas/builtin/schemas/v1/multimodal_parent_child_structure.json b/api/core/schemas/builtin/schemas/v1/multimodal_parent_child_structure.json new file mode 100644 index 0000000000..4ffb590519 --- /dev/null +++ b/api/core/schemas/builtin/schemas/v1/multimodal_parent_child_structure.json @@ -0,0 +1,78 @@ +{ + "$id": "https://dify.ai/schemas/v1/multimodal_parent_child_structure.json", + "$schema": "http://json-schema.org/draft-07/schema#", + "version": "1.0.0", + "type": "object", + "title": "Multimodal Parent-Child Structure", + "description": "Schema for multimodal parent-child structure (v1)", + "properties": { + "parent_mode": { + "type": "string", + "description": "The mode of parent-child relationship" + }, + "parent_child_chunks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "parent_content": { + "type": "string", + "description": "The parent content" + }, + "files": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "file name" + }, + "size": { + "type": "number", + "description": "file size" + }, + "extension": { + "type": "string", + "description": "file extension" + }, + "type": { + "type": "string", + "description": "file type" + }, + "mime_type": { + "type": "string", + "description": "file mime type" + }, + "transfer_method": { + "type": "string", + "description": "file transfer method" + }, + "url": { + "type": "string", + "description": "file url" + }, + "related_id": { + "type": "string", + "description": "file related id" + } + }, + "required": ["name", "size", "extension", "type", "mime_type", "transfer_method", "url", "related_id"] + }, + "description": "List of files" + }, + "child_contents": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of child contents" + } + }, + "required": ["parent_content", "child_contents"] + }, + "description": "List of parent-child chunk pairs" + } + }, + "required": ["parent_mode", "parent_child_chunks"] +} \ No newline at end of file diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py index 5cdf473542..fef3157f27 100644 --- a/api/core/tools/signature.py +++ b/api/core/tools/signature.py @@ -25,6 +25,24 @@ def sign_tool_file(tool_file_id: str, extension: str) -> str: return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" +def sign_upload_file(upload_file_id: str, extension: str) -> str: + """ + sign file to get a temporary url for plugin access + """ + # Use internal URL for plugin/tool file access in Docker environments + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + file_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: """ verify signature diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 8f5fa7cab5..f8213d9fd7 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -5,7 +5,7 @@ import time from collections.abc import Generator, Mapping from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast import sqlalchemy as sa from sqlalchemy import select @@ -67,6 +67,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class ApiProviderControllerItem(TypedDict): + provider: ApiToolProvider + controller: ApiToolProviderController + + class ToolManager: _builtin_provider_lock = Lock() _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} @@ -655,9 +660,10 @@ class ToolManager: else: filters.append(typ) - with db.session.no_autoflush: + # Use a single session for all database operations to reduce connection overhead + with Session(db.engine) as session: if "builtin" in filters: - builtin_providers = cls.list_builtin_providers(tenant_id) + builtin_providers = list(cls.list_builtin_providers(tenant_id)) # key: provider name, value: provider db_builtin_providers = { @@ -688,57 +694,74 @@ class ToolManager: # get db api providers if "api" in filters: - db_api_providers = db.session.scalars( + db_api_providers = session.scalars( select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id) ).all() - api_provider_controllers: list[dict[str, Any]] = [ - {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} - for provider in db_api_providers - ] + # Batch create controllers + api_provider_controllers: list[ApiProviderControllerItem] = [] + for api_provider in db_api_providers: + try: + controller = ToolTransformService.api_provider_to_controller(api_provider) + api_provider_controllers.append({"provider": api_provider, "controller": controller}) + except Exception: + # Skip invalid providers but continue processing others + logger.warning("Failed to create controller for API provider %s", api_provider.id) - # get labels - labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers]) - - for api_provider_controller in api_provider_controllers: - user_provider = ToolTransformService.api_provider_to_user_provider( - provider_controller=api_provider_controller["controller"], - db_provider=api_provider_controller["provider"], - decrypt_credentials=False, - labels=labels.get(api_provider_controller["controller"].provider_id, []), + # Batch get labels for all API providers + if api_provider_controllers: + controllers = cast( + list[ToolProviderController], [item["controller"] for item in api_provider_controllers] ) - result_providers[f"api_provider.{user_provider.name}"] = user_provider + labels = ToolLabelManager.get_tools_labels(controllers) + + for item in api_provider_controllers: + provider_controller = item["controller"] + db_provider = item["provider"] + provider_labels = labels.get(provider_controller.provider_id, []) + user_provider = ToolTransformService.api_provider_to_user_provider( + provider_controller=provider_controller, + db_provider=db_provider, + decrypt_credentials=False, + labels=provider_labels, + ) + result_providers[f"api_provider.{user_provider.name}"] = user_provider if "workflow" in filters: # get workflow providers - workflow_providers = db.session.scalars( + workflow_providers = session.scalars( select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id) ).all() workflow_provider_controllers: list[WorkflowToolProviderController] = [] for workflow_provider in workflow_providers: try: - workflow_provider_controllers.append( + workflow_controller: WorkflowToolProviderController = ( ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) ) + workflow_provider_controllers.append(workflow_controller) except Exception: # app has been deleted - pass + logger.exception("Failed to transform workflow provider %s to controller", workflow_provider.id) + continue + # Batch get labels for workflow providers + if workflow_provider_controllers: + workflow_controllers: list[ToolProviderController] = [ + cast(ToolProviderController, controller) for controller in workflow_provider_controllers + ] + labels = ToolLabelManager.get_tools_labels(workflow_controllers) - labels = ToolLabelManager.get_tools_labels( - [cast(ToolProviderController, controller) for controller in workflow_provider_controllers] - ) + for workflow_provider_controller in workflow_provider_controllers: + provider_labels = labels.get(workflow_provider_controller.provider_id, []) + user_provider = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=workflow_provider_controller, + labels=provider_labels, + ) + result_providers[f"workflow_provider.{user_provider.name}"] = user_provider - for provider_controller in workflow_provider_controllers: - user_provider = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=provider_controller, - labels=labels.get(provider_controller.provider_id, []), - ) - result_providers[f"workflow_provider.{user_provider.name}"] = user_provider if "mcp" in filters: - with Session(db.engine) as session: - mcp_service = MCPToolManageService(session=session) - mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True) + mcp_service = MCPToolManageService(session=session) + mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True) for mcp_provider in mcp_providers: result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider diff --git a/api/core/tools/utils/text_processing_utils.py b/api/core/tools/utils/text_processing_utils.py index 105823f896..0f9a91a111 100644 --- a/api/core/tools/utils/text_processing_utils.py +++ b/api/core/tools/utils/text_processing_utils.py @@ -13,5 +13,5 @@ def remove_leading_symbols(text: str) -> str: """ # Match Unicode ranges for punctuation and symbols # FIXME this pattern is confused quick fix for #11868 maybe refactor it later - pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+" + pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+' return re.sub(pattern, "", text) diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index ef6913d0bd..ed3ed3e0de 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import Any, cast from urllib.parse import unquote -import chardet +import charset_normalizer import cloudscraper from readabilipy import simple_json_from_html_string @@ -69,9 +69,12 @@ def get_url(url: str, user_agent: str | None = None) -> str: if response.status_code != 200: return f"URL returned status code {response.status_code}." - # Detect encoding using chardet - detected_encoding = chardet.detect(response.content) - encoding = detected_encoding["encoding"] + # Detect encoding using charset_normalizer + detected_encoding = charset_normalizer.from_bytes(response.content).best() + if detected_encoding: + encoding = detected_encoding.encoding + else: + encoding = "utf-8" if encoding: try: content = response.content.decode(encoding) diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 4852e9d2d8..0439fb1d60 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -221,7 +221,7 @@ class WorkflowToolProviderController(ToolProviderController): session.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.app_id == self.provider_id, + WorkflowToolProvider.id == self.provider_id, ) .first() ) diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 1751b45d9b..30334f5da8 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -203,7 +203,7 @@ class WorkflowTool(Tool): Resolve user object in both HTTP and worker contexts. In HTTP context: dereference the current_user LocalProxy (can return Account or EndUser). - In worker context: load Account from database by user_id (only returns Account, never EndUser). + In worker context: load Account(knowledge pipeline) or EndUser(trigger) from database by user_id. Returns: Account | EndUser | None: The resolved user object, or None if resolution fails. @@ -224,24 +224,28 @@ class WorkflowTool(Tool): logger.warning("Failed to resolve user from request context: %s", e) return None - def _resolve_user_from_database(self, user_id: str) -> Account | None: + def _resolve_user_from_database(self, user_id: str) -> Account | EndUser | None: """ Resolve user from database (worker/Celery context). """ - user_stmt = select(Account).where(Account.id == user_id) - user = db.session.scalar(user_stmt) - if not user: - return None - tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id) tenant = db.session.scalar(tenant_stmt) if not tenant: return None - user.current_tenant = tenant + user_stmt = select(Account).where(Account.id == user_id) + user = db.session.scalar(user_stmt) + if user: + user.current_tenant = tenant + return user - return user + end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id) + end_user = db.session.scalar(end_user_stmt) + if end_user: + return end_user + + return None def _get_workflow(self, app_id: str, version: str) -> Workflow: """ diff --git a/api/core/workflow/graph_engine/event_management/event_manager.py b/api/core/workflow/graph_engine/event_management/event_manager.py index 71043b9a43..ae2e659543 100644 --- a/api/core/workflow/graph_engine/event_management/event_manager.py +++ b/api/core/workflow/graph_engine/event_management/event_manager.py @@ -2,6 +2,7 @@ Unified event manager for collecting and emitting events. """ +import logging import threading import time from collections.abc import Generator @@ -12,6 +13,8 @@ from core.workflow.graph_events import GraphEngineEvent from ..layers.base import GraphEngineLayer +_logger = logging.getLogger(__name__) + @final class ReadWriteLock: @@ -180,5 +183,4 @@ class EventManager: try: layer.on_event(event) except Exception: - # Silently ignore layer errors during collection - pass + _logger.exception("Error in layer on_event, layer_type=%s", type(layer)) diff --git a/api/core/workflow/graph_engine/manager.py b/api/core/workflow/graph_engine/manager.py index f05d43d8ad..0577ba8f02 100644 --- a/api/core/workflow/graph_engine/manager.py +++ b/api/core/workflow/graph_engine/manager.py @@ -6,12 +6,15 @@ using the new Redis command channel, without requiring user permission checks. Supports stop, pause, and resume operations. """ +import logging from typing import final from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand from extensions.ext_redis import redis_client +logger = logging.getLogger(__name__) + @final class GraphEngineManager: @@ -57,4 +60,4 @@ class GraphEngineManager: except Exception: # Silently fail if Redis is unavailable # The legacy control mechanisms will still work - pass + logger.exception("Failed to send graph engine command %s for task %s", command.__class__.__name__, task_id) diff --git a/api/core/workflow/node_events/node.py b/api/core/workflow/node_events/node.py index ebf93f2fc2..e4fa52f444 100644 --- a/api/core/workflow/node_events/node.py +++ b/api/core/workflow/node_events/node.py @@ -3,6 +3,7 @@ from datetime import datetime from pydantic import Field +from core.file import File from core.model_runtime.entities.llm_entities import LLMUsage from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.pause_reason import PauseReason @@ -14,6 +15,7 @@ from .base import NodeEventBase class RunRetrieverResourceEvent(NodeEventBase): retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") context: str = Field(..., description="context") + context_files: list[File] | None = Field(default=None, description="context files") class ModelInvokeCompletedEvent(NodeEventBase): diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index e816e16d74..5aab6bbde4 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -59,7 +59,7 @@ class OutputVariableEntity(BaseModel): """ variable: str - value_type: OutputVariableType + value_type: OutputVariableType = OutputVariableType.ANY value_selector: Sequence[str] @field_validator("value_type", mode="before") diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index f05c5f9873..14ebd1f9ae 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -7,7 +7,7 @@ import tempfile from collections.abc import Mapping, Sequence from typing import Any -import chardet +import charset_normalizer import docx import pandas as pd import pypandoc @@ -228,9 +228,12 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) def _extract_text_from_plain_text(file_content: bytes) -> str: try: - # Detect encoding using chardet - result = chardet.detect(file_content) - encoding = result["encoding"] + # Detect encoding using charset_normalizer + result = charset_normalizer.from_bytes(file_content, cp_isolation=["utf_8", "latin_1", "cp1252"]).best() + if result: + encoding = result.encoding + else: + encoding = "utf-8" # Fallback to utf-8 if detection fails if not encoding: @@ -247,9 +250,12 @@ def _extract_text_from_plain_text(file_content: bytes) -> str: def _extract_text_from_json(file_content: bytes) -> str: try: - # Detect encoding using chardet - result = chardet.detect(file_content) - encoding = result["encoding"] + # Detect encoding using charset_normalizer + result = charset_normalizer.from_bytes(file_content).best() + if result: + encoding = result.encoding + else: + encoding = "utf-8" # Fallback to utf-8 if detection fails if not encoding: @@ -269,9 +275,12 @@ def _extract_text_from_json(file_content: bytes) -> str: def _extract_text_from_yaml(file_content: bytes) -> str: """Extract the content from yaml file""" try: - # Detect encoding using chardet - result = chardet.detect(file_content) - encoding = result["encoding"] + # Detect encoding using charset_normalizer + result = charset_normalizer.from_bytes(file_content).best() + if result: + encoding = result.encoding + else: + encoding = "utf-8" # Fallback to utf-8 if detection fails if not encoding: @@ -424,9 +433,12 @@ def _extract_text_from_file(file: File): def _extract_text_from_csv(file_content: bytes) -> str: try: - # Detect encoding using chardet - result = chardet.detect(file_content) - encoding = result["encoding"] + # Detect encoding using charset_normalizer + result = charset_normalizer.from_bytes(file_content).best() + if result: + encoding = result.encoding + else: + encoding = "utf-8" # Fallback to utf-8 if detection fails if not encoding: diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 5a7db6e0e6..e323533835 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -3,6 +3,7 @@ from collections.abc import Sequence from email.message import Message from typing import Any, Literal +import charset_normalizer import httpx from pydantic import BaseModel, Field, ValidationInfo, field_validator @@ -96,10 +97,12 @@ class HttpRequestNodeData(BaseNodeData): class Response: headers: dict[str, str] response: httpx.Response + _cached_text: str | None def __init__(self, response: httpx.Response): self.response = response self.headers = dict(response.headers) + self._cached_text = None @property def is_file(self): @@ -159,7 +162,31 @@ class Response: @property def text(self) -> str: - return self.response.text + """ + Get response text with robust encoding detection. + + Uses charset_normalizer for better encoding detection than httpx's default, + which helps handle Chinese and other non-ASCII characters properly. + """ + # Check cache first + if hasattr(self, "_cached_text") and self._cached_text is not None: + return self._cached_text + + # Try charset_normalizer for robust encoding detection first + detected_encoding = charset_normalizer.from_bytes(self.response.content).best() + if detected_encoding and detected_encoding.encoding: + try: + text = self.response.content.decode(detected_encoding.encoding) + self._cached_text = text + return text + except (UnicodeDecodeError, TypeError, LookupError): + # Fallback to httpx's encoding detection if charset_normalizer fails + pass + + # Fallback to httpx's built-in encoding detection + text = self.response.text + self._cached_text = text + return text @property def content(self) -> bytes: diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 7b5b9c9e86..f0c84872fb 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -412,16 +412,20 @@ class Executor: body_string += f"--{boundary}\r\n" body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' # decode content safely - try: - body_string += content.decode("utf-8") - except UnicodeDecodeError: - body_string += content.decode("utf-8", errors="replace") - body_string += "\r\n" + # Do not decode binary content; use a placeholder with file metadata instead. + # Includes filename, size, and MIME type for better logging context. + body_string += ( + f"\r\n" + ) body_string += f"--{boundary}--\r\n" elif self.node_data.body: if self.content: + # If content is bytes, do not decode it; show a placeholder with size. + # Provides content size information for binary data without exposing the raw bytes. if isinstance(self.content, bytes): - body_string = self.content.decode("utf-8", errors="replace") + body_string = f"" else: body_string = self.content elif self.data and self.node_data.body.type == "x-www-form-urlencoded": diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 8aa6a5016f..86bb2495e7 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -114,7 +114,8 @@ class KnowledgeRetrievalNodeData(BaseNodeData): """ type: str = "knowledge-retrieval" - query_variable_selector: list[str] + query_variable_selector: list[str] | None | str = None + query_attachment_selector: list[str] | None | str = None dataset_ids: list[str] retrieval_mode: Literal["single", "multiple"] multiple_retrieval_config: MultipleRetrievalConfig | None = None diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 1b57d23e24..adc474bd60 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -25,6 +25,8 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables import ( + ArrayFileSegment, + FileSegment, StringSegment, ) from core.variables.segments import ArrayObjectSegment @@ -119,20 +121,41 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD return "1" def _run(self) -> NodeRunResult: - # extract variables - variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector) - if not isinstance(variable, StringSegment): + if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector: return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, - error="Query variable is not string type.", - ) - query = variable.value - variables = {"query": query} - if not query: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required." + process_data={}, + outputs={}, + metadata={}, + llm_usage=LLMUsage.empty_usage(), ) + variables: dict[str, Any] = {} + # extract variables + if self._node_data.query_variable_selector: + variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector) + if not isinstance(variable, StringSegment): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + error="Query variable is not string type.", + ) + query = variable.value + variables["query"] = query + + if self._node_data.query_attachment_selector: + variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_attachment_selector) + if not isinstance(variable, ArrayFileSegment) and not isinstance(variable, FileSegment): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + error="Attachments variable is not array file or file type.", + ) + if isinstance(variable, ArrayFileSegment): + variables["attachments"] = variable.value + else: + variables["attachments"] = [variable.value] + # TODO(-LAN-): Move this check outside. # check rate limit knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id) @@ -161,7 +184,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD # retrieve knowledge usage = LLMUsage.empty_usage() try: - results, usage = self._fetch_dataset_retriever(node_data=self.node_data, query=query) + results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables) outputs = {"result": ArrayObjectSegment(value=results)} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -198,12 +221,16 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD db.session.close() def _fetch_dataset_retriever( - self, node_data: KnowledgeRetrievalNodeData, query: str + self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any] ) -> tuple[list[dict[str, Any]], LLMUsage]: usage = LLMUsage.empty_usage() available_datasets = [] dataset_ids = node_data.dataset_ids - + query = variables.get("query") + attachments = variables.get("attachments") + metadata_filter_document_ids = None + metadata_condition = None + metadata_usage = LLMUsage.empty_usage() # Subquery: Count the number of available documents for each dataset subquery = ( db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count")) @@ -234,13 +261,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD if not dataset: continue available_datasets.append(dataset) - metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition( - [dataset.id for dataset in available_datasets], query, node_data - ) - usage = self._merge_usage(usage, metadata_usage) + if query: + metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition( + [dataset.id for dataset in available_datasets], query, node_data + ) + usage = self._merge_usage(usage, metadata_usage) all_documents = [] dataset_retrieval = DatasetRetrieval() - if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: + if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query: # fetch model config if node_data.single_retrieval_config is None: raise ValueError("single_retrieval_config is required") @@ -272,7 +300,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD metadata_filter_document_ids=metadata_filter_document_ids, metadata_condition=metadata_condition, ) - elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: + elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: if node_data.multiple_retrieval_config is None: raise ValueError("multiple_retrieval_config is required") if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": @@ -319,6 +347,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD reranking_enable=node_data.multiple_retrieval_config.reranking_enable, metadata_filter_document_ids=metadata_filter_document_ids, metadata_condition=metadata_condition, + attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None, ) usage = self._merge_usage(usage, dataset_retrieval.llm_usage) @@ -327,7 +356,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD retrieval_resource_list = [] # deal with external documents for item in external_documents: - source = { + source: dict[str, dict[str, str | Any | dict[Any, Any] | None] | Any | str | None] = { "metadata": { "_source": "knowledge", "dataset_id": item.metadata.get("dataset_id"), @@ -384,6 +413,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD "doc_metadata": document.doc_metadata, }, "title": document.name, + "files": list(record.files) if record.files else None, } if segment.answer: source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}" @@ -393,13 +423,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD if retrieval_resource_list: retrieval_resource_list = sorted( retrieval_resource_list, - key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0, + key=self._score, # type: ignore[arg-type, return-value] reverse=True, ) for position, item in enumerate(retrieval_resource_list, start=1): - item["metadata"]["position"] = position + item["metadata"]["position"] = position # type: ignore[index] return retrieval_resource_list, usage + def _score(self, item: dict[str, Any]) -> float: + meta = item.get("metadata") + if isinstance(meta, dict): + s = meta.get("score") + if isinstance(s, (int, float)): + return float(s) + return 0.0 + def _get_metadata_filter_condition( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]: @@ -659,7 +697,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data) variable_mapping = {} - variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector + if typed_node_data.query_variable_selector: + variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector + if typed_node_data.query_attachment_selector: + variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector return variable_mapping def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 1a2473e0bb..04e2802191 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -7,8 +7,10 @@ import time from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal +from sqlalchemy import select + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.file import FileType, file_manager +from core.file import File, FileTransferMethod, FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output @@ -44,6 +46,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.tools.signature import sign_upload_file from core.variables import ( ArrayFileSegment, ArraySegment, @@ -72,6 +75,9 @@ from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser from core.workflow.runtime import VariablePool +from extensions.ext_database import db +from models.dataset import SegmentAttachmentBinding +from models.model import UploadFile from . import llm_utils from .entities import ( @@ -179,12 +185,17 @@ class LLMNode(Node[LLMNodeData]): # fetch context value generator = self._fetch_context(node_data=self.node_data) context = None + context_files: list[File] = [] for event in generator: context = event.context + context_files = event.context_files or [] yield event if context: node_inputs["#context#"] = context + if context_files: + node_inputs["#context_files#"] = [file.model_dump() for file in context_files] + # fetch model config model_instance, model_config = LLMNode._fetch_model_config( node_data_model=self.node_data.model, @@ -220,6 +231,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool=variable_pool, jinja2_variables=self.node_data.prompt_config.jinja2_variables, tenant_id=self.tenant_id, + context_files=context_files, ) # handle invoke result @@ -322,6 +334,7 @@ class LLMNode(Node[LLMNodeData]): inputs=node_inputs, process_data=process_data, error_type=type(e).__name__, + llm_usage=usage, ) ) except Exception as e: @@ -332,6 +345,8 @@ class LLMNode(Node[LLMNodeData]): error=str(e), inputs=node_inputs, process_data=process_data, + error_type=type(e).__name__, + llm_usage=usage, ) ) @@ -654,10 +669,13 @@ class LLMNode(Node[LLMNodeData]): context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector) if context_value_variable: if isinstance(context_value_variable, StringSegment): - yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value) + yield RunRetrieverResourceEvent( + retriever_resources=[], context=context_value_variable.value, context_files=[] + ) elif isinstance(context_value_variable, ArraySegment): context_str = "" original_retriever_resource: list[RetrievalSourceMetadata] = [] + context_files: list[File] = [] for item in context_value_variable.value: if isinstance(item, str): context_str += item + "\n" @@ -670,9 +688,34 @@ class LLMNode(Node[LLMNodeData]): retriever_resource = self._convert_to_original_retriever_resource(item) if retriever_resource: original_retriever_resource.append(retriever_resource) - + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.segment_id == retriever_resource.segment_id, + ) + ).all() + if attachments_with_bindings: + for _, upload_file in attachments_with_bindings: + attachment_info = File( + id=upload_file.id, + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=self.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + remote_url=upload_file.source_url, + related_id=upload_file.id, + size=upload_file.size, + storage_key=upload_file.key, + url=sign_upload_file(upload_file.id, upload_file.extension), + ) + context_files.append(attachment_info) yield RunRetrieverResourceEvent( - retriever_resources=original_retriever_resource, context=context_str.strip() + retriever_resources=original_retriever_resource, + context=context_str.strip(), + context_files=context_files, ) def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None: @@ -700,6 +743,7 @@ class LLMNode(Node[LLMNodeData]): content=context_dict.get("content"), page=metadata.get("page"), doc_metadata=metadata.get("doc_metadata"), + files=context_dict.get("files"), ) return source @@ -741,6 +785,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], tenant_id: str, + context_files: list["File"] | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] @@ -853,6 +898,23 @@ class LLMNode(Node[LLMNodeData]): else: prompt_messages.append(UserPromptMessage(content=file_prompts)) + # The context_files + if vision_enabled and context_files: + file_prompts = [] + for file in context_files: + file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) + file_prompts.append(file_prompt) + # If last prompt is a user prompt, add files into its contents, + # otherwise append a new user prompt + if ( + len(prompt_messages) > 0 + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) + # Remove empty messages and filter unsupported content filtered_prompt_messages = [] for prompt_message in prompt_messages: diff --git a/api/core/workflow/nodes/node_factory.py b/api/core/workflow/nodes/node_factory.py index 5fc363257b..c55ad346bf 100644 --- a/api/core/workflow/nodes/node_factory.py +++ b/api/core/workflow/nodes/node_factory.py @@ -64,7 +64,10 @@ class DifyNodeFactory(NodeFactory): if not node_mapping: raise ValueError(f"No class mapping found for node type: {node_type}") - node_class = node_mapping.get(LATEST_VERSION) + latest_node_class = node_mapping.get(LATEST_VERSION) + node_version = str(node_data.get("version", "1")) + matched_node_class = node_mapping.get(node_version) + node_class = matched_node_class or latest_node_class if not node_class: raise ValueError(f"No latest version class found for node type: {node_type}") diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index db3d4d4aac..4a3e8e56f8 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -221,6 +221,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), + error_type=type(e).__name__, metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 6d2938771f..38effa79f7 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,3 +1,8 @@ +from typing import Any + +from jsonschema import Draft7Validator, ValidationError + +from core.app.app_config.entities import VariableEntityType from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult @@ -15,6 +20,7 @@ class StartNode(Node[StartNodeData]): def _run(self) -> NodeRunResult: node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) + self._validate_and_normalize_json_object_inputs(node_inputs) system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() # TODO: System variables should be directly accessible, no need for special handling @@ -24,3 +30,27 @@ class StartNode(Node[StartNodeData]): outputs = dict(node_inputs) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs) + + def _validate_and_normalize_json_object_inputs(self, node_inputs: dict[str, Any]) -> None: + for variable in self.node_data.variables: + if variable.type != VariableEntityType.JSON_OBJECT: + continue + + key = variable.variable + value = node_inputs.get(key) + + if value is None and variable.required: + raise ValueError(f"{key} is required in input form") + + if not isinstance(value, dict): + raise ValueError(f"{key} must be a JSON object") + + schema = variable.json_schema + if not schema: + continue + + try: + Draft7Validator(schema).validate(value) + except ValidationError as e: + raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}") + node_inputs[key] = value diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 1b44d8a1e2..bac2fbef47 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -1,9 +1,13 @@ +import logging + from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolEntity from events.app_event import app_draft_workflow_was_synced +logger = logging.getLogger(__name__) + @app_draft_workflow_was_synced.connect def handle(sender, **kwargs): @@ -30,6 +34,10 @@ def handle(sender, **kwargs): identity_id=f"WORKFLOW.{app.id}.{node_data.get('id')}", ) manager.delete_tool_parameters_cache() - except: + except Exception: # tool dose not exist - pass + logger.exception( + "Failed to delete tool parameters cache for workflow %s node %s", + app.id, + node_data.get("id"), + ) diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index e1c96fb050..84266ab0fa 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -256,7 +256,7 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation] now = datetime_utils.naive_utc_now() last_update = _get_last_update_timestamp(cache_key) - if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS: + if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS: # type: ignore update_values["last_used"] = values.last_used _set_last_update_timestamp(cache_key, now) diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index 44b50e42ee..725e5351e6 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -6,6 +6,7 @@ BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEAD SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization") AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN) FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN) +EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id") def init_app(app: DifyApp): @@ -25,6 +26,7 @@ def init_app(app: DifyApp): service_api_bp, allow_headers=list(SERVICE_API_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=list(EXPOSED_HEADERS), ) app.register_blueprint(service_api_bp) @@ -34,7 +36,7 @@ def init_app(app: DifyApp): supports_credentials=True, allow_headers=list(AUTHENTICATED_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], - expose_headers=["X-Version", "X-Env"], + expose_headers=list(EXPOSED_HEADERS), ) app.register_blueprint(web_bp) @@ -44,7 +46,7 @@ def init_app(app: DifyApp): supports_credentials=True, allow_headers=list(AUTHENTICATED_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], - expose_headers=["X-Version", "X-Env"], + expose_headers=list(EXPOSED_HEADERS), ) app.register_blueprint(console_app_bp) @@ -52,6 +54,7 @@ def init_app(app: DifyApp): files_bp, allow_headers=list(FILES_HEADERS), methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=list(EXPOSED_HEADERS), ) app.register_blueprint(files_bp) @@ -63,5 +66,6 @@ def init_app(app: DifyApp): trigger_bp, allow_headers=["Content-Type", "Authorization", "X-App-Code"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH", "HEAD"], + expose_headers=list(EXPOSED_HEADERS), ) app.register_blueprint(trigger_bp) diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py index 79d49aba5e..000d03ac41 100644 --- a/api/extensions/ext_logging.py +++ b/api/extensions/ext_logging.py @@ -7,6 +7,7 @@ from logging.handlers import RotatingFileHandler import flask from configs import dify_config +from core.helper.trace_id_helper import get_trace_id_from_otel_context from dify_app import DifyApp @@ -76,7 +77,9 @@ class RequestIdFilter(logging.Filter): # the logging format. Note that we're checking if we're in a request # context, as we may want to log things before Flask is fully loaded. def filter(self, record): + trace_id = get_trace_id_from_otel_context() or "" record.req_id = get_request_id() if flask.has_request_context() else "" + record.trace_id = trace_id return True @@ -84,6 +87,8 @@ class RequestIdFormatter(logging.Formatter): def format(self, record): if not hasattr(record, "req_id"): record.req_id = "" + if not hasattr(record, "trace_id"): + record.trace_id = "" return super().format(record) diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index 20ac2503a2..40a915e68c 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -1,148 +1,22 @@ import atexit -import contextlib import logging import os import platform import socket -import sys from typing import Union -import flask -from celery.signals import worker_init -from flask_login import user_loaded_from_request, user_logged_in - from configs import dify_config from dify_app import DifyApp -from libs.helper import extract_tenant_id -from models import Account, EndUser logger = logging.getLogger(__name__) -@user_logged_in.connect -@user_loaded_from_request.connect -def on_user_loaded(_sender, user: Union["Account", "EndUser"]): - if dify_config.ENABLE_OTEL: - from opentelemetry.trace import get_current_span - - if user: - try: - current_span = get_current_span() - tenant_id = extract_tenant_id(user) - if not tenant_id: - return - if current_span: - current_span.set_attribute("service.tenant.id", tenant_id) - current_span.set_attribute("service.user.id", user.id) - except Exception: - logger.exception("Error setting tenant and user attributes") - pass - - def init_app(app: DifyApp): - from opentelemetry.semconv.trace import SpanAttributes - - def is_celery_worker(): - return "celery" in sys.argv[0].lower() - - def instrument_exception_logging(): - exception_handler = ExceptionLoggingHandler() - logging.getLogger().addHandler(exception_handler) - - def init_flask_instrumentor(app: DifyApp): - meter = get_meter("http_metrics", version=dify_config.project.version) - _http_response_counter = meter.create_counter( - "http.server.response.count", - description="Total number of HTTP responses by status code, method and target", - unit="{response}", - ) - - def response_hook(span: Span, status: str, response_headers: list): - if span and span.is_recording(): - try: - if status.startswith("2"): - span.set_status(StatusCode.OK) - else: - span.set_status(StatusCode.ERROR, status) - - status = status.split(" ")[0] - status_code = int(status) - status_class = f"{status_code // 100}xx" - attributes: dict[str, str | int] = {"status_code": status_code, "status_class": status_class} - request = flask.request - if request and request.url_rule: - attributes[SpanAttributes.HTTP_TARGET] = str(request.url_rule.rule) - if request and request.method: - attributes[SpanAttributes.HTTP_METHOD] = str(request.method) - _http_response_counter.add(1, attributes) - except Exception: - logger.exception("Error setting status and attributes") - pass - - instrumentor = FlaskInstrumentor() - if dify_config.DEBUG: - logger.info("Initializing Flask instrumentor") - instrumentor.instrument_app(app, response_hook=response_hook) - - def init_sqlalchemy_instrumentor(app: DifyApp): - with app.app_context(): - engines = list(app.extensions["sqlalchemy"].engines.values()) - SQLAlchemyInstrumentor().instrument(enable_commenter=True, engines=engines) - - def setup_context_propagation(): - # Configure propagators - set_global_textmap( - CompositePropagator( - [ - TraceContextTextMapPropagator(), # W3C trace context - B3Format(), # B3 propagation (used by many systems) - ] - ) - ) - - def shutdown_tracer(): - provider = trace.get_tracer_provider() - if hasattr(provider, "force_flush"): - provider.force_flush() - - class ExceptionLoggingHandler(logging.Handler): - """Custom logging handler that creates spans for logging.exception() calls""" - - def emit(self, record: logging.LogRecord): - with contextlib.suppress(Exception): - if record.exc_info: - tracer = get_tracer_provider().get_tracer("dify.exception.logging") - with tracer.start_as_current_span( - "log.exception", - attributes={ - "log.level": record.levelname, - "log.message": record.getMessage(), - "log.logger": record.name, - "log.file.path": record.pathname, - "log.file.line": record.lineno, - }, - ) as span: - span.set_status(StatusCode.ERROR) - if record.exc_info[1]: - span.record_exception(record.exc_info[1]) - span.set_attribute("exception.message", str(record.exc_info[1])) - if record.exc_info[0]: - span.set_attribute("exception.type", record.exc_info[0].__name__) - - from opentelemetry import trace from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter - from opentelemetry.instrumentation.celery import CeleryInstrumentor - from opentelemetry.instrumentation.flask import FlaskInstrumentor - from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor - from opentelemetry.instrumentation.redis import RedisInstrumentor - from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor - from opentelemetry.metrics import get_meter, get_meter_provider, set_meter_provider - from opentelemetry.propagate import set_global_textmap - from opentelemetry.propagators.b3 import B3Format - from opentelemetry.propagators.composite import CompositePropagator + from opentelemetry.metrics import set_meter_provider from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics.export import ConsoleMetricExporter, PeriodicExportingMetricReader from opentelemetry.sdk.resources import Resource @@ -153,9 +27,10 @@ def init_app(app: DifyApp): ) from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio from opentelemetry.semconv.resource import ResourceAttributes - from opentelemetry.trace import Span, get_tracer_provider, set_tracer_provider - from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator - from opentelemetry.trace.status import StatusCode + from opentelemetry.trace import set_tracer_provider + + from extensions.otel.instrumentation import init_instruments + from extensions.otel.runtime import setup_context_propagation, shutdown_tracer setup_context_propagation() # Initialize OpenTelemetry @@ -177,6 +52,7 @@ def init_app(app: DifyApp): ) sampler = ParentBasedTraceIdRatio(dify_config.OTEL_SAMPLING_RATE) provider = TracerProvider(resource=resource, sampler=sampler) + set_tracer_provider(provider) exporter: Union[GRPCSpanExporter, HTTPSpanExporter, ConsoleSpanExporter] metric_exporter: Union[GRPCMetricExporter, HTTPMetricExporter, ConsoleMetricExporter] @@ -231,29 +107,11 @@ def init_app(app: DifyApp): export_timeout_millis=dify_config.OTEL_METRIC_EXPORT_TIMEOUT, ) set_meter_provider(MeterProvider(resource=resource, metric_readers=[reader])) - if not is_celery_worker(): - init_flask_instrumentor(app) - CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument() - instrument_exception_logging() - init_sqlalchemy_instrumentor(app) - RedisInstrumentor().instrument() - HTTPXClientInstrumentor().instrument() + + init_instruments(app) + atexit.register(shutdown_tracer) def is_enabled(): return dify_config.ENABLE_OTEL - - -@worker_init.connect(weak=False) -def init_celery_worker(*args, **kwargs): - if dify_config.ENABLE_OTEL: - from opentelemetry.instrumentation.celery import CeleryInstrumentor - from opentelemetry.metrics import get_meter_provider - from opentelemetry.trace import get_tracer_provider - - tracer_provider = get_tracer_provider() - metric_provider = get_meter_provider() - if dify_config.DEBUG: - logger.info("Initializing OpenTelemetry for Celery worker") - CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument() diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 588fbae285..5e75bc36b0 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -3,7 +3,7 @@ import logging import ssl from collections.abc import Callable from datetime import timedelta -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union import redis from redis import RedisError @@ -245,7 +245,12 @@ def init_app(app: DifyApp): app.extensions["redis"] = redis_client -def redis_fallback(default_return: Any | None = None): +P = ParamSpec("P") +R = TypeVar("R") +T = TypeVar("T") + + +def redis_fallback(default_return: T | None = None): # type: ignore """ decorator to handle Redis operation exceptions and return a default value when Redis is unavailable. @@ -253,9 +258,9 @@ def redis_fallback(default_return: Any | None = None): default_return: The value to return when a Redis operation fails. Defaults to None. """ - def decorator(func: Callable): + def decorator(func: Callable[P, R]): @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs): try: return func(*args, **kwargs) except RedisError as e: diff --git a/api/extensions/ext_request_logging.py b/api/extensions/ext_request_logging.py index f7263e18c4..8ea7b97f47 100644 --- a/api/extensions/ext_request_logging.py +++ b/api/extensions/ext_request_logging.py @@ -1,12 +1,14 @@ import json import logging +import time import flask import werkzeug.http -from flask import Flask +from flask import Flask, g from flask.signals import request_finished, request_started from configs import dify_config +from core.helper.trace_id_helper import get_trace_id_from_otel_context logger = logging.getLogger(__name__) @@ -20,6 +22,9 @@ def _is_content_type_json(content_type: str) -> bool: def _log_request_started(_sender, **_extra): """Log the start of a request.""" + # Record start time for access logging + g.__request_started_ts = time.perf_counter() + if not logger.isEnabledFor(logging.DEBUG): return @@ -42,8 +47,39 @@ def _log_request_started(_sender, **_extra): def _log_request_finished(_sender, response, **_extra): - """Log the end of a request.""" - if not logger.isEnabledFor(logging.DEBUG) or response is None: + """Log the end of a request. + + Safe to call with or without an active Flask request context. + """ + if response is None: + return + + # Always emit a compact access line at INFO with trace_id so it can be grepped + has_ctx = flask.has_request_context() + start_ts = getattr(g, "__request_started_ts", None) if has_ctx else None + duration_ms = None + if start_ts is not None: + duration_ms = round((time.perf_counter() - start_ts) * 1000, 3) + + # Request attributes are available only when a request context exists + if has_ctx: + req_method = flask.request.method + req_path = flask.request.path + else: + req_method = "-" + req_path = "-" + + trace_id = get_trace_id_from_otel_context() or response.headers.get("X-Trace-Id") or "" + logger.info( + "%s %s %s %s %s", + req_method, + req_path, + getattr(response, "status_code", "-"), + duration_ms if duration_ms is not None else "-", + trace_id, + ) + + if not logger.isEnabledFor(logging.DEBUG): return if not _is_content_type_json(response.content_type): diff --git a/api/extensions/otel/__init__.py b/api/extensions/otel/__init__.py new file mode 100644 index 0000000000..a431698d3d --- /dev/null +++ b/api/extensions/otel/__init__.py @@ -0,0 +1,11 @@ +from extensions.otel.decorators.base import trace_span +from extensions.otel.decorators.handler import SpanHandler +from extensions.otel.decorators.handlers.generate_handler import AppGenerateHandler +from extensions.otel.decorators.handlers.workflow_app_runner_handler import WorkflowAppRunnerHandler + +__all__ = [ + "AppGenerateHandler", + "SpanHandler", + "WorkflowAppRunnerHandler", + "trace_span", +] diff --git a/sdks/python-client/tests/__init__.py b/api/extensions/otel/decorators/__init__.py similarity index 100% rename from sdks/python-client/tests/__init__.py rename to api/extensions/otel/decorators/__init__.py diff --git a/api/extensions/otel/decorators/base.py b/api/extensions/otel/decorators/base.py new file mode 100644 index 0000000000..9604a3b6d5 --- /dev/null +++ b/api/extensions/otel/decorators/base.py @@ -0,0 +1,61 @@ +import functools +import os +from collections.abc import Callable +from typing import Any, TypeVar, cast + +from opentelemetry.trace import get_tracer + +from configs import dify_config +from extensions.otel.decorators.handler import SpanHandler + +T = TypeVar("T", bound=Callable[..., Any]) + +_HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()} + + +def _is_instrument_flag_enabled() -> bool: + """ + Check if external instrumentation is enabled via environment variable. + + Third-party non-invasive instrumentation agents set this flag to coordinate + with Dify's manual OpenTelemetry instrumentation. + """ + return os.getenv("ENABLE_OTEL_FOR_INSTRUMENT", "").strip().lower() == "true" + + +def _get_handler_instance(handler_class: type[SpanHandler]) -> SpanHandler: + """Get or create a singleton instance of the handler class.""" + if handler_class not in _HANDLER_INSTANCES: + _HANDLER_INSTANCES[handler_class] = handler_class() + return _HANDLER_INSTANCES[handler_class] + + +def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[T], T]: + """ + Decorator that traces a function with an OpenTelemetry span. + + The decorator uses the provided handler class to create a singleton handler instance + and delegates the wrapper implementation to that handler. + + :param handler_class: Optional handler class to use for this span. If None, uses the default SpanHandler. + """ + + def decorator(func: T) -> T: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if not (dify_config.ENABLE_OTEL or _is_instrument_flag_enabled()): + return func(*args, **kwargs) + + handler = _get_handler_instance(handler_class or SpanHandler) + tracer = get_tracer(__name__) + + return handler.wrapper( + tracer=tracer, + wrapped=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, wrapper) + + return decorator diff --git a/api/extensions/otel/decorators/handler.py b/api/extensions/otel/decorators/handler.py new file mode 100644 index 0000000000..1a7def5b0b --- /dev/null +++ b/api/extensions/otel/decorators/handler.py @@ -0,0 +1,95 @@ +import inspect +from collections.abc import Callable, Mapping +from typing import Any + +from opentelemetry.trace import SpanKind, Status, StatusCode + + +class SpanHandler: + """ + Base class for all span handlers. + + Each instrumentation point provides a handler implementation that fully controls + how spans are created, annotated, and finalized through the wrapper method. + + This class provides a default implementation that creates a basic span and handles + exceptions. Handlers can override the wrapper method to customize behavior. + """ + + _signature_cache: dict[Callable[..., Any], inspect.Signature] = {} + + def _build_span_name(self, wrapped: Callable[..., Any]) -> str: + """ + Build the span name from the wrapped function. + + Handlers can override this method to customize span name generation. + + :param wrapped: The original function being traced + :return: The span name + """ + return f"{wrapped.__module__}.{wrapped.__qualname__}" + + def _extract_arguments( + self, + wrapped: Callable[..., Any], + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + ) -> dict[str, Any] | None: + """ + Extract function arguments using inspect.signature. + + Returns a dictionary of bound arguments, or None if extraction fails. + Handlers can use this to safely extract parameters from args/kwargs. + + The function signature is cached to improve performance on repeated calls. + + :param wrapped: The function being traced + :param args: Positional arguments + :param kwargs: Keyword arguments + :return: Dictionary of bound arguments, or None if extraction fails + """ + try: + if wrapped not in self._signature_cache: + self._signature_cache[wrapped] = inspect.signature(wrapped) + + sig = self._signature_cache[wrapped] + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + return bound.arguments + except Exception: + return None + + def wrapper( + self, + tracer: Any, + wrapped: Callable[..., Any], + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + ) -> Any: + """ + Fully control the wrapper behavior. + + Default implementation creates a basic span and handles exceptions. + Handlers can override this method to provide complete control over: + - Span creation and configuration + - Attribute extraction + - Function invocation + - Exception handling + - Status setting + + :param tracer: OpenTelemetry tracer instance + :param wrapped: The original function being traced + :param args: Positional arguments (including self/cls if applicable) + :param kwargs: Keyword arguments + :return: Result of calling wrapped function + """ + span_name = self._build_span_name(wrapped) + with tracer.start_as_current_span(span_name, kind=SpanKind.INTERNAL) as span: + try: + result = wrapped(*args, **kwargs) + span.set_status(Status(StatusCode.OK)) + return result + except Exception as exc: + span.record_exception(exc) + span.set_status(Status(StatusCode.ERROR, str(exc))) + raise diff --git a/api/extensions/otel/decorators/handlers/__init__.py b/api/extensions/otel/decorators/handlers/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/extensions/otel/decorators/handlers/__init__.py @@ -0,0 +1 @@ + diff --git a/api/extensions/otel/decorators/handlers/generate_handler.py b/api/extensions/otel/decorators/handlers/generate_handler.py new file mode 100644 index 0000000000..63748a9824 --- /dev/null +++ b/api/extensions/otel/decorators/handlers/generate_handler.py @@ -0,0 +1,64 @@ +import logging +from collections.abc import Callable, Mapping +from typing import Any + +from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.util.types import AttributeValue + +from extensions.otel.decorators.handler import SpanHandler +from extensions.otel.semconv import DifySpanAttributes, GenAIAttributes +from models.model import Account + +logger = logging.getLogger(__name__) + + +class AppGenerateHandler(SpanHandler): + """Span handler for ``AppGenerateService.generate``.""" + + def wrapper( + self, + tracer: Any, + wrapped: Callable[..., Any], + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + ) -> Any: + try: + arguments = self._extract_arguments(wrapped, args, kwargs) + if not arguments: + return wrapped(*args, **kwargs) + + app_model = arguments.get("app_model") + user = arguments.get("user") + args_dict = arguments.get("args", {}) + streaming = arguments.get("streaming", True) + + if not app_model or not user or not isinstance(args_dict, dict): + return wrapped(*args, **kwargs) + app_id = getattr(app_model, "id", None) or "unknown" + tenant_id = getattr(app_model, "tenant_id", None) or "unknown" + user_id = getattr(user, "id", None) or "unknown" + workflow_id = args_dict.get("workflow_id") or "unknown" + + attributes: dict[str, AttributeValue] = { + DifySpanAttributes.APP_ID: app_id, + DifySpanAttributes.TENANT_ID: tenant_id, + GenAIAttributes.USER_ID: user_id, + DifySpanAttributes.USER_TYPE: "Account" if isinstance(user, Account) else "EndUser", + DifySpanAttributes.STREAMING: streaming, + DifySpanAttributes.WORKFLOW_ID: workflow_id, + } + + span_name = self._build_span_name(wrapped) + except Exception as exc: + logger.warning("Failed to prepare span attributes for AppGenerateService.generate: %s", exc, exc_info=True) + return wrapped(*args, **kwargs) + + with tracer.start_as_current_span(span_name, kind=SpanKind.INTERNAL, attributes=attributes) as span: + try: + result = wrapped(*args, **kwargs) + span.set_status(Status(StatusCode.OK)) + return result + except Exception as exc: + span.record_exception(exc) + span.set_status(Status(StatusCode.ERROR, str(exc))) + raise diff --git a/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py b/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py new file mode 100644 index 0000000000..8abd60197c --- /dev/null +++ b/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py @@ -0,0 +1,65 @@ +import logging +from collections.abc import Callable, Mapping +from typing import Any + +from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.util.types import AttributeValue + +from extensions.otel.decorators.handler import SpanHandler +from extensions.otel.semconv import DifySpanAttributes, GenAIAttributes + +logger = logging.getLogger(__name__) + + +class WorkflowAppRunnerHandler(SpanHandler): + """Span handler for ``WorkflowAppRunner.run``.""" + + def wrapper( + self, + tracer: Any, + wrapped: Callable[..., Any], + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + ) -> Any: + try: + arguments = self._extract_arguments(wrapped, args, kwargs) + if not arguments: + return wrapped(*args, **kwargs) + + runner = arguments.get("self") + if runner is None or not hasattr(runner, "application_generate_entity"): + return wrapped(*args, **kwargs) + + entity = runner.application_generate_entity + app_config = getattr(entity, "app_config", None) + if app_config is None: + return wrapped(*args, **kwargs) + + user_id: AttributeValue = getattr(entity, "user_id", None) or "unknown" + app_id: AttributeValue = getattr(app_config, "app_id", None) or "unknown" + tenant_id: AttributeValue = getattr(app_config, "tenant_id", None) or "unknown" + workflow_id: AttributeValue = getattr(app_config, "workflow_id", None) or "unknown" + streaming = getattr(entity, "stream", True) + + attributes: dict[str, AttributeValue] = { + DifySpanAttributes.APP_ID: app_id, + DifySpanAttributes.TENANT_ID: tenant_id, + GenAIAttributes.USER_ID: user_id, + DifySpanAttributes.STREAMING: streaming, + DifySpanAttributes.WORKFLOW_ID: workflow_id, + } + + span_name = self._build_span_name(wrapped) + except Exception as exc: + logger.warning("Failed to prepare span attributes for WorkflowAppRunner.run: %s", exc, exc_info=True) + return wrapped(*args, **kwargs) + + with tracer.start_as_current_span(span_name, kind=SpanKind.INTERNAL, attributes=attributes) as span: + try: + result = wrapped(*args, **kwargs) + span.set_status(Status(StatusCode.OK)) + return result + except Exception as exc: + span.record_exception(exc) + span.set_status(Status(StatusCode.ERROR, str(exc))) + raise diff --git a/api/extensions/otel/instrumentation.py b/api/extensions/otel/instrumentation.py new file mode 100644 index 0000000000..3597110cba --- /dev/null +++ b/api/extensions/otel/instrumentation.py @@ -0,0 +1,108 @@ +import contextlib +import logging + +import flask +from opentelemetry.instrumentation.celery import CeleryInstrumentor +from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor +from opentelemetry.instrumentation.redis import RedisInstrumentor +from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor +from opentelemetry.metrics import get_meter, get_meter_provider +from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.trace import Span, get_tracer_provider +from opentelemetry.trace.status import StatusCode + +from configs import dify_config +from dify_app import DifyApp +from extensions.otel.runtime import is_celery_worker + +logger = logging.getLogger(__name__) + + +class ExceptionLoggingHandler(logging.Handler): + def emit(self, record: logging.LogRecord): + with contextlib.suppress(Exception): + if record.exc_info: + tracer = get_tracer_provider().get_tracer("dify.exception.logging") + with tracer.start_as_current_span( + "log.exception", + attributes={ + "log.level": record.levelname, + "log.message": record.getMessage(), + "log.logger": record.name, + "log.file.path": record.pathname, + "log.file.line": record.lineno, + }, + ) as span: + span.set_status(StatusCode.ERROR) + if record.exc_info[1]: + span.record_exception(record.exc_info[1]) + span.set_attribute("exception.message", str(record.exc_info[1])) + if record.exc_info[0]: + span.set_attribute("exception.type", record.exc_info[0].__name__) + + +def instrument_exception_logging() -> None: + exception_handler = ExceptionLoggingHandler() + logging.getLogger().addHandler(exception_handler) + + +def init_flask_instrumentor(app: DifyApp) -> None: + meter = get_meter("http_metrics", version=dify_config.project.version) + _http_response_counter = meter.create_counter( + "http.server.response.count", + description="Total number of HTTP responses by status code, method and target", + unit="{response}", + ) + + def response_hook(span: Span, status: str, response_headers: list) -> None: + if span and span.is_recording(): + try: + if status.startswith("2"): + span.set_status(StatusCode.OK) + else: + span.set_status(StatusCode.ERROR, status) + + status = status.split(" ")[0] + status_code = int(status) + status_class = f"{status_code // 100}xx" + attributes: dict[str, str | int] = {"status_code": status_code, "status_class": status_class} + request = flask.request + if request and request.url_rule: + attributes[SpanAttributes.HTTP_TARGET] = str(request.url_rule.rule) + if request and request.method: + attributes[SpanAttributes.HTTP_METHOD] = str(request.method) + _http_response_counter.add(1, attributes) + except Exception: + logger.exception("Error setting status and attributes") + + from opentelemetry.instrumentation.flask import FlaskInstrumentor + + instrumentor = FlaskInstrumentor() + if dify_config.DEBUG: + logger.info("Initializing Flask instrumentor") + instrumentor.instrument_app(app, response_hook=response_hook) + + +def init_sqlalchemy_instrumentor(app: DifyApp) -> None: + with app.app_context(): + engines = list(app.extensions["sqlalchemy"].engines.values()) + SQLAlchemyInstrumentor().instrument(enable_commenter=True, engines=engines) + + +def init_redis_instrumentor() -> None: + RedisInstrumentor().instrument() + + +def init_httpx_instrumentor() -> None: + HTTPXClientInstrumentor().instrument() + + +def init_instruments(app: DifyApp) -> None: + if not is_celery_worker(): + init_flask_instrumentor(app) + CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument() + + instrument_exception_logging() + init_sqlalchemy_instrumentor(app) + init_redis_instrumentor() + init_httpx_instrumentor() diff --git a/api/extensions/otel/runtime.py b/api/extensions/otel/runtime.py new file mode 100644 index 0000000000..16f5ccf488 --- /dev/null +++ b/api/extensions/otel/runtime.py @@ -0,0 +1,73 @@ +import logging +import sys +from typing import Union + +from celery.signals import worker_init +from flask_login import user_loaded_from_request, user_logged_in +from opentelemetry import trace +from opentelemetry.propagate import set_global_textmap +from opentelemetry.propagators.b3 import B3Format +from opentelemetry.propagators.composite import CompositePropagator +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + +from configs import dify_config +from extensions.otel.semconv import DifySpanAttributes, GenAIAttributes +from libs.helper import extract_tenant_id +from models import Account, EndUser + +logger = logging.getLogger(__name__) + + +def setup_context_propagation() -> None: + set_global_textmap( + CompositePropagator( + [ + TraceContextTextMapPropagator(), + B3Format(), + ] + ) + ) + + +def shutdown_tracer() -> None: + provider = trace.get_tracer_provider() + if hasattr(provider, "force_flush"): + provider.force_flush() + + +def is_celery_worker(): + return "celery" in sys.argv[0].lower() + + +@user_logged_in.connect +@user_loaded_from_request.connect +def on_user_loaded(_sender, user: Union["Account", "EndUser"]): + if dify_config.ENABLE_OTEL: + from opentelemetry.trace import get_current_span + + if user: + try: + current_span = get_current_span() + tenant_id = extract_tenant_id(user) + if not tenant_id: + return + if current_span: + current_span.set_attribute(DifySpanAttributes.TENANT_ID, tenant_id) + current_span.set_attribute(GenAIAttributes.USER_ID, user.id) + except Exception: + logger.exception("Error setting tenant and user attributes") + pass + + +@worker_init.connect(weak=False) +def init_celery_worker(*args, **kwargs): + if dify_config.ENABLE_OTEL: + from opentelemetry.instrumentation.celery import CeleryInstrumentor + from opentelemetry.metrics import get_meter_provider + from opentelemetry.trace import get_tracer_provider + + tracer_provider = get_tracer_provider() + metric_provider = get_meter_provider() + if dify_config.DEBUG: + logger.info("Initializing OpenTelemetry for Celery worker") + CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument() diff --git a/api/extensions/otel/semconv/__init__.py b/api/extensions/otel/semconv/__init__.py new file mode 100644 index 0000000000..dc79dee222 --- /dev/null +++ b/api/extensions/otel/semconv/__init__.py @@ -0,0 +1,6 @@ +"""Semantic convention shortcuts for Dify-specific spans.""" + +from .dify import DifySpanAttributes +from .gen_ai import GenAIAttributes + +__all__ = ["DifySpanAttributes", "GenAIAttributes"] diff --git a/api/extensions/otel/semconv/dify.py b/api/extensions/otel/semconv/dify.py new file mode 100644 index 0000000000..a20b9b358d --- /dev/null +++ b/api/extensions/otel/semconv/dify.py @@ -0,0 +1,23 @@ +"""Dify-specific semantic convention definitions.""" + + +class DifySpanAttributes: + """Attribute names for Dify-specific spans.""" + + APP_ID = "dify.app_id" + """Application identifier.""" + + TENANT_ID = "dify.tenant_id" + """Tenant identifier.""" + + USER_TYPE = "dify.user_type" + """User type, e.g. Account, EndUser.""" + + STREAMING = "dify.streaming" + """Whether streaming response is enabled.""" + + WORKFLOW_ID = "dify.workflow_id" + """Workflow identifier.""" + + INVOKE_FROM = "dify.invoke_from" + """Invocation source, e.g. SERVICE_API, WEB_APP, DEBUGGER.""" diff --git a/api/extensions/otel/semconv/gen_ai.py b/api/extensions/otel/semconv/gen_ai.py new file mode 100644 index 0000000000..83c52ed34f --- /dev/null +++ b/api/extensions/otel/semconv/gen_ai.py @@ -0,0 +1,64 @@ +""" +GenAI semantic conventions. +""" + + +class GenAIAttributes: + """Common GenAI attribute keys.""" + + USER_ID = "gen_ai.user.id" + """Identifier of the end user in the application layer.""" + + FRAMEWORK = "gen_ai.framework" + """Framework type. Fixed to 'dify' in this project.""" + + SPAN_KIND = "gen_ai.span.kind" + """Operation type. Extended specification, not in OTel standard.""" + + +class ChainAttributes: + """Chain operation attribute keys.""" + + OPERATION_NAME = "gen_ai.operation.name" + """Secondary operation type, e.g. WORKFLOW, TASK.""" + + INPUT_VALUE = "input.value" + """Input content.""" + + OUTPUT_VALUE = "output.value" + """Output content.""" + + TIME_TO_FIRST_TOKEN = "gen_ai.user.time_to_first_token" + """Time to first token in nanoseconds from receiving the request to first token return.""" + + +class RetrieverAttributes: + """Retriever operation attribute keys.""" + + QUERY = "retrieval.query" + """Retrieval query string.""" + + DOCUMENT = "retrieval.document" + """Retrieved document list as JSON array.""" + + +class ToolAttributes: + """Tool operation attribute keys.""" + + TOOL_CALL_ID = "gen_ai.tool.call.id" + """Tool call identifier.""" + + TOOL_DESCRIPTION = "gen_ai.tool.description" + """Tool description.""" + + TOOL_NAME = "gen_ai.tool.name" + """Tool name.""" + + TOOL_TYPE = "gen_ai.tool.type" + """Tool type. Examples: function, extension, datastore.""" + + TOOL_CALL_ARGUMENTS = "gen_ai.tool.call.arguments" + """Tool invocation arguments.""" + + TOOL_CALL_RESULT = "gen_ai.tool.call.result" + """Tool invocation result.""" diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index dc5aa8e39c..51a97b20f8 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -199,9 +199,9 @@ class FileLifecycleManager: # Temporarily create basic metadata information except ValueError: continue - except: + except Exception: # If cannot scan version files, only return current version - pass + logger.exception("Failed to scan version files for %s", filename) return sorted(versions, key=lambda x: x.version or 0, reverse=True) diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index f7146adba6..a084844d72 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -87,7 +87,7 @@ class OpenDALStorage(BaseStorage): if not self.exists(path): raise FileNotFoundError("Path not found") - all_files = self.op.list(path=path) + all_files = self.op.scan(path=path) if files and directories: logger.debug("files and directories on %s scanned", path) return [f.path for f in all_files] diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 89c4d8fba9..1e5ec7d200 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -97,11 +97,27 @@ dataset_detail_fields = { "total_documents": fields.Integer, "total_available_documents": fields.Integer, "enable_api": fields.Boolean, + "is_multimodal": fields.Boolean, +} + +file_info_fields = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "source_url": fields.String, +} + +content_fields = { + "content_type": fields.String, + "content": fields.String, + "file_info": fields.Nested(file_info_fields, allow_null=True), } dataset_query_detail_fields = { "id": fields.String, - "content": fields.String, + "queries": fields.Nested(content_fields), "source": fields.String, "source_app_id": fields.String, "created_by_role": fields.String, diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index c12ebc09c8..a707500445 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -9,6 +9,8 @@ upload_config_fields = { "video_file_size_limit": fields.Integer, "audio_file_size_limit": fields.Integer, "workflow_file_upload_limit": fields.Integer, + "image_file_batch_limit": fields.Integer, + "single_chunk_attachment_limit": fields.Integer, } diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index 75bdff1803..e70f9fa722 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -43,9 +43,19 @@ child_chunk_fields = { "score": fields.Float, } +files_fields = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "source_url": fields.String, +} + hit_testing_record_fields = { "segment": fields.Nested(segment_fields), "child_chunks": fields.List(fields.Nested(child_chunk_fields)), "score": fields.Float, "tsne_position": fields.Raw, + "files": fields.List(fields.Nested(files_fields)), } diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 2ff917d6bc..56d6b68378 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -13,6 +13,15 @@ child_chunk_fields = { "updated_at": TimestampField, } +attachment_fields = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "source_url": fields.String, +} + segment_fields = { "id": fields.String, "position": fields.Integer, @@ -39,4 +48,5 @@ segment_fields = { "error": fields.String, "stopped_at": TimestampField, "child_chunks": fields.List(fields.Nested(child_chunk_fields)), + "attachments": fields.List(fields.Nested(attachment_fields)), } diff --git a/api/libs/helper.py b/api/libs/helper.py index 1013c3b878..abc81d1fde 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -10,12 +10,13 @@ import uuid from collections.abc import Generator, Mapping from datetime import datetime from hashlib import sha256 -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast from zoneinfo import available_timezones from flask import Response, stream_with_context from flask_restx import fields from pydantic import BaseModel +from pydantic.functional_validators import AfterValidator from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator @@ -103,7 +104,10 @@ def email(email): raise ValueError(error) -def uuid_value(value): +EmailStr = Annotated[str, AfterValidator(email)] + + +def uuid_value(value: Any) -> str: if value == "": return str(value) @@ -211,7 +215,11 @@ def generate_text_hash(text: str) -> str: def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response: if isinstance(response, dict): - return Response(response=json.dumps(jsonable_encoder(response)), status=200, mimetype="application/json") + return Response( + response=json.dumps(jsonable_encoder(response)), + status=200, + content_type="application/json; charset=utf-8", + ) else: def generate() -> Generator: diff --git a/api/libs/token.py b/api/libs/token.py index 098ff958da..a34db70764 100644 --- a/api/libs/token.py +++ b/api/libs/token.py @@ -189,6 +189,11 @@ def build_force_logout_cookie_headers() -> list[str]: def check_csrf_token(request: Request, user_id: str): # some apis are sent by beacon, so we need to bypass csrf token check # since these APIs are post, they are already protected by SameSite: Lax, so csrf is not required. + if dify_config.ADMIN_API_KEY_ENABLE: + auth_token = extract_access_token(request) + if auth_token and auth_token == dify_config.ADMIN_API_KEY: + return + def _unauthorized(): raise Unauthorized("CSRF token is missing or invalid.") diff --git a/api/migrations/versions/2025_11_12_1537-d57accd375ae_support_multi_modal.py b/api/migrations/versions/2025_11_12_1537-d57accd375ae_support_multi_modal.py new file mode 100644 index 0000000000..187bf7136d --- /dev/null +++ b/api/migrations/versions/2025_11_12_1537-d57accd375ae_support_multi_modal.py @@ -0,0 +1,57 @@ +"""support-multi-modal + +Revision ID: d57accd375ae +Revises: 03f8dcbc611e +Create Date: 2025-11-12 15:37:12.363670 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'd57accd375ae' +down_revision = '7bb281b7a422' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('segment_attachment_bindings', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('segment_id', models.types.StringUUID(), nullable=False), + sa.Column('attachment_id', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.PrimaryKeyConstraint('id', name='segment_attachment_binding_pkey') + ) + with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op: + batch_op.create_index( + 'segment_attachment_binding_tenant_dataset_document_segment_idx', + ['tenant_id', 'dataset_id', 'document_id', 'segment_id'], + unique=False + ) + batch_op.create_index('segment_attachment_binding_attachment_idx', ['attachment_id'], unique=False) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('is_multimodal', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.drop_column('is_multimodal') + + + with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op: + batch_op.drop_index('segment_attachment_binding_attachment_idx') + batch_op.drop_index('segment_attachment_binding_tenant_dataset_document_segment_idx') + + op.drop_table('segment_attachment_bindings') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_11_15_2102-09cfdda155d1_mysql_adaptation.py b/api/migrations/versions/2025_11_15_2102-09cfdda155d1_mysql_adaptation.py index a3f6c3cb19..877fa2f309 100644 --- a/api/migrations/versions/2025_11_15_2102-09cfdda155d1_mysql_adaptation.py +++ b/api/migrations/versions/2025_11_15_2102-09cfdda155d1_mysql_adaptation.py @@ -1,4 +1,4 @@ -"""empty message +"""mysql adaptation Revision ID: 09cfdda155d1 Revises: 669ffd70119c @@ -97,11 +97,31 @@ def downgrade(): batch_op.alter_column('include_plugins', existing_type=sa.JSON(), type_=postgresql.ARRAY(sa.VARCHAR(length=255)), - existing_nullable=False) + existing_nullable=False, + postgresql_using=""" + COALESCE( + regexp_replace( + replace(replace(include_plugins::text, '[', '{'), ']', '}'), + '"', + '', + 'g' + )::varchar(255)[], + ARRAY[]::varchar(255)[] + )""") batch_op.alter_column('exclude_plugins', existing_type=sa.JSON(), type_=postgresql.ARRAY(sa.VARCHAR(length=255)), - existing_nullable=False) + existing_nullable=False, + postgresql_using=""" + COALESCE( + regexp_replace( + replace(replace(exclude_plugins::text, '[', '{'), ']', '}'), + '"', + '', + 'g' + )::varchar(255)[], + ARRAY[]::varchar(255)[] + )""") with op.batch_alter_table('external_knowledge_bindings', schema=None) as batch_op: batch_op.alter_column('external_knowledge_id', diff --git a/api/models/dataset.py b/api/models/dataset.py index e072711b82..ba2eaf6749 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -19,7 +19,9 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource +from core.rag.index_processor.constant.query_type import QueryType from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.tools.signature import sign_upload_file from extensions.ext_storage import storage from libs.uuid_utils import uuidv7 from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule @@ -76,6 +78,7 @@ class Dataset(Base): pipeline_id = mapped_column(StringUUID, nullable=True) chunk_structure = mapped_column(sa.String(255), nullable=True) enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + is_multimodal = mapped_column(sa.Boolean, default=False, nullable=False, server_default=db.text("false")) @property def total_documents(self): @@ -728,9 +731,7 @@ class DocumentSegment(Base): created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() - ) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) error = mapped_column(LongText, nullable=True) @@ -866,6 +867,47 @@ class DocumentSegment(Base): return text + @property + def attachments(self) -> list[dict[str, Any]]: + # Use JOIN to fetch attachments in a single query instead of two separate queries + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.tenant_id == self.tenant_id, + SegmentAttachmentBinding.dataset_id == self.dataset_id, + SegmentAttachmentBinding.document_id == self.document_id, + SegmentAttachmentBinding.segment_id == self.id, + ) + ).all() + if not attachments_with_bindings: + return [] + attachment_list = [] + for _, attachment in attachments_with_bindings: + upload_file_id = attachment.id + nonce = os.urandom(16).hex() + timestamp = str(int(time.time())) + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + reference_url = dify_config.CONSOLE_API_URL or "" + base_url = f"{reference_url}/files/{upload_file_id}/image-preview" + source_url = f"{base_url}?{params}" + attachment_list.append( + { + "id": attachment.id, + "name": attachment.name, + "size": attachment.size, + "extension": attachment.extension, + "mime_type": attachment.mime_type, + "source_url": source_url, + } + ) + return attachment_list + class ChildChunk(Base): __tablename__ = "child_chunks" @@ -963,6 +1005,38 @@ class DatasetQuery(TypeBase): DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False ) + @property + def queries(self) -> list[dict[str, Any]]: + try: + queries = json.loads(self.content) + if isinstance(queries, list): + for query in queries: + if query["content_type"] == QueryType.IMAGE_QUERY: + file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first() + if file_info: + query["file_info"] = { + "id": file_info.id, + "name": file_info.name, + "size": file_info.size, + "extension": file_info.extension, + "mime_type": file_info.mime_type, + "source_url": sign_upload_file(file_info.id, file_info.extension), + } + else: + query["file_info"] = None + + return queries + else: + return [queries] + except JSONDecodeError: + return [ + { + "content_type": QueryType.TEXT_QUERY, + "content": self.content, + "file_info": None, + } + ] + class DatasetKeywordTable(TypeBase): __tablename__ = "dataset_keyword_tables" @@ -1470,3 +1544,25 @@ class PipelineRecommendedPlugin(TypeBase): onupdate=func.current_timestamp(), init=False, ) + + +class SegmentAttachmentBinding(Base): + __tablename__ = "segment_attachment_bindings" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="segment_attachment_binding_pkey"), + sa.Index( + "segment_attachment_binding_tenant_dataset_document_segment_idx", + "tenant_id", + "dataset_id", + "document_id", + "segment_id", + ), + sa.Index("segment_attachment_binding_attachment_idx", "attachment_id"), + ) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7())) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/model.py b/api/models/model.py index 1731ff5699..88cb945b3f 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -111,7 +111,11 @@ class App(Base): else: app_model_config = self.app_model_config if app_model_config: - return app_model_config.pre_prompt + pre_prompt = app_model_config.pre_prompt or "" + # Truncate to 200 characters with ellipsis if using prompt as description + if len(pre_prompt) > 200: + return pre_prompt[:200] + "..." + return pre_prompt else: return "" @@ -259,7 +263,7 @@ class App(Base): provider_id = tool.get("provider_id", "") if provider_type == ToolProviderType.API: - if uuid.UUID(provider_id) not in existing_api_providers: + if provider_id not in existing_api_providers: deleted_tools.append( { "type": ToolProviderType.API, @@ -835,7 +839,29 @@ class Conversation(Base): @property def status_count(self): - messages = db.session.scalars(select(Message).where(Message.conversation_id == self.id)).all() + from models.workflow import WorkflowRun + + # Get all messages with workflow_run_id for this conversation + messages = db.session.scalars( + select(Message).where(Message.conversation_id == self.id, Message.workflow_run_id.isnot(None)) + ).all() + + if not messages: + return None + + # Batch load all workflow runs in a single query, filtered by this conversation's app_id + workflow_run_ids = [msg.workflow_run_id for msg in messages if msg.workflow_run_id] + workflow_runs = {} + + if workflow_run_ids: + workflow_runs_query = db.session.scalars( + select(WorkflowRun).where( + WorkflowRun.id.in_(workflow_run_ids), + WorkflowRun.app_id == self.app_id, # Filter by this conversation's app_id + ) + ).all() + workflow_runs = {run.id: run for run in workflow_runs_query} + status_counts = { WorkflowExecutionStatus.RUNNING: 0, WorkflowExecutionStatus.SUCCEEDED: 0, @@ -845,18 +871,24 @@ class Conversation(Base): } for message in messages: - if message.workflow_run: - status_counts[WorkflowExecutionStatus(message.workflow_run.status)] += 1 + # Guard against None to satisfy type checker and avoid invalid dict lookups + if message.workflow_run_id is None: + continue + workflow_run = workflow_runs.get(message.workflow_run_id) + if not workflow_run: + continue - return ( - { - "success": status_counts[WorkflowExecutionStatus.SUCCEEDED], - "failed": status_counts[WorkflowExecutionStatus.FAILED], - "partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED], - } - if messages - else None - ) + try: + status_counts[WorkflowExecutionStatus(workflow_run.status)] += 1 + except (ValueError, KeyError): + # Handle invalid status values gracefully + pass + + return { + "success": status_counts[WorkflowExecutionStatus.SUCCEEDED], + "failed": status_counts[WorkflowExecutionStatus.FAILED], + "partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED], + } @property def first_message(self): @@ -1255,13 +1287,9 @@ class Message(Base): "id": self.id, "app_id": self.app_id, "conversation_id": self.conversation_id, - "model_provider": self.model_provider, "model_id": self.model_id, "inputs": self.inputs, "query": self.query, - "message_tokens": self.message_tokens, - "answer_tokens": self.answer_tokens, - "provider_response_latency": self.provider_response_latency, "total_price": self.total_price, "message": self.message, "answer": self.answer, @@ -1283,12 +1311,8 @@ class Message(Base): id=data["id"], app_id=data["app_id"], conversation_id=data["conversation_id"], - model_provider=data.get("model_provider"), model_id=data["model_id"], inputs=data["inputs"], - message_tokens=data.get("message_tokens", 0), - answer_tokens=data.get("answer_tokens", 0), - provider_response_latency=data.get("provider_response_latency", 0.0), total_price=data["total_price"], query=data["query"], message=data["message"], diff --git a/api/models/workflow.py b/api/models/workflow.py index 42ee8a1f2b..853d5afefc 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -907,19 +907,29 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo @property def extras(self) -> dict[str, Any]: from core.tools.tool_manager import ToolManager + from core.trigger.trigger_manager import TriggerManager extras: dict[str, Any] = {} - if self.execution_metadata_dict: - if self.node_type == NodeType.TOOL and "tool_info" in self.execution_metadata_dict: - tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"] + execution_metadata = self.execution_metadata_dict + if execution_metadata: + if self.node_type == NodeType.TOOL and "tool_info" in execution_metadata: + tool_info: dict[str, Any] = execution_metadata["tool_info"] extras["icon"] = ToolManager.get_tool_icon( tenant_id=self.tenant_id, provider_type=tool_info["provider_type"], provider_id=tool_info["provider_id"], ) - elif self.node_type == NodeType.DATASOURCE and "datasource_info" in self.execution_metadata_dict: - datasource_info = self.execution_metadata_dict["datasource_info"] + elif self.node_type == NodeType.DATASOURCE and "datasource_info" in execution_metadata: + datasource_info = execution_metadata["datasource_info"] extras["icon"] = datasource_info.get("icon") + elif self.node_type == NodeType.TRIGGER_PLUGIN and "trigger_info" in execution_metadata: + trigger_info = execution_metadata["trigger_info"] or {} + provider_id = trigger_info.get("provider_id") + if provider_id: + extras["icon"] = TriggerManager.get_trigger_plugin_icon( + tenant_id=self.tenant_id, + provider_id=provider_id, + ) return extras def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]: diff --git a/api/pyproject.toml b/api/pyproject.toml index d28ba91413..092b5ab9f9 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.10.1" +version = "1.11.1" requires-python = ">=3.11,<3.13" dependencies = [ @@ -11,7 +11,7 @@ dependencies = [ "bs4~=0.0.1", "cachetools~=5.3.0", "celery~=5.5.2", - "chardet~=5.1.0", + "charset-normalizer>=3.4.4", "flask~=3.1.2", "flask-compress>=1.17,<1.18", "flask-cors~=6.0.0", @@ -91,6 +91,7 @@ dependencies = [ "weaviate-client==4.17.0", "apscheduler>=3.11.0", "weave>=0.52.16", + "jsonschema>=4.25.1", ] # Before adding new dependency, consider place it in # alphabet order (a-z) and suitable group. @@ -132,7 +133,7 @@ dev = [ "types-jsonschema~=4.23.0", "types-flask-cors~=5.0.0", "types-flask-migrate~=4.1.0", - "types-gevent~=24.11.0", + "types-gevent~=25.9.0", "types-greenlet~=3.1.0", "types-html5lib~=1.1.11", "types-markdown~=3.7.0", @@ -150,7 +151,7 @@ dev = [ "types-pywin32~=310.0.0", "types-pyyaml~=6.0.12", "types-regex~=2024.11.6", - "types-shapely~=2.0.0", + "types-shapely~=2.1.0", "types-simplejson>=3.20.0", "types-six>=1.17.0", "types-tensorflow>=2.18.0", diff --git a/api/pyrefly.toml b/api/pyrefly.toml new file mode 100644 index 0000000000..80ffba019d --- /dev/null +++ b/api/pyrefly.toml @@ -0,0 +1,10 @@ +project-includes = ["."] +project-excludes = [ + "tests/", + ".venv", + "migrations/", + "core/rag", +] +python-platform = "linux" +python-version = "3.11.0" +infer-with-first-use = false diff --git a/api/services/account_service.py b/api/services/account_service.py index ac6d1bde77..5a549dc318 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1259,7 +1259,7 @@ class RegisterService: return f"member_invite:token:{token}" @classmethod - def setup(cls, email: str, name: str, password: str, ip_address: str, language: str): + def setup(cls, email: str, name: str, password: str, ip_address: str, language: str | None): """ Setup dify @@ -1267,6 +1267,7 @@ class RegisterService: :param name: username :param password: password :param ip_address: ip address + :param language: language """ try: account = AccountService.create_account( @@ -1414,7 +1415,7 @@ class RegisterService: return data is not None @classmethod - def revoke_token(cls, workspace_id: str, email: str, token: str): + def revoke_token(cls, workspace_id: str | None, email: str | None, token: str): if workspace_id and email: email_hash = sha256(email.encode()).hexdigest() cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" @@ -1423,7 +1424,9 @@ class RegisterService: redis_client.delete(cls._get_invitation_token_key(token)) @classmethod - def get_invitation_if_token_valid(cls, workspace_id: str | None, email: str, token: str) -> dict[str, Any] | None: + def get_invitation_if_token_valid( + cls, workspace_id: str | None, email: str | None, token: str + ) -> dict[str, Any] | None: invitation_data = cls.get_invitation_by_token(token, workspace_id, email) if not invitation_data: return None diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index dc85929b98..4514c86f7c 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -11,6 +11,7 @@ from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting import RateLimit from enums.quota_type import QuotaType, unlimited +from extensions.otel import AppGenerateHandler, trace_span from models.model import Account, App, AppMode, EndUser from models.workflow import Workflow from services.errors.app import InvokeRateLimitError, QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError @@ -19,6 +20,7 @@ from services.workflow_service import WorkflowService class AppGenerateService: @classmethod + @trace_span(AppGenerateHandler) def generate( cls, app_model: App, diff --git a/api/services/app_service.py b/api/services/app_service.py index 5f8c5089c9..ef89a4fd10 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -211,7 +211,7 @@ class AppService: # override tool parameters tool["tool_parameters"] = masked_parameter except Exception: - pass + logger.exception("Failed to mask agent tool parameters for tool %s", agent_tool_entity.tool_name) # override agent mode if model_config: diff --git a/api/services/attachment_service.py b/api/services/attachment_service.py new file mode 100644 index 0000000000..2bd5627d5e --- /dev/null +++ b/api/services/attachment_service.py @@ -0,0 +1,31 @@ +import base64 + +from sqlalchemy import Engine +from sqlalchemy.orm import sessionmaker +from werkzeug.exceptions import NotFound + +from extensions.ext_storage import storage +from models.model import UploadFile + +PREVIEW_WORDS_LIMIT = 3000 + + +class AttachmentService: + _session_maker: sessionmaker + + def __init__(self, session_factory: sessionmaker | Engine | None = None): + if isinstance(session_factory, Engine): + self._session_maker = sessionmaker(bind=session_factory) + elif isinstance(session_factory, sessionmaker): + self._session_maker = session_factory + else: + raise AssertionError("must be a sessionmaker or an Engine.") + + def get_file_base64(self, file_id: str) -> str: + upload_file = ( + self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first() + ) + if not upload_file: + raise NotFound("File not found") + blob = storage.load_once(upload_file.key) + return base64.b64encode(blob).decode() diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 39d6c81621..5253199552 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -118,7 +118,7 @@ class ConversationService: app_model: App, conversation_id: str, user: Union[Account, EndUser] | None, - name: str, + name: str | None, auto_generate: bool, ): conversation = cls.get_conversation(app_model, conversation_id, user) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 2bec61963c..8097a6daa0 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -7,9 +7,10 @@ import time import uuid from collections import Counter from collections.abc import Sequence -from typing import Any, Literal +from typing import Any, Literal, cast import sqlalchemy as sa +from redis.exceptions import LockNotOwnedError from sqlalchemy import exists, func, select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -18,9 +19,10 @@ from configs import dify_config from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.helper.name_generator import generate_incremental_name from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod from enums.cloud_plan import CloudPlan from events.dataset_event import dataset_was_deleted @@ -45,12 +47,14 @@ from models.dataset import ( DocumentSegment, ExternalKnowledgeBindings, Pipeline, + SegmentAttachmentBinding, ) from models.model import UploadFile from models.provider_ids import ModelProviderID from models.source import DataSourceOauthBinding from models.workflow import Workflow -from services.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.duplicate_document_indexing_task_proxy import DuplicateDocumentIndexingTaskProxy from services.entities.knowledge_entities.knowledge_entities import ( ChildChunkUpdateArgs, KnowledgeConfig, @@ -81,7 +85,6 @@ from tasks.delete_segment_from_index_task import delete_segment_from_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task from tasks.disable_segments_from_index_task import disable_segments_from_index_task from tasks.document_indexing_update_task import document_indexing_update_task -from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task from tasks.enable_segments_to_index_task import enable_segments_to_index_task from tasks.recover_document_indexing_task import recover_document_indexing_task from tasks.remove_document_from_index_task import remove_document_from_index_task @@ -362,6 +365,27 @@ class DatasetService: except ProviderTokenNotInitError as ex: raise ValueError(ex.description) + @staticmethod + def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str): + try: + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + provider=model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=model, + ) + text_embedding_model = cast(TextEmbeddingModel, model_instance.model_type_instance) + model_schema = text_embedding_model.get_model_schema(model_instance.model, model_instance.credentials) + if not model_schema: + raise ValueError("Model schema not found") + if model_schema.features and ModelFeature.VISION in model_schema.features: + return True + else: + return False + except LLMBadRequestError: + raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.") + @staticmethod def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str): try: @@ -401,13 +425,13 @@ class DatasetService: if not dataset: raise ValueError("Dataset not found") # check if dataset name is exists - - if DatasetService._has_dataset_same_name( - tenant_id=dataset.tenant_id, - dataset_id=dataset_id, - name=data.get("name", dataset.name), - ): - raise ValueError("Dataset name already exists") + if data.get("name") and data.get("name") != dataset.name: + if DatasetService._has_dataset_same_name( + tenant_id=dataset.tenant_id, + dataset_id=dataset_id, + name=data.get("name", dataset.name), + ): + raise ValueError("Dataset name already exists") # Verify user has permission to update this dataset DatasetService.check_dataset_permission(dataset, user) @@ -649,6 +673,8 @@ class DatasetService: Returns: str: Action to perform ('add', 'remove', 'update', or None) """ + if "indexing_technique" not in data: + return None if dataset.indexing_technique != data["indexing_technique"]: if data["indexing_technique"] == "economy": # Remove embedding model configuration for economy mode @@ -843,6 +869,12 @@ class DatasetService: model_type=ModelType.TEXT_EMBEDDING, model=knowledge_configuration.embedding_model or "", ) + is_multimodal = DatasetService.check_is_multimodal_model( + current_user.current_tenant_id, + knowledge_configuration.embedding_model_provider, + knowledge_configuration.embedding_model, + ) + dataset.is_multimodal = is_multimodal dataset.embedding_model = embedding_model.model dataset.embedding_model_provider = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( @@ -879,6 +911,12 @@ class DatasetService: dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( embedding_model.provider, embedding_model.model ) + is_multimodal = DatasetService.check_is_multimodal_model( + current_user.current_tenant_id, + knowledge_configuration.embedding_model_provider, + knowledge_configuration.embedding_model, + ) + dataset.is_multimodal = is_multimodal dataset.collection_binding_id = dataset_collection_binding.id dataset.indexing_technique = knowledge_configuration.indexing_technique except LLMBadRequestError: @@ -936,6 +974,12 @@ class DatasetService: ) ) dataset.collection_binding_id = dataset_collection_binding.id + is_multimodal = DatasetService.check_is_multimodal_model( + current_user.current_tenant_id, + knowledge_configuration.embedding_model_provider, + knowledge_configuration.embedding_model, + ) + dataset.is_multimodal = is_multimodal except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -1592,45 +1636,61 @@ class DocumentService: return [], "" db.session.add(dataset_process_rule) db.session.flush() - lock_name = f"add_document_lock_dataset_id_{dataset.id}" - with redis_client.lock(lock_name, timeout=600): - assert dataset_process_rule - position = DocumentService.get_documents_position(dataset.id) - document_ids = [] - duplicate_document_ids = [] - if knowledge_config.data_source.info_list.data_source_type == "upload_file": - if not knowledge_config.data_source.info_list.file_info_list: - raise ValueError("File source info is required") - upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids - for file_id in upload_file_list: - file = ( - db.session.query(UploadFile) - .where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) - .first() + else: + # Fallback when no process_rule provided in knowledge_config: + # 1) reuse dataset.latest_process_rule if present + # 2) otherwise create an automatic rule + dataset_process_rule = getattr(dataset, "latest_process_rule", None) + if not dataset_process_rule: + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode="automatic", + rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + created_by=account.id, ) - - # raise error if file not found - if not file: - raise FileNotExistsError() - - file_name = file.name - data_source_info: dict[str, str | bool] = { - "upload_file_id": file_id, - } - # check duplicate - if knowledge_config.duplicate: - document = ( - db.session.query(Document) - .filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="upload_file", - enabled=True, - name=file_name, - ) - .first() + db.session.add(dataset_process_rule) + db.session.flush() + lock_name = f"add_document_lock_dataset_id_{dataset.id}" + try: + with redis_client.lock(lock_name, timeout=600): + assert dataset_process_rule + position = DocumentService.get_documents_position(dataset.id) + document_ids = [] + duplicate_document_ids = [] + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + if not knowledge_config.data_source.info_list.file_info_list: + raise ValueError("File source info is required") + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids + files = ( + db.session.query(UploadFile) + .where( + UploadFile.tenant_id == dataset.tenant_id, + UploadFile.id.in_(upload_file_list), ) - if document: + .all() + ) + if len(files) != len(set(upload_file_list)): + raise FileNotExistsError("One or more files not found.") + + file_names = [file.name for file in files] + db_documents = ( + db.session.query(Document) + .where( + Document.dataset_id == dataset.id, + Document.tenant_id == current_user.current_tenant_id, + Document.data_source_type == "upload_file", + Document.enabled == True, + Document.name.in_(file_names), + ) + .all() + ) + documents_map = {document.name: document for document in db_documents} + for file in files: + data_source_info: dict[str, str | bool] = { + "upload_file_id": file.id, + } + document = documents_map.get(file.name) + if knowledge_config.duplicate and document: document.dataset_process_rule_id = dataset_process_rule.id document.updated_at = naive_utc_now() document.created_from = created_from @@ -1643,58 +1703,7 @@ class DocumentService: documents.append(document) duplicate_document_ids.append(document.id) continue - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, - knowledge_config.data_source.info_list.data_source_type, - knowledge_config.doc_form, - knowledge_config.doc_language, - data_source_info, - created_from, - position, - account, - file_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - elif knowledge_config.data_source.info_list.data_source_type == "notion_import": - notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore - if not notion_info_list: - raise ValueError("No notion info list found.") - exist_page_ids = [] - exist_document = {} - documents = ( - db.session.query(Document) - .filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="notion_import", - enabled=True, - ) - .all() - ) - if documents: - for document in documents: - data_source_info = json.loads(document.data_source_info) - exist_page_ids.append(data_source_info["notion_page_id"]) - exist_document[data_source_info["notion_page_id"]] = document.id - for notion_info in notion_info_list: - workspace_id = notion_info.workspace_id - for page in notion_info.pages: - if page.page_id not in exist_page_ids: - data_source_info = { - "credential_id": notion_info.credential_id, - "notion_workspace_id": workspace_id, - "notion_page_id": page.page_id, - "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore - "type": page.type, - } - # Truncate page name to 255 characters to prevent DB field length errors - truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" + else: document = DocumentService.build_document( dataset, dataset_process_rule.id, @@ -1705,7 +1714,7 @@ class DocumentService: created_from, position, account, - truncated_page_name, + file.name, batch, ) db.session.add(document) @@ -1713,53 +1722,109 @@ class DocumentService: document_ids.append(document.id) documents.append(document) position += 1 - else: - exist_document.pop(page.page_id) - # delete not selected documents - if len(exist_document) > 0: - clean_notion_document_task.delay(list(exist_document.values()), dataset.id) - elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": - website_info = knowledge_config.data_source.info_list.website_info_list - if not website_info: - raise ValueError("No website info list found.") - urls = website_info.urls - for url in urls: - data_source_info = { - "url": url, - "provider": website_info.provider, - "job_id": website_info.job_id, - "only_main_content": website_info.only_main_content, - "mode": "crawl", - } - if len(url) > 255: - document_name = url[:200] + "..." - else: - document_name = url - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, - knowledge_config.data_source.info_list.data_source_type, - knowledge_config.doc_form, - knowledge_config.doc_language, - data_source_info, - created_from, - position, - account, - document_name, - batch, + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore + if not notion_info_list: + raise ValueError("No notion info list found.") + exist_page_ids = [] + exist_document = {} + documents = ( + db.session.query(Document) + .filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="notion_import", + enabled=True, + ) + .all() ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - db.session.commit() + if documents: + for document in documents: + data_source_info = json.loads(document.data_source_info) + exist_page_ids.append(data_source_info["notion_page_id"]) + exist_document[data_source_info["notion_page_id"]] = document.id + for notion_info in notion_info_list: + workspace_id = notion_info.workspace_id + for page in notion_info.pages: + if page.page_id not in exist_page_ids: + data_source_info = { + "credential_id": notion_info.credential_id, + "notion_workspace_id": workspace_id, + "notion_page_id": page.page_id, + "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore + "type": page.type, + } + # Truncate page name to 255 characters to prevent DB field length errors + truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, + knowledge_config.data_source.info_list.data_source_type, + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + position, + account, + truncated_page_name, + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + else: + exist_document.pop(page.page_id) + # delete not selected documents + if len(exist_document) > 0: + clean_notion_document_task.delay(list(exist_document.values()), dataset.id) + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list + if not website_info: + raise ValueError("No website info list found.") + urls = website_info.urls + for url in urls: + data_source_info = { + "url": url, + "provider": website_info.provider, + "job_id": website_info.job_id, + "only_main_content": website_info.only_main_content, + "mode": "crawl", + } + if len(url) > 255: + document_name = url[:200] + "..." + else: + document_name = url + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, + knowledge_config.data_source.info_list.data_source_type, + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + position, + account, + document_name, + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + db.session.commit() - # trigger async task - if document_ids: - DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay() - if duplicate_document_ids: - duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + # trigger async task + if document_ids: + DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay() + if duplicate_document_ids: + DuplicateDocumentIndexingTaskProxy( + dataset.tenant_id, dataset.id, duplicate_document_ids + ).delay() + except LockNotOwnedError: + pass return documents, batch @@ -2299,6 +2364,7 @@ class DocumentService: embedding_model_provider=knowledge_config.embedding_model_provider, collection_binding_id=dataset_collection_binding_id, retrieval_model=retrieval_model.model_dump() if retrieval_model else None, + is_multimodal=knowledge_config.is_multimodal, ) db.session.add(dataset) @@ -2679,6 +2745,13 @@ class SegmentService: if "content" not in args or not args["content"] or not args["content"].strip(): raise ValueError("Content is empty") + if args.get("attachment_ids"): + if not isinstance(args["attachment_ids"], list): + raise ValueError("Attachment IDs is invalid") + single_chunk_attachment_limit = dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT + if len(args["attachment_ids"]) > single_chunk_attachment_limit: + raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}") + @classmethod def create_segment(cls, args: dict, document: Document, dataset: Dataset): assert isinstance(current_user, Account) @@ -2699,30 +2772,31 @@ class SegmentService: # calc embedding use tokens tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] lock_name = f"add_segment_lock_document_id_{document.id}" - with redis_client.lock(lock_name, timeout=600): - max_position = ( - db.session.query(func.max(DocumentSegment.position)) - .where(DocumentSegment.document_id == document.id) - .scalar() - ) - segment_document = DocumentSegment( - tenant_id=current_user.current_tenant_id, - dataset_id=document.dataset_id, - document_id=document.id, - index_node_id=doc_id, - index_node_hash=segment_hash, - position=max_position + 1 if max_position else 1, - content=content, - word_count=len(content), - tokens=tokens, - status="completed", - indexing_at=naive_utc_now(), - completed_at=naive_utc_now(), - created_by=current_user.id, - ) - if document.doc_form == "qa_model": - segment_document.word_count += len(args["answer"]) - segment_document.answer = args["answer"] + try: + with redis_client.lock(lock_name, timeout=600): + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .where(DocumentSegment.document_id == document.id) + .scalar() + ) + segment_document = DocumentSegment( + tenant_id=current_user.current_tenant_id, + dataset_id=document.dataset_id, + document_id=document.id, + index_node_id=doc_id, + index_node_hash=segment_hash, + position=max_position + 1 if max_position else 1, + content=content, + word_count=len(content), + tokens=tokens, + status="completed", + indexing_at=naive_utc_now(), + completed_at=naive_utc_now(), + created_by=current_user.id, + ) + if document.doc_form == "qa_model": + segment_document.word_count += len(args["answer"]) + segment_document.answer = args["answer"] db.session.add(segment_document) # update document word count @@ -2731,18 +2805,34 @@ class SegmentService: db.session.add(document) db.session.commit() - # save vector index - try: - VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset, document.doc_form) - except Exception as e: - logger.exception("create segment index failed") - segment_document.enabled = False - segment_document.disabled_at = naive_utc_now() - segment_document.status = "error" - segment_document.error = str(e) + if args["attachment_ids"]: + for attachment_id in args["attachment_ids"]: + binding = SegmentAttachmentBinding( + tenant_id=current_user.current_tenant_id, + dataset_id=document.dataset_id, + document_id=document.id, + segment_id=segment_document.id, + attachment_id=attachment_id, + ) + db.session.add(binding) db.session.commit() - segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first() - return segment + + # save vector index + try: + VectorService.create_segments_vector( + [args["keywords"]], [segment_document], dataset, document.doc_form + ) + except Exception as e: + logger.exception("create segment index failed") + segment_document.enabled = False + segment_document.disabled_at = naive_utc_now() + segment_document.status = "error" + segment_document.error = str(e) + db.session.commit() + segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first() + return segment + except LockNotOwnedError: + pass @classmethod def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): @@ -2751,84 +2841,89 @@ class SegmentService: lock_name = f"multi_add_segment_lock_document_id_{document.id}" increment_word_count = 0 - with redis_client.lock(lock_name, timeout=600): - embedding_model = None - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, + try: + with redis_client.lock(lock_name, timeout=600): + embedding_model = None + if dataset.indexing_technique == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .where(DocumentSegment.document_id == document.id) + .scalar() ) - max_position = ( - db.session.query(func.max(DocumentSegment.position)) - .where(DocumentSegment.document_id == document.id) - .scalar() - ) - pre_segment_data_list = [] - segment_data_list = [] - keywords_list = [] - position = max_position + 1 if max_position else 1 - for segment_item in segments: - content = segment_item["content"] - doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content) - tokens = 0 - if dataset.indexing_technique == "high_quality" and embedding_model: - # calc embedding use tokens + pre_segment_data_list = [] + segment_data_list = [] + keywords_list = [] + position = max_position + 1 if max_position else 1 + for segment_item in segments: + content = segment_item["content"] + doc_id = str(uuid.uuid4()) + segment_hash = helper.generate_text_hash(content) + tokens = 0 + if dataset.indexing_technique == "high_quality" and embedding_model: + # calc embedding use tokens + if document.doc_form == "qa_model": + tokens = embedding_model.get_text_embedding_num_tokens( + texts=[content + segment_item["answer"]] + )[0] + else: + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] + + segment_document = DocumentSegment( + tenant_id=current_user.current_tenant_id, + dataset_id=document.dataset_id, + document_id=document.id, + index_node_id=doc_id, + index_node_hash=segment_hash, + position=position, + content=content, + word_count=len(content), + tokens=tokens, + keywords=segment_item.get("keywords", []), + status="completed", + indexing_at=naive_utc_now(), + completed_at=naive_utc_now(), + created_by=current_user.id, + ) if document.doc_form == "qa_model": - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content + segment_item["answer"]] - )[0] + segment_document.answer = segment_item["answer"] + segment_document.word_count += len(segment_item["answer"]) + increment_word_count += segment_document.word_count + db.session.add(segment_document) + segment_data_list.append(segment_document) + position += 1 + + pre_segment_data_list.append(segment_document) + if "keywords" in segment_item: + keywords_list.append(segment_item["keywords"]) else: - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] - - segment_document = DocumentSegment( - tenant_id=current_user.current_tenant_id, - dataset_id=document.dataset_id, - document_id=document.id, - index_node_id=doc_id, - index_node_hash=segment_hash, - position=position, - content=content, - word_count=len(content), - tokens=tokens, - keywords=segment_item.get("keywords", []), - status="completed", - indexing_at=naive_utc_now(), - completed_at=naive_utc_now(), - created_by=current_user.id, - ) - if document.doc_form == "qa_model": - segment_document.answer = segment_item["answer"] - segment_document.word_count += len(segment_item["answer"]) - increment_word_count += segment_document.word_count - db.session.add(segment_document) - segment_data_list.append(segment_document) - position += 1 - - pre_segment_data_list.append(segment_document) - if "keywords" in segment_item: - keywords_list.append(segment_item["keywords"]) - else: - keywords_list.append(None) - # update document word count - assert document.word_count is not None - document.word_count += increment_word_count - db.session.add(document) - try: - # save vector index - VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset, document.doc_form) - except Exception as e: - logger.exception("create segment index failed") - for segment_document in segment_data_list: - segment_document.enabled = False - segment_document.disabled_at = naive_utc_now() - segment_document.status = "error" - segment_document.error = str(e) - db.session.commit() - return segment_data_list + keywords_list.append(None) + # update document word count + assert document.word_count is not None + document.word_count += increment_word_count + db.session.add(document) + try: + # save vector index + VectorService.create_segments_vector( + keywords_list, pre_segment_data_list, dataset, document.doc_form + ) + except Exception as e: + logger.exception("create segment index failed") + for segment_document in segment_data_list: + segment_document.enabled = False + segment_document.disabled_at = naive_utc_now() + segment_document.status = "error" + segment_document.error = str(e) + db.session.commit() + return segment_data_list + except LockNotOwnedError: + pass @classmethod def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset): @@ -2883,7 +2978,7 @@ class SegmentService: document.word_count = max(0, document.word_count + word_count_change) db.session.add(document) # update segment index task - if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # regenerate child chunks # get embedding model instance if dataset.indexing_technique == "high_quality": @@ -2910,12 +3005,11 @@ class SegmentService: .where(DatasetProcessRule.id == document.dataset_process_rule_id) .first() ) - if not processing_rule: - raise ValueError("No processing rule found.") - VectorService.generate_child_chunks( - segment, document, dataset, embedding_model_instance, processing_rule, True - ) - elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): + if processing_rule: + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) + elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX): if args.enabled or keyword_changed: # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) @@ -2960,7 +3054,7 @@ class SegmentService: db.session.add(document) db.session.add(segment) db.session.commit() - if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # get embedding model instance if dataset.indexing_technique == "high_quality": # check embedding model setting @@ -2986,15 +3080,15 @@ class SegmentService: .where(DatasetProcessRule.id == document.dataset_process_rule_id) .first() ) - if not processing_rule: - raise ValueError("No processing rule found.") - VectorService.generate_child_chunks( - segment, document, dataset, embedding_model_instance, processing_rule, True - ) - elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): + if processing_rule: + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) + elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX): # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) - + # update multimodel vector index + VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset) except Exception as e: logger.exception("update segment index failed") segment.enabled = False @@ -3032,7 +3126,9 @@ class SegmentService: ) child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]] - delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids) + delete_segment_from_index_task.delay( + [segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids + ) db.session.delete(segment) # update document word count @@ -3081,7 +3177,9 @@ class SegmentService: # Start async cleanup with both parent and child node IDs if index_node_ids or child_node_ids: - delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids) + delete_segment_from_index_task.delay( + index_node_ids, dataset.id, document.id, segment_db_ids, child_node_ids + ) if document.word_count is None: document.word_count = 0 diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 81e0c0ecd4..eeb14072bd 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -29,8 +29,14 @@ def get_current_user(): from models.account import Account from models.model import EndUser - if not isinstance(current_user._get_current_object(), (Account, EndUser)): # type: ignore - raise TypeError(f"current_user must be Account or EndUser, got {type(current_user).__name__}") + try: + user_object = current_user._get_current_object() + except AttributeError: + # Handle case where current_user might not be a LocalProxy in test environments + user_object = current_user + + if not isinstance(user_object, (Account, EndUser)): + raise TypeError(f"current_user must be Account or EndUser, got {type(user_object).__name__}") return current_user diff --git a/api/services/document_indexing_proxy/__init__.py b/api/services/document_indexing_proxy/__init__.py new file mode 100644 index 0000000000..74195adbe1 --- /dev/null +++ b/api/services/document_indexing_proxy/__init__.py @@ -0,0 +1,11 @@ +from .base import DocumentTaskProxyBase +from .batch_indexing_base import BatchDocumentIndexingProxy +from .document_indexing_task_proxy import DocumentIndexingTaskProxy +from .duplicate_document_indexing_task_proxy import DuplicateDocumentIndexingTaskProxy + +__all__ = [ + "BatchDocumentIndexingProxy", + "DocumentIndexingTaskProxy", + "DocumentTaskProxyBase", + "DuplicateDocumentIndexingTaskProxy", +] diff --git a/api/services/document_indexing_proxy/base.py b/api/services/document_indexing_proxy/base.py new file mode 100644 index 0000000000..56e47857c9 --- /dev/null +++ b/api/services/document_indexing_proxy/base.py @@ -0,0 +1,111 @@ +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable +from functools import cached_property +from typing import Any, ClassVar + +from enums.cloud_plan import CloudPlan +from services.feature_service import FeatureService + +logger = logging.getLogger(__name__) + + +class DocumentTaskProxyBase(ABC): + """ + Base proxy for all document processing tasks. + + Handles common logic: + - Feature/billing checks + - Dispatch routing based on plan + + Subclasses must define: + - QUEUE_NAME: Redis queue identifier + - NORMAL_TASK_FUNC: Task function for normal priority + - PRIORITY_TASK_FUNC: Task function for high priority + """ + + QUEUE_NAME: ClassVar[str] + NORMAL_TASK_FUNC: ClassVar[Callable[..., Any]] + PRIORITY_TASK_FUNC: ClassVar[Callable[..., Any]] + + def __init__(self, tenant_id: str, dataset_id: str): + """ + Initialize with minimal required parameters. + + Args: + tenant_id: Tenant identifier for billing/features + dataset_id: Dataset identifier for logging + """ + self._tenant_id = tenant_id + self._dataset_id = dataset_id + + @cached_property + def features(self): + return FeatureService.get_features(self._tenant_id) + + @abstractmethod + def _send_to_direct_queue(self, task_func: Callable[..., Any]): + """ + Send task directly to Celery queue without tenant isolation. + + Subclasses implement this to pass task-specific parameters. + + Args: + task_func: The Celery task function to call + """ + pass + + @abstractmethod + def _send_to_tenant_queue(self, task_func: Callable[..., Any]): + """ + Send task to tenant-isolated queue. + + Subclasses implement this to handle queue management. + + Args: + task_func: The Celery task function to call + """ + pass + + def _send_to_default_tenant_queue(self): + """Route to normal priority with tenant isolation.""" + self._send_to_tenant_queue(self.NORMAL_TASK_FUNC) + + def _send_to_priority_tenant_queue(self): + """Route to priority queue with tenant isolation.""" + self._send_to_tenant_queue(self.PRIORITY_TASK_FUNC) + + def _send_to_priority_direct_queue(self): + """Route to priority queue without tenant isolation.""" + self._send_to_direct_queue(self.PRIORITY_TASK_FUNC) + + def _dispatch(self): + """ + Dispatch task based on billing plan. + + Routing logic: + - Sandbox plan → normal queue + tenant isolation + - Paid plans → priority queue + tenant isolation + - Self-hosted → priority queue, no isolation + """ + logger.info( + "dispatch args: %s - %s - %s", + self._tenant_id, + self.features.billing.enabled, + self.features.billing.subscription.plan, + ) + # dispatch to different indexing queue with tenant isolation when billing enabled + if self.features.billing.enabled: + if self.features.billing.subscription.plan == CloudPlan.SANDBOX: + # dispatch to normal pipeline queue with tenant self sub queue for sandbox plan + self._send_to_default_tenant_queue() + else: + # dispatch to priority pipeline queue with tenant self sub queue for other plans + self._send_to_priority_tenant_queue() + else: + # dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise + self._send_to_priority_direct_queue() + + def delay(self): + """Public API: Queue the task asynchronously.""" + self._dispatch() diff --git a/api/services/document_indexing_proxy/batch_indexing_base.py b/api/services/document_indexing_proxy/batch_indexing_base.py new file mode 100644 index 0000000000..dd122f34a8 --- /dev/null +++ b/api/services/document_indexing_proxy/batch_indexing_base.py @@ -0,0 +1,76 @@ +import logging +from collections.abc import Callable, Sequence +from dataclasses import asdict +from typing import Any + +from core.entities.document_task import DocumentTask +from core.rag.pipeline.queue import TenantIsolatedTaskQueue + +from .base import DocumentTaskProxyBase + +logger = logging.getLogger(__name__) + + +class BatchDocumentIndexingProxy(DocumentTaskProxyBase): + """ + Base proxy for batch document indexing tasks (document_ids in plural). + + Adds: + - Tenant isolated queue management + - Batch document handling + """ + + def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]): + """ + Initialize with batch documents. + + Args: + tenant_id: Tenant identifier + dataset_id: Dataset identifier + document_ids: List of document IDs to process + """ + super().__init__(tenant_id, dataset_id) + self._document_ids = document_ids + self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, self.QUEUE_NAME) + + def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], Any]): + """ + Send batch task to direct queue. + + Args: + task_func: The Celery task function to call with (tenant_id, dataset_id, document_ids) + """ + logger.info("tenant %s send documents %s to direct queue", self._tenant_id, self._document_ids) + task_func.delay( # type: ignore + tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids + ) + + def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], Any]): + """ + Send batch task to tenant-isolated queue. + + Args: + task_func: The Celery task function to call with (tenant_id, dataset_id, document_ids) + """ + logger.info( + "tenant %s send documents %s to tenant queue %s", self._tenant_id, self._document_ids, self.QUEUE_NAME + ) + if self._tenant_isolated_task_queue.get_task_key(): + # Add to waiting queue using List operations (lpush) + self._tenant_isolated_task_queue.push_tasks( + [ + asdict( + DocumentTask( + tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids + ) + ) + ] + ) + logger.info("tenant %s push tasks: %s - %s", self._tenant_id, self._dataset_id, self._document_ids) + else: + # Set flag and execute task + self._tenant_isolated_task_queue.set_task_waiting_time() + task_func.delay( # type: ignore + tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids + ) + logger.info("tenant %s init tasks: %s - %s", self._tenant_id, self._dataset_id, self._document_ids) diff --git a/api/services/document_indexing_proxy/document_indexing_task_proxy.py b/api/services/document_indexing_proxy/document_indexing_task_proxy.py new file mode 100644 index 0000000000..fce79a8387 --- /dev/null +++ b/api/services/document_indexing_proxy/document_indexing_task_proxy.py @@ -0,0 +1,12 @@ +from typing import ClassVar + +from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy +from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task + + +class DocumentIndexingTaskProxy(BatchDocumentIndexingProxy): + """Proxy for document indexing tasks.""" + + QUEUE_NAME: ClassVar[str] = "document_indexing" + NORMAL_TASK_FUNC = normal_document_indexing_task + PRIORITY_TASK_FUNC = priority_document_indexing_task diff --git a/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py b/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py new file mode 100644 index 0000000000..277cfbdcf1 --- /dev/null +++ b/api/services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py @@ -0,0 +1,15 @@ +from typing import ClassVar + +from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy +from tasks.duplicate_document_indexing_task import ( + normal_duplicate_document_indexing_task, + priority_duplicate_document_indexing_task, +) + + +class DuplicateDocumentIndexingTaskProxy(BatchDocumentIndexingProxy): + """Proxy for duplicate document indexing tasks.""" + + QUEUE_NAME: ClassVar[str] = "duplicate_document_indexing" + NORMAL_TASK_FUNC = normal_duplicate_document_indexing_task + PRIORITY_TASK_FUNC = priority_duplicate_document_indexing_task diff --git a/api/services/document_indexing_task_proxy.py b/api/services/document_indexing_task_proxy.py deleted file mode 100644 index 861c84b586..0000000000 --- a/api/services/document_indexing_task_proxy.py +++ /dev/null @@ -1,83 +0,0 @@ -import logging -from collections.abc import Callable, Sequence -from dataclasses import asdict -from functools import cached_property - -from core.entities.document_task import DocumentTask -from core.rag.pipeline.queue import TenantIsolatedTaskQueue -from enums.cloud_plan import CloudPlan -from services.feature_service import FeatureService -from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task - -logger = logging.getLogger(__name__) - - -class DocumentIndexingTaskProxy: - def __init__(self, tenant_id: str, dataset_id: str, document_ids: Sequence[str]): - self._tenant_id = tenant_id - self._dataset_id = dataset_id - self._document_ids = document_ids - self._tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing") - - @cached_property - def features(self): - return FeatureService.get_features(self._tenant_id) - - def _send_to_direct_queue(self, task_func: Callable[[str, str, Sequence[str]], None]): - logger.info("send dataset %s to direct queue", self._dataset_id) - task_func.delay( # type: ignore - tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids - ) - - def _send_to_tenant_queue(self, task_func: Callable[[str, str, Sequence[str]], None]): - logger.info("send dataset %s to tenant queue", self._dataset_id) - if self._tenant_isolated_task_queue.get_task_key(): - # Add to waiting queue using List operations (lpush) - self._tenant_isolated_task_queue.push_tasks( - [ - asdict( - DocumentTask( - tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids - ) - ) - ] - ) - logger.info("push tasks: %s - %s", self._dataset_id, self._document_ids) - else: - # Set flag and execute task - self._tenant_isolated_task_queue.set_task_waiting_time() - task_func.delay( # type: ignore - tenant_id=self._tenant_id, dataset_id=self._dataset_id, document_ids=self._document_ids - ) - logger.info("init tasks: %s - %s", self._dataset_id, self._document_ids) - - def _send_to_default_tenant_queue(self): - self._send_to_tenant_queue(normal_document_indexing_task) - - def _send_to_priority_tenant_queue(self): - self._send_to_tenant_queue(priority_document_indexing_task) - - def _send_to_priority_direct_queue(self): - self._send_to_direct_queue(priority_document_indexing_task) - - def _dispatch(self): - logger.info( - "dispatch args: %s - %s - %s", - self._tenant_id, - self.features.billing.enabled, - self.features.billing.subscription.plan, - ) - # dispatch to different indexing queue with tenant isolation when billing enabled - if self.features.billing.enabled: - if self.features.billing.subscription.plan == CloudPlan.SANDBOX: - # dispatch to normal pipeline queue with tenant self sub queue for sandbox plan - self._send_to_default_tenant_queue() - else: - # dispatch to priority pipeline queue with tenant self sub queue for other plans - self._send_to_priority_tenant_queue() - else: - # dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise - self._send_to_priority_direct_queue() - - def delay(self): - self._dispatch() diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 131e90e195..7959734e89 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -124,6 +124,14 @@ class KnowledgeConfig(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None name: str | None = None + is_multimodal: bool = False + + +class SegmentCreateArgs(BaseModel): + content: str | None = None + answer: str | None = None + keywords: list[str] | None = None + attachment_ids: list[str] | None = None class SegmentUpdateArgs(BaseModel): @@ -132,6 +140,7 @@ class SegmentUpdateArgs(BaseModel): keywords: list[str] | None = None regenerate_child_chunks: bool = False enabled: bool | None = None + attachment_ids: list[str] | None = None class ChildChunkUpdateArgs(BaseModel): diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index d07badefa7..f405546909 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -69,6 +69,7 @@ class ProviderResponse(BaseModel): label: I18nObject description: I18nObject | None = None icon_small: I18nObject | None = None + icon_small_dark: I18nObject | None = None icon_large: I18nObject | None = None background: str | None = None help: ProviderHelpEntity | None = None @@ -92,6 +93,11 @@ class ProviderResponse(BaseModel): self.icon_small = I18nObject( en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) + if self.icon_small_dark is not None: + self.icon_small_dark = I18nObject( + en_US=f"{url_prefix}/icon_small_dark/en_US", + zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans", + ) if self.icon_large is not None: self.icon_large = I18nObject( @@ -109,6 +115,7 @@ class ProviderWithModelsResponse(BaseModel): provider: str label: I18nObject icon_small: I18nObject | None = None + icon_small_dark: I18nObject | None = None icon_large: I18nObject | None = None status: CustomConfigurationStatus models: list[ProviderModelWithStatusEntity] @@ -123,6 +130,11 @@ class ProviderWithModelsResponse(BaseModel): en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) + if self.icon_small_dark is not None: + self.icon_small_dark = I18nObject( + en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans" + ) + if self.icon_large is not None: self.icon_large = I18nObject( en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" @@ -147,6 +159,11 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) + if self.icon_small_dark is not None: + self.icon_small_dark = I18nObject( + en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans" + ) + if self.icon_large is not None: self.icon_large = I18nObject( en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 27936f6278..40faa85b9a 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -324,4 +324,5 @@ class ExternalDatasetService: ) if response.status_code == 200: return cast(list[Any], response.json().get("records", [])) - return [] + else: + raise ValueError(response.text) diff --git a/api/services/file_service.py b/api/services/file_service.py index 1980cd8d59..0911cf38c4 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -1,3 +1,4 @@ +import base64 import hashlib import os import uuid @@ -123,6 +124,15 @@ class FileService: return file_size <= file_size_limit + def get_file_base64(self, file_id: str) -> str: + upload_file = ( + self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first() + ) + if not upload_file: + raise NotFound("File not found") + blob = storage.load_once(upload_file.key) + return base64.b64encode(blob).decode() + def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile: if len(text_name) > 200: text_name = text_name[:200] diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index cdbd2355ca..8cbf3a25c3 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -1,3 +1,4 @@ +import json import logging import time from typing import Any @@ -5,6 +6,7 @@ from typing import Any from core.app.app_config.entities import ModelConfig from core.model_runtime.entities import LLMMode from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -32,6 +34,7 @@ class HitTestingService: account: Account, retrieval_model: Any, # FIXME drop this any external_retrieval_model: dict, + attachment_ids: list | None = None, limit: int = 10, ): start = time.perf_counter() @@ -41,7 +44,7 @@ class HitTestingService: retrieval_model = dataset.retrieval_model or default_retrieval_model document_ids_filter = None metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {}) - if metadata_filtering_conditions: + if metadata_filtering_conditions and query: dataset_retrieval = DatasetRetrieval() from core.app.app_config.entities import MetadataFilteringCondition @@ -66,6 +69,7 @@ class HitTestingService: retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)), dataset_id=dataset.id, query=query, + attachment_ids=attachment_ids, top_k=retrieval_model.get("top_k", 4), score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] @@ -80,17 +84,24 @@ class HitTestingService: end = time.perf_counter() logger.debug("Hit testing retrieve in %s seconds", end - start) - - dataset_query = DatasetQuery( - dataset_id=dataset.id, - content=query, - source="hit_testing", - source_app_id=None, - created_by_role="account", - created_by=account.id, - ) - - db.session.add(dataset_query) + dataset_queries = [] + if query: + content = {"content_type": QueryType.TEXT_QUERY, "content": query} + dataset_queries.append(content) + if attachment_ids: + for attachment_id in attachment_ids: + content = {"content_type": QueryType.IMAGE_QUERY, "content": attachment_id} + dataset_queries.append(content) + if dataset_queries: + dataset_query = DatasetQuery( + dataset_id=dataset.id, + content=json.dumps(dataset_queries), + source="hit_testing", + source_app_id=None, + created_by_role="account", + created_by=account.id, + ) + db.session.add(dataset_query) db.session.commit() return cls.compact_retrieve_response(query, all_documents) @@ -101,8 +112,8 @@ class HitTestingService: dataset: Dataset, query: str, account: Account, - external_retrieval_model: dict, - metadata_filtering_conditions: dict, + external_retrieval_model: dict | None = None, + metadata_filtering_conditions: dict | None = None, ): if dataset.provider != "external": return { @@ -167,10 +178,15 @@ class HitTestingService: @classmethod def hit_testing_args_check(cls, args): - query = args["query"] + query = args.get("query") + attachment_ids = args.get("attachment_ids") - if not query or len(query) > 250: - raise ValueError("Query is required and cannot exceed 250 characters") + if not attachment_ids and not query: + raise ValueError("Query or attachment_ids is required") + if query and len(query) > 250: + raise ValueError("Query cannot exceed 250 characters") + if attachment_ids and not isinstance(attachment_ids, list): + raise ValueError("Attachment_ids must be a list") @staticmethod def escape_query_for_search(query: str) -> str: diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 50ddbbf681..eea382febe 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -70,15 +70,35 @@ class ModelProviderService: continue provider_config = provider_configuration.custom_configuration.provider - model_config = provider_configuration.custom_configuration.models + models = provider_configuration.custom_configuration.models can_added_models = provider_configuration.custom_configuration.can_added_models + # IMPORTANT: Never expose decrypted credentials in the provider list API. + # Sanitize custom model configurations by dropping the credentials payload. + sanitized_model_config = [] + if models: + from core.entities.provider_entities import CustomModelConfiguration # local import to avoid cycles + + for model in models: + sanitized_model_config.append( + CustomModelConfiguration( + model=model.model, + model_type=model.model_type, + credentials=None, # strip secrets from list view + current_credential_id=model.current_credential_id, + current_credential_name=model.current_credential_name, + available_model_credentials=model.available_model_credentials, + unadded_to_model_list=model.unadded_to_model_list, + ) + ) + provider_response = ProviderResponse( tenant_id=tenant_id, provider=provider_configuration.provider.provider, label=provider_configuration.provider.label, description=provider_configuration.provider.description, icon_small=provider_configuration.provider.icon_small, + icon_small_dark=provider_configuration.provider.icon_small_dark, icon_large=provider_configuration.provider.icon_large, background=provider_configuration.provider.background, help=provider_configuration.provider.help, @@ -94,7 +114,7 @@ class ModelProviderService: current_credential_id=getattr(provider_config, "current_credential_id", None), current_credential_name=getattr(provider_config, "current_credential_name", None), available_credentials=getattr(provider_config, "available_credentials", []), - custom_models=model_config, + custom_models=sanitized_model_config, can_added_models=can_added_models, ), system_configuration=SystemConfigurationResponse( @@ -402,6 +422,7 @@ class ModelProviderService: provider=provider, label=first_model.provider.label, icon_small=first_model.provider.icon_small, + icon_small_dark=first_model.provider.icon_small_dark, icon_large=first_model.provider.icon_large, status=CustomConfigurationStatus.ACTIVE, models=[ diff --git a/api/services/rag_pipeline/rag_pipeline_task_proxy.py b/api/services/rag_pipeline/rag_pipeline_task_proxy.py index 94dd7941da..1a7b104a70 100644 --- a/api/services/rag_pipeline/rag_pipeline_task_proxy.py +++ b/api/services/rag_pipeline/rag_pipeline_task_proxy.py @@ -38,21 +38,24 @@ class RagPipelineTaskProxy: upload_file = FileService(db.engine).upload_text( json_text, self._RAG_PIPELINE_INVOKE_ENTITIES_FILE_NAME, self._user_id, self._dataset_tenant_id ) + logger.info( + "tenant %s upload %d invoke entities", self._dataset_tenant_id, len(self._rag_pipeline_invoke_entities) + ) return upload_file.id def _send_to_direct_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]): - logger.info("send file %s to direct queue", upload_file_id) + logger.info("tenant %s send file %s to direct queue", self._dataset_tenant_id, upload_file_id) task_func.delay( # type: ignore rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id=self._dataset_tenant_id, ) def _send_to_tenant_queue(self, upload_file_id: str, task_func: Callable[[str, str], None]): - logger.info("send file %s to tenant queue", upload_file_id) + logger.info("tenant %s send file %s to tenant queue", self._dataset_tenant_id, upload_file_id) if self._tenant_isolated_task_queue.get_task_key(): # Add to waiting queue using List operations (lpush) self._tenant_isolated_task_queue.push_tasks([upload_file_id]) - logger.info("push tasks: %s", upload_file_id) + logger.info("tenant %s push tasks: %s", self._dataset_tenant_id, upload_file_id) else: # Set flag and execute task self._tenant_isolated_task_queue.set_task_waiting_time() @@ -60,7 +63,7 @@ class RagPipelineTaskProxy: rag_pipeline_invoke_entities_file_id=upload_file_id, tenant_id=self._dataset_tenant_id, ) - logger.info("init tasks: %s", upload_file_id) + logger.info("tenant %s init tasks: %s", self._dataset_tenant_id, upload_file_id) def _send_to_default_tenant_queue(self, upload_file_id: str): self._send_to_tenant_queue(upload_file_id, rag_pipeline_run_task) diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 250d29f335..b3b6e36346 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -7,6 +7,7 @@ from httpx import get from sqlalchemy import select from core.entities.provider_entities import ProviderConfig +from core.helper.tool_provider_cache import ToolProviderListCache from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_runtime import ToolRuntime from core.tools.custom_tool.provider import ApiToolProviderController @@ -177,6 +178,9 @@ class ApiToolManageService: # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) + return {"result": "success"} @staticmethod @@ -318,6 +322,9 @@ class ApiToolManageService: # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) + return {"result": "success"} @staticmethod @@ -340,6 +347,9 @@ class ApiToolManageService: db.session.delete(provider) db.session.commit() + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) + return {"result": "success"} @staticmethod diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 783f2f0d21..cf1d39fa25 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -12,6 +12,7 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.helper.name_generator import generate_incremental_name from core.helper.position_helper import is_filtered from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache +from core.helper.tool_provider_cache import ToolProviderListCache from core.plugin.entities.plugin_daemon import CredentialType from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort @@ -204,6 +205,9 @@ class BuiltinToolManageService: db_provider.name = name session.commit() + + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) except Exception as e: session.rollback() raise ValueError(str(e)) @@ -282,6 +286,9 @@ class BuiltinToolManageService: session.add(db_provider) session.commit() + + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) except Exception as e: session.rollback() raise ValueError(str(e)) @@ -402,6 +409,9 @@ class BuiltinToolManageService: ) cache.delete() + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) + return {"result": "success"} @staticmethod @@ -423,6 +433,9 @@ class BuiltinToolManageService: # set new default provider target_provider.is_default = True session.commit() + + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) return {"result": "success"} @staticmethod diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index 7eedf76aed..d641fe0315 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -15,6 +15,7 @@ from sqlalchemy.orm import Session from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity from core.helper import encrypter from core.helper.provider_cache import NoOpProviderCredentialCache +from core.helper.tool_provider_cache import ToolProviderListCache from core.mcp.auth.auth_flow import auth from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPAuthError, MCPError @@ -164,6 +165,10 @@ class MCPToolManageService: self._session.add(mcp_tool) self._session.flush() + + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) + mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True) return mcp_providers @@ -245,6 +250,9 @@ class MCPToolManageService: # Flush changes to database self._session.flush() + + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) except IntegrityError as e: self._handle_integrity_error(e, name, server_url, server_identifier) @@ -253,6 +261,9 @@ class MCPToolManageService: mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id) self._session.delete(mcp_tool) + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) + def list_providers( self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True ) -> list[ToolProviderApiEntity]: diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py index 51e9120b8d..038c462f15 100644 --- a/api/services/tools/tools_manage_service.py +++ b/api/services/tools/tools_manage_service.py @@ -1,5 +1,6 @@ import logging +from core.helper.tool_provider_cache import ToolProviderListCache from core.tools.entities.api_entities import ToolProviderTypeApiLiteral from core.tools.tool_manager import ToolManager from services.tools.tools_transform_service import ToolTransformService @@ -15,6 +16,14 @@ class ToolCommonService: :return: the list of tool providers """ + # Try to get from cache first + cached_result = ToolProviderListCache.get_cached_providers(tenant_id, typ) + if cached_result is not None: + logger.debug("Returning cached tool providers for tenant %s, type %s", tenant_id, typ) + return cached_result + + # Cache miss - fetch from database + logger.debug("Cache miss for tool providers, fetching from database for tenant %s, type %s", tenant_id, typ) providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ) # add icon @@ -23,4 +32,7 @@ class ToolCommonService: result = [provider.to_dict() for provider in providers] + # Cache the result + ToolProviderListCache.set_cached_providers(tenant_id, typ, result) + return result diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index c2bfb4dde6..fe77ff2dc5 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -1,4 +1,5 @@ import json +import logging from collections.abc import Mapping from datetime import datetime from typing import Any @@ -6,6 +7,7 @@ from typing import Any from sqlalchemy import or_, select from sqlalchemy.orm import Session +from core.helper.tool_provider_cache import ToolProviderListCache from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_provider import ToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity @@ -19,6 +21,8 @@ from models.tools import WorkflowToolProvider from models.workflow import Workflow from services.tools.tools_transform_service import ToolTransformService +logger = logging.getLogger(__name__) + class WorkflowToolManageService: """ @@ -88,6 +92,10 @@ class WorkflowToolManageService: ToolLabelManager.update_tool_labels( ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) + + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) + return {"result": "success"} @classmethod @@ -175,6 +183,9 @@ class WorkflowToolManageService: ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) + return {"result": "success"} @classmethod @@ -198,7 +209,7 @@ class WorkflowToolManageService: tools.append(ToolTransformService.workflow_provider_to_controller(provider)) except Exception: # skip deleted tools - pass + logger.exception("Failed to load workflow tool provider %s", provider.id) labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)]) @@ -237,6 +248,9 @@ class WorkflowToolManageService: db.session.commit() + # Invalidate tool providers cache + ToolProviderListCache.invalidate_cache(tenant_id) + return {"result": "success"} @classmethod diff --git a/api/services/vector_service.py b/api/services/vector_service.py index abc92a0181..f1fa33cb75 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -4,11 +4,14 @@ from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment +from models import UploadFile +from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import ParentMode @@ -21,9 +24,10 @@ class VectorService: cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str ): documents: list[Document] = [] + multimodal_documents: list[AttachmentDocument] = [] for segment in segments: - if doc_form == IndexType.PARENT_CHILD_INDEX: + if doc_form == IndexStructureType.PARENT_CHILD_INDEX: dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first() if not dataset_document: logger.warning( @@ -70,12 +74,29 @@ class VectorService: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, + "doc_type": DocType.TEXT, }, ) documents.append(rag_document) + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_document: AttachmentDocument = AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + multimodal_documents.append(multimodal_document) + index_processor: BaseIndexProcessor = IndexProcessorFactory(doc_form).init_index_processor() + if len(documents) > 0: - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) + index_processor.load(dataset, documents, None, with_keywords=True, keywords_list=keywords_list) + if len(multimodal_documents) > 0: + index_processor.load(dataset, [], multimodal_documents, with_keywords=False) @classmethod def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset): @@ -130,6 +151,7 @@ class VectorService: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, + "doc_type": DocType.TEXT, }, ) # use full doc mode to generate segment's child chunk @@ -226,3 +248,92 @@ class VectorService: def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset): vector = Vector(dataset=dataset) vector.delete_by_ids([child_chunk.index_node_id]) + + @classmethod + def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset): + if dataset.indexing_technique != "high_quality": + return + + attachments = segment.attachments + old_attachment_ids = [attachment["id"] for attachment in attachments] if attachments else [] + + # Check if there's any actual change needed + if set(attachment_ids) == set(old_attachment_ids): + return + + try: + vector = Vector(dataset=dataset) + if dataset.is_multimodal: + # Delete old vectors if they exist + if old_attachment_ids: + vector.delete_by_ids(old_attachment_ids) + + # Delete existing segment attachment bindings in one operation + db.session.query(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id).delete( + synchronize_session=False + ) + + if not attachment_ids: + db.session.commit() + return + + # Bulk fetch upload files - only fetch needed fields + upload_file_list = db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all() + + if not upload_file_list: + db.session.commit() + return + + # Create a mapping for quick lookup + upload_file_map = {upload_file.id: upload_file for upload_file in upload_file_list} + + # Prepare batch operations + bindings = [] + documents = [] + + # Create common metadata base to avoid repetition + base_metadata = { + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + } + + # Process attachments in the order specified by attachment_ids + for attachment_id in attachment_ids: + upload_file = upload_file_map.get(attachment_id) + if not upload_file: + logger.warning("Upload file not found for attachment_id: %s", attachment_id) + continue + + # Create segment attachment binding + bindings.append( + SegmentAttachmentBinding( + tenant_id=segment.tenant_id, + dataset_id=segment.dataset_id, + document_id=segment.document_id, + segment_id=segment.id, + attachment_id=upload_file.id, + ) + ) + + # Create document for vector indexing + documents.append( + Document(page_content=upload_file.name, metadata={**base_metadata, "doc_id": upload_file.id}) + ) + + # Bulk insert all bindings at once + if bindings: + db.session.add_all(bindings) + + # Add documents to vector store if any + if documents and dataset.is_multimodal: + vector.add_texts(documents, duplicate_check=True) + + # Single commit for all operations + db.session.commit() + + except Exception: + logger.exception("Failed to update multimodal vector for segment %s", segment.id) + db.session.rollback() + raise diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 933ad6b9e2..e7dead8a56 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -4,9 +4,10 @@ import time import click from celery import shared_task -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -55,6 +56,7 @@ def add_document_to_index_task(dataset_document_id: str): ) documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -65,7 +67,7 @@ def add_document_to_index_task(dataset_document_id: str): "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -81,11 +83,25 @@ def add_document_to_index_task(dataset_document_id: str): ) child_documents.append(child_document) document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.load(dataset, documents) + index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) # delete auto disable log db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete() diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 5f2a355d16..8608df6b8e 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -18,6 +18,7 @@ from models.dataset import ( DatasetQuery, Document, DocumentSegment, + SegmentAttachmentBinding, ) from models.model import UploadFile @@ -58,14 +59,20 @@ def clean_dataset_task( ) documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all() segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all() + # Use JOIN to fetch attachments with bindings in a single query + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where(SegmentAttachmentBinding.tenant_id == tenant_id, SegmentAttachmentBinding.dataset_id == dataset_id) + ).all() # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace # This ensures all invalid doc_form values are properly handled if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()): # Use default paragraph index type for empty/invalid datasets to enable vector database cleanup - from core.rag.index_processor.constant.index_type import IndexType + from core.rag.index_processor.constant.index_type import IndexStructureType - doc_form = IndexType.PARAGRAPH_INDEX + doc_form = IndexStructureType.PARAGRAPH_INDEX logger.info( click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow") ) @@ -90,6 +97,7 @@ def clean_dataset_task( for document in documents: db.session.delete(document) + # delete document file for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) @@ -107,6 +115,19 @@ def clean_dataset_task( ) db.session.delete(image_file) db.session.delete(segment) + # delete segment attachments + if attachments_with_bindings: + for binding, attachment_file in attachments_with_bindings: + try: + storage.delete(attachment_file.key) + except Exception: + logger.exception( + "Delete attachment_file failed when storage deleted, \ + attachment_file_id: %s", + binding.attachment_id, + ) + db.session.delete(attachment_file) + db.session.delete(binding) db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete() db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete() diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 62200715cc..6d2feb1da3 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -9,7 +9,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_storage import storage -from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment +from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment, SegmentAttachmentBinding from models.model import UploadFile logger = logging.getLogger(__name__) @@ -36,6 +36,16 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i raise Exception("Document has no dataset") segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() + # Use JOIN to fetch attachments with bindings in a single query + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.tenant_id == dataset.tenant_id, + SegmentAttachmentBinding.dataset_id == dataset_id, + SegmentAttachmentBinding.document_id == document_id, + ) + ).all() # check segment is exist if segments: index_node_ids = [segment.index_node_id for segment in segments] @@ -69,6 +79,19 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i logger.exception("Delete file failed when document deleted, file_id: %s", file_id) db.session.delete(file) db.session.commit() + # delete segment attachments + if attachments_with_bindings: + for binding, attachment_file in attachments_with_bindings: + try: + storage.delete(attachment_file.key) + except Exception: + logger.exception( + "Delete attachment_file failed when storage deleted, \ + attachment_file_id: %s", + binding.attachment_id, + ) + db.session.delete(attachment_file) + db.session.delete(binding) # delete dataset metadata binding db.session.query(DatasetMetadataBinding).where( diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py index 713f149c38..3d13afdec0 100644 --- a/api/tasks/deal_dataset_index_update_task.py +++ b/api/tasks/deal_dataset_index_update_task.py @@ -4,9 +4,10 @@ import time import click from celery import shared_task # type: ignore -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -28,7 +29,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): if not dataset: raise Exception("Dataset not found") - index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX + index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "upgrade": dataset_documents = ( @@ -119,6 +120,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): ) if segments: documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -129,7 +131,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -145,9 +147,25 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): ) child_documents.append(child_document) document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) # save vector index - index_processor.load(dataset, documents, with_keywords=False) + index_processor.load( + dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False + ) db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index dc6ef6fb61..1c7de3b1ce 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -1,14 +1,14 @@ import logging import time -from typing import Literal import click from celery import shared_task from sqlalchemy import select -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]): +def deal_dataset_vector_index_task(dataset_id: str, action: str): """ Async deal dataset from index :param dataset_id: dataset_id @@ -32,7 +32,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a if not dataset: raise Exception("Dataset not found") - index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX + index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "remove": index_processor.clean(dataset, None, with_keywords=False) @@ -119,6 +119,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a ) if segments: documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -129,7 +130,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -145,9 +146,25 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a ) child_documents.append(child_document) document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) # save vector index - index_processor.load(dataset, documents, with_keywords=False) + index_processor.load( + dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False + ) db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index e8cbd0f250..bea5c952cf 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -6,14 +6,15 @@ from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db -from models.dataset import Dataset, Document +from models.dataset import Dataset, Document, SegmentAttachmentBinding +from models.model import UploadFile logger = logging.getLogger(__name__) @shared_task(queue="dataset") def delete_segment_from_index_task( - index_node_ids: list, dataset_id: str, document_id: str, child_node_ids: list | None = None + index_node_ids: list, dataset_id: str, document_id: str, segment_ids: list, child_node_ids: list | None = None ): """ Async Remove segment from index @@ -49,6 +50,21 @@ def delete_segment_from_index_task( delete_child_chunks=True, precomputed_child_node_ids=child_node_ids, ) + if dataset.is_multimodal: + # delete segment attachment binding + segment_attachment_bindings = ( + db.session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + .all() + ) + if segment_attachment_bindings: + attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] + index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False) + for binding in segment_attachment_bindings: + db.session.delete(binding) + # delete upload file + db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False) + db.session.commit() end_at = time.perf_counter() logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green")) diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index 9038dc179b..c2a3de29f4 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -8,7 +8,7 @@ from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.dataset import Dataset, DocumentSegment +from models.dataset import Dataset, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument logger = logging.getLogger(__name__) @@ -59,6 +59,16 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen try: index_node_ids = [segment.index_node_id for segment in segments] + if dataset.is_multimodal: + segment_ids = [segment.id for segment in segments] + segment_attachment_bindings = ( + db.session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + .all() + ) + if segment_attachment_bindings: + attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] + index_node_ids.extend(attachment_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) end_at = time.perf_counter() diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index fee4430612..acbdab631b 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -114,7 +114,13 @@ def _document_indexing_with_tenant_queue( try: _document_indexing(dataset_id, document_ids) except Exception: - logger.exception("Error processing document indexing %s for tenant %s: %s", dataset_id, tenant_id) + logger.exception( + "Error processing document indexing %s for tenant %s: %s", + dataset_id, + tenant_id, + document_ids, + exc_info=True, + ) finally: tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "document_indexing") @@ -122,7 +128,7 @@ def _document_indexing_with_tenant_queue( # Use rpop to get the next task from the queue (FIFO order) next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) - logger.info("document indexing tenant isolation queue next tasks: %s", next_tasks) + logger.info("document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks) if next_tasks: for next_task in next_tasks: diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 6492e356a3..4078c8910e 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -1,13 +1,16 @@ import logging import time +from collections.abc import Callable, Sequence import click from celery import shared_task from sqlalchemy import select from configs import dify_config +from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_database import db from libs.datetime_utils import naive_utc_now @@ -24,8 +27,55 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): :param dataset_id: :param document_ids: + .. warning:: TO BE DEPRECATED + This function will be deprecated and removed in a future version. + Use normal_duplicate_document_indexing_task or priority_duplicate_document_indexing_task instead. + Usage: duplicate_document_indexing_task.delay(dataset_id, document_ids) """ + logger.warning("duplicate document indexing task received: %s - %s", dataset_id, document_ids) + _duplicate_document_indexing_task(dataset_id, document_ids) + + +def _duplicate_document_indexing_task_with_tenant_queue( + tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None] +): + try: + _duplicate_document_indexing_task(dataset_id, document_ids) + except Exception: + logger.exception( + "Error processing duplicate document indexing %s for tenant %s: %s", + dataset_id, + tenant_id, + document_ids, + exc_info=True, + ) + finally: + tenant_isolated_task_queue = TenantIsolatedTaskQueue(tenant_id, "duplicate_document_indexing") + + # Check if there are waiting tasks in the queue + # Use rpop to get the next task from the queue (FIFO order) + next_tasks = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) + + logger.info("duplicate document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks) + + if next_tasks: + for next_task in next_tasks: + document_task = DocumentTask(**next_task) + # Process the next waiting task + # Keep the flag set to indicate a task is running + tenant_isolated_task_queue.set_task_waiting_time() + task_func.delay( # type: ignore + tenant_id=document_task.tenant_id, + dataset_id=document_task.dataset_id, + document_ids=document_task.document_ids, + ) + else: + # No more waiting tasks, clear the flag + tenant_isolated_task_queue.delete_task_key() + + +def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[str]): documents = [] start_at = time.perf_counter() @@ -110,3 +160,35 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id) finally: db.session.close() + + +@shared_task(queue="dataset") +def normal_duplicate_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]): + """ + Async process duplicate documents + :param tenant_id: + :param dataset_id: + :param document_ids: + + Usage: normal_duplicate_document_indexing_task.delay(tenant_id, dataset_id, document_ids) + """ + logger.info("normal duplicate document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids) + _duplicate_document_indexing_task_with_tenant_queue( + tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task + ) + + +@shared_task(queue="priority_dataset") +def priority_duplicate_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: Sequence[str]): + """ + Async process duplicate documents + :param tenant_id: + :param dataset_id: + :param document_ids: + + Usage: priority_duplicate_document_indexing_task.delay(tenant_id, dataset_id, document_ids) + """ + logger.info("priority duplicate document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids) + _duplicate_document_indexing_task_with_tenant_queue( + tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task + ) diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 07c44f333e..7615469ed0 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -4,9 +4,10 @@ import time import click from celery import shared_task -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -67,7 +68,7 @@ def enable_segment_to_index_task(segment_id: str): return index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -83,8 +84,24 @@ def enable_segment_to_index_task(segment_id: str): ) child_documents.append(child_document) document.children = child_documents + multimodel_documents = [] + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodel_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) + # save vector index - index_processor.load(dataset, [document]) + index_processor.load(dataset, [document], multimodal_documents=multimodel_documents) end_at = time.perf_counter() logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green")) diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py index c5ca7a6171..9f17d09e18 100644 --- a/api/tasks/enable_segments_to_index_task.py +++ b/api/tasks/enable_segments_to_index_task.py @@ -5,9 +5,10 @@ import click from celery import shared_task from sqlalchemy import select -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -60,6 +61,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i try: documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -71,7 +73,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -87,9 +89,24 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i ) child_documents.append(child_document) document.children = child_documents + + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) # save vector index - index_processor.load(dataset, documents) + index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) end_at = time.perf_counter() logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green")) diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py index 124971e8e2..e6492c230d 100644 --- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py +++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @@ -1,4 +1,5 @@ import json +import logging import operator import typing @@ -12,6 +13,8 @@ from core.plugin.impl.plugin import PluginInstaller from extensions.ext_redis import redis_client from models.account import TenantPluginAutoUpgradeStrategy +logger = logging.getLogger(__name__) + RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3 CACHE_REDIS_KEY_PREFIX = "plugin_autoupgrade_check_task:cached_plugin_manifests:" CACHE_REDIS_TTL = 60 * 15 # 15 minutes @@ -42,6 +45,7 @@ def _get_cached_manifest(plugin_id: str) -> typing.Union[MarketplacePluginDeclar return MarketplacePluginDeclaration.model_validate(cached_json) except Exception: + logger.exception("Failed to get cached manifest for plugin %s", plugin_id) return False @@ -63,7 +67,7 @@ def _set_cached_manifest(plugin_id: str, manifest: typing.Union[MarketplacePlugi except Exception: # If Redis fails, continue without caching # traceback.print_exc() - pass + logger.exception("Failed to set cached manifest for plugin %s", plugin_id) def marketplace_batch_fetch_plugin_manifests( diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index a7f61d9811..1eef361a92 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -47,6 +47,8 @@ def priority_rag_pipeline_run_task( ) rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content) + logger.info("tenant %s received %d rag pipeline invoke entities", tenant_id, len(rag_pipeline_invoke_entities)) + # Get Flask app object for thread context flask_app = current_app._get_current_object() # type: ignore @@ -66,7 +68,7 @@ def priority_rag_pipeline_run_task( end_at = time.perf_counter() logging.info( click.style( - f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" + f"tenant_id: {tenant_id}, Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" ) ) except Exception: @@ -78,7 +80,7 @@ def priority_rag_pipeline_run_task( # Check if there are waiting tasks in the queue # Use rpop to get the next task from the queue (FIFO order) next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) - logger.info("priority rag pipeline tenant isolation queue next files: %s", next_file_ids) + logger.info("priority rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids) if next_file_ids: for next_file_id in next_file_ids: diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 92f1dfb73d..275f5abe6e 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -47,6 +47,8 @@ def rag_pipeline_run_task( ) rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content) + logger.info("tenant %s received %d rag pipeline invoke entities", tenant_id, len(rag_pipeline_invoke_entities)) + # Get Flask app object for thread context flask_app = current_app._get_current_object() # type: ignore @@ -66,7 +68,7 @@ def rag_pipeline_run_task( end_at = time.perf_counter() logging.info( click.style( - f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" + f"tenant_id: {tenant_id}, Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green" ) ) except Exception: @@ -78,7 +80,7 @@ def rag_pipeline_run_task( # Check if there are waiting tasks in the queue # Use rpop to get the next task from the queue (FIFO order) next_file_ids = tenant_isolated_task_queue.pull_tasks(count=dify_config.TENANT_ISOLATED_TASK_CONCURRENCY) - logger.info("rag pipeline tenant isolation queue next files: %s", next_file_ids) + logger.info("rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids) if next_file_ids: for next_file_id in next_file_ids: diff --git a/api/tests/fixtures/workflow/end_node_without_value_type_field_workflow.yml b/api/tests/fixtures/workflow/end_node_without_value_type_field_workflow.yml new file mode 100644 index 0000000000..a69339691d --- /dev/null +++ b/api/tests/fixtures/workflow/end_node_without_value_type_field_workflow.yml @@ -0,0 +1,127 @@ +app: + description: 'End node without value_type field reproduction' + icon: 🤖 + icon_background: '#FFEAD5' + mode: workflow + name: end_node_without_value_type_field_reproduction + use_icon_as_answer_icon: false +dependencies: [] +kind: app +version: 0.5.0 +workflow: + conversation_variables: [] + environment_variables: [] + features: + file_upload: + allowed_file_extensions: + - .JPG + - .JPEG + - .PNG + - .GIF + - .WEBP + - .SVG + allowed_file_types: + - image + allowed_file_upload_methods: + - local_file + - remote_url + enabled: false + fileUploadConfig: + audio_file_size_limit: 50 + batch_count_limit: 5 + file_size_limit: 15 + image_file_batch_limit: 10 + image_file_size_limit: 10 + single_chunk_attachment_limit: 10 + video_file_size_limit: 100 + workflow_file_upload_limit: 10 + image: + enabled: false + number_limits: 3 + transfer_methods: + - local_file + - remote_url + number_limits: 3 + opening_statement: '' + retriever_resource: + enabled: true + sensitive_word_avoidance: + enabled: false + speech_to_text: + enabled: false + suggested_questions: [] + suggested_questions_after_answer: + enabled: false + text_to_speech: + enabled: false + language: '' + voice: '' + graph: + edges: + - data: + isInIteration: false + isInLoop: false + sourceType: start + targetType: end + id: 1765423445456-source-1765423454810-target + source: '1765423445456' + sourceHandle: source + target: '1765423454810' + targetHandle: target + type: custom + zIndex: 0 + nodes: + - data: + selected: false + title: 用户输入 + type: start + variables: + - default: '' + hint: '' + label: query + max_length: 48 + options: [] + placeholder: '' + required: true + type: text-input + variable: query + height: 109 + id: '1765423445456' + position: + x: -48 + y: 261 + positionAbsolute: + x: -48 + y: 261 + selected: false + sourcePosition: right + targetPosition: left + type: custom + width: 242 + - data: + outputs: + - value_selector: + - '1765423445456' + - query + variable: query + selected: true + title: 输出 + type: end + height: 88 + id: '1765423454810' + position: + x: 382 + y: 282 + positionAbsolute: + x: 382 + y: 282 + selected: true + sourcePosition: right + targetPosition: left + type: custom + width: 242 + viewport: + x: 139 + y: -135 + zoom: 1 + rag_pipeline_variables: [] diff --git a/api/tests/fixtures/workflow/iteration_flatten_output_disabled_workflow.yml b/api/tests/fixtures/workflow/iteration_flatten_output_disabled_workflow.yml index 9cae6385c8..b2451c7a9e 100644 --- a/api/tests/fixtures/workflow/iteration_flatten_output_disabled_workflow.yml +++ b/api/tests/fixtures/workflow/iteration_flatten_output_disabled_workflow.yml @@ -233,7 +233,7 @@ workflow: - value_selector: - iteration_node - output - value_type: array[array[number]] + value_type: array[number] variable: output selected: false title: End diff --git a/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py b/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py new file mode 100644 index 0000000000..e55c12e678 --- /dev/null +++ b/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py @@ -0,0 +1,244 @@ +"""Integration tests for Trigger Provider subscription permission verification.""" + +import uuid +from unittest import mock + +import pytest +from flask.testing import FlaskClient + +from controllers.console.workspace import trigger_providers as trigger_providers_api +from libs.datetime_utils import naive_utc_now +from models import Tenant +from models.account import Account, TenantAccountJoin, TenantAccountRole + + +class TestTriggerProviderSubscriptionPermissions: + """Test permission verification for Trigger Provider subscription endpoints.""" + + @pytest.fixture + def mock_account(self, monkeypatch: pytest.MonkeyPatch): + """Create a mock Account for testing.""" + + account = Account(name="Test User", email="test@example.com") + account.id = str(uuid.uuid4()) + account.last_active_at = naive_utc_now() + account.created_at = naive_utc_now() + account.updated_at = naive_utc_now() + + # Create mock tenant + tenant = Tenant(name="Test Tenant") + tenant.id = str(uuid.uuid4()) + + mock_session_instance = mock.Mock() + + mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER) + monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join)) + + mock_scalars_result = mock.Mock() + mock_scalars_result.one.return_value = tenant + monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result)) + + mock_session_context = mock.Mock() + mock_session_context.__enter__.return_value = mock_session_instance + monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context) + + account.current_tenant = tenant + account.current_tenant_id = tenant.id + return account + + @pytest.mark.parametrize( + ("role", "list_status", "get_status", "update_status", "create_status", "build_status", "delete_status"), + [ + # Admin/Owner can do everything + (TenantAccountRole.OWNER, 200, 200, 200, 200, 200, 200), + (TenantAccountRole.ADMIN, 200, 200, 200, 200, 200, 200), + # Editor can list, get, update (parameters), but not create, build, or delete + (TenantAccountRole.EDITOR, 200, 200, 200, 403, 403, 403), + # Normal user cannot do anything + (TenantAccountRole.NORMAL, 403, 403, 403, 403, 403, 403), + # Dataset operator cannot do anything + (TenantAccountRole.DATASET_OPERATOR, 403, 403, 403, 403, 403, 403), + ], + ) + def test_trigger_subscription_permissions( + self, + test_client: FlaskClient, + auth_header, + monkeypatch, + mock_account, + role: TenantAccountRole, + list_status: int, + get_status: int, + update_status: int, + create_status: int, + build_status: int, + delete_status: int, + ): + """Test that different roles have appropriate permissions for trigger subscription operations.""" + # Set user role + mock_account.role = role + + # Mock current user + monkeypatch.setattr(trigger_providers_api, "current_user", mock_account) + + # Mock AccountService.load_user to prevent authentication issues + from services.account_service import AccountService + + mock_load_user = mock.Mock(return_value=mock_account) + monkeypatch.setattr(AccountService, "load_user", mock_load_user) + + # Test data + provider = "some_provider/some_trigger" + subscription_builder_id = str(uuid.uuid4()) + subscription_id = str(uuid.uuid4()) + + # Mock service methods + mock_list_subscriptions = mock.Mock(return_value=[]) + monkeypatch.setattr( + "services.trigger.trigger_provider_service.TriggerProviderService.list_trigger_provider_subscriptions", + mock_list_subscriptions, + ) + + mock_get_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.get_subscription_builder_by_id", + mock_get_subscription_builder, + ) + + mock_update_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.update_trigger_subscription_builder", + mock_update_subscription_builder, + ) + + mock_create_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.create_trigger_subscription_builder", + mock_create_subscription_builder, + ) + + mock_update_and_build_builder = mock.Mock() + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.update_and_build_builder", + mock_update_and_build_builder, + ) + + mock_delete_provider = mock.Mock() + mock_delete_plugin_trigger = mock.Mock() + mock_db_session = mock.Mock() + mock_db_session.commit = mock.Mock() + + def mock_session_func(engine=None): + return mock_session_context + + mock_session_context = mock.Mock() + mock_session_context.__enter__.return_value = mock_db_session + mock_session_context.__exit__.return_value = None + + monkeypatch.setattr("services.trigger.trigger_provider_service.Session", mock_session_func) + monkeypatch.setattr("services.trigger.trigger_subscription_operator_service.Session", mock_session_func) + + monkeypatch.setattr( + "services.trigger.trigger_provider_service.TriggerProviderService.delete_trigger_provider", + mock_delete_provider, + ) + monkeypatch.setattr( + "services.trigger.trigger_subscription_operator_service.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription", + mock_delete_plugin_trigger, + ) + + # Test 1: List subscriptions (should work for Editor, Admin, Owner) + response = test_client.get( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/list", + headers=auth_header, + ) + assert response.status_code == list_status + + # Test 2: Get subscription builder (should work for Editor, Admin, Owner) + response = test_client.get( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/{subscription_builder_id}", + headers=auth_header, + ) + assert response.status_code == get_status + + # Test 3: Update subscription builder parameters (should work for Editor, Admin, Owner) + response = test_client.post( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/update/{subscription_builder_id}", + headers=auth_header, + json={"parameters": {"webhook_url": "https://example.com/webhook"}}, + ) + assert response.status_code == update_status + + # Test 4: Create subscription builder (should only work for Admin, Owner) + response = test_client.post( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/create", + headers=auth_header, + json={"credential_type": "api_key"}, + ) + assert response.status_code == create_status + + # Test 5: Build/activate subscription (should only work for Admin, Owner) + response = test_client.post( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/build/{subscription_builder_id}", + headers=auth_header, + json={"name": "Test Subscription"}, + ) + assert response.status_code == build_status + + # Test 6: Delete subscription (should only work for Admin, Owner) + response = test_client.post( + f"/console/api/workspaces/current/trigger-provider/{subscription_id}/subscriptions/delete", + headers=auth_header, + ) + assert response.status_code == delete_status + + @pytest.mark.parametrize( + ("role", "status"), + [ + (TenantAccountRole.OWNER, 200), + (TenantAccountRole.ADMIN, 200), + # Editor should be able to access logs for debugging + (TenantAccountRole.EDITOR, 200), + (TenantAccountRole.NORMAL, 403), + (TenantAccountRole.DATASET_OPERATOR, 403), + ], + ) + def test_trigger_subscription_logs_permissions( + self, + test_client: FlaskClient, + auth_header, + monkeypatch, + mock_account, + role: TenantAccountRole, + status: int, + ): + """Test that different roles have appropriate permissions for accessing subscription logs.""" + # Set user role + mock_account.role = role + + # Mock current user + monkeypatch.setattr(trigger_providers_api, "current_user", mock_account) + + # Mock AccountService.load_user to prevent authentication issues + from services.account_service import AccountService + + mock_load_user = mock.Mock(return_value=mock_account) + monkeypatch.setattr(AccountService, "load_user", mock_load_user) + + # Test data + provider = "some_provider/some_trigger" + subscription_builder_id = str(uuid.uuid4()) + + # Mock service method + mock_list_logs = mock.Mock(return_value=[]) + monkeypatch.setattr( + "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.list_logs", + mock_list_logs, + ) + + # Test access to logs + response = test_client.get( + f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/logs/{subscription_builder_id}", + headers=auth_header, + ) + assert response.status_code == status diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index 8cb3572c47..612210ef86 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -227,6 +227,7 @@ class TestModelProviderService: mock_provider_entity.label = {"en_US": "OpenAI", "zh_Hans": "OpenAI"} mock_provider_entity.description = {"en_US": "OpenAI provider", "zh_Hans": "OpenAI 提供商"} mock_provider_entity.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"} + mock_provider_entity.icon_small_dark = None mock_provider_entity.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"} mock_provider_entity.background = "#FF6B6B" mock_provider_entity.help = None @@ -300,6 +301,7 @@ class TestModelProviderService: mock_provider_entity_llm.label = {"en_US": "OpenAI", "zh_Hans": "OpenAI"} mock_provider_entity_llm.description = {"en_US": "OpenAI provider", "zh_Hans": "OpenAI 提供商"} mock_provider_entity_llm.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"} + mock_provider_entity_llm.icon_small_dark = None mock_provider_entity_llm.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"} mock_provider_entity_llm.background = "#FF6B6B" mock_provider_entity_llm.help = None @@ -313,6 +315,7 @@ class TestModelProviderService: mock_provider_entity_embedding.label = {"en_US": "Cohere", "zh_Hans": "Cohere"} mock_provider_entity_embedding.description = {"en_US": "Cohere provider", "zh_Hans": "Cohere 提供商"} mock_provider_entity_embedding.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"} + mock_provider_entity_embedding.icon_small_dark = None mock_provider_entity_embedding.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"} mock_provider_entity_embedding.background = "#4ECDC4" mock_provider_entity_embedding.help = None @@ -1023,6 +1026,7 @@ class TestModelProviderService: provider="openai", label={"en_US": "OpenAI", "zh_Hans": "OpenAI"}, icon_small={"en_US": "icon_small.png", "zh_Hans": "icon_small.png"}, + icon_small_dark=None, icon_large={"en_US": "icon_large.png", "zh_Hans": "icon_large.png"}, ), model="gpt-3.5-turbo", @@ -1040,6 +1044,7 @@ class TestModelProviderService: provider="openai", label={"en_US": "OpenAI", "zh_Hans": "OpenAI"}, icon_small={"en_US": "icon_small.png", "zh_Hans": "icon_small.png"}, + icon_small_dark=None, icon_large={"en_US": "icon_large.png", "zh_Hans": "icon_large.png"}, ), model="gpt-4", diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 9478bb9ddb..088d6ba6ba 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.ext_database import db from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -95,7 +95,7 @@ class TestAddDocumentToIndexTask: created_by=account.id, indexing_status="completed", enabled=True, - doc_form=IndexType.PARAGRAPH_INDEX, + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db.session.add(document) db.session.commit() @@ -172,7 +172,9 @@ class TestAddDocumentToIndexTask: # Assert: Verify the expected outcomes # Verify index processor was called correctly - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify database state changes @@ -204,7 +206,7 @@ class TestAddDocumentToIndexTask: ) # Update document to use different index type - document.doc_form = IndexType.QA_INDEX + document.doc_form = IndexStructureType.QA_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -221,7 +223,9 @@ class TestAddDocumentToIndexTask: add_document_to_index_task(document.id) # Assert: Verify different index type handling - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.QA_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with correct parameters @@ -360,7 +364,7 @@ class TestAddDocumentToIndexTask: ) # Update document to use parent-child index type - document.doc_form = IndexType.PARENT_CHILD_INDEX + document.doc_form = IndexStructureType.PARENT_CHILD_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -391,7 +395,7 @@ class TestAddDocumentToIndexTask: # Assert: Verify parent-child index processing mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( - IndexType.PARENT_CHILD_INDEX + IndexStructureType.PARENT_CHILD_INDEX ) mock_external_service_dependencies["index_processor"].load.assert_called_once() @@ -465,8 +469,10 @@ class TestAddDocumentToIndexTask: # Act: Execute the task add_document_to_index_task(document.id) - # Assert: Verify index processing occurred with all completed segments - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + # Assert: Verify index processing occurred but with empty documents list + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with all completed segments @@ -532,7 +538,9 @@ class TestAddDocumentToIndexTask: assert len(remaining_logs) == 0 # Verify index processing occurred normally - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify segments were enabled @@ -699,7 +707,9 @@ class TestAddDocumentToIndexTask: add_document_to_index_task(document.id) # Assert: Verify only eligible segments were processed - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with correct parameters diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 94e9b76965..37d886f569 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch from faker import Faker -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Dataset, Document, DocumentSegment, Tenant from tasks.delete_segment_from_index_task import delete_segment_from_index_task @@ -164,7 +164,7 @@ class TestDeleteSegmentFromIndexTask: document.updated_at = fake.date_time_this_year() document.doc_type = kwargs.get("doc_type", "text") document.doc_metadata = kwargs.get("doc_metadata", {}) - document.doc_form = kwargs.get("doc_form", IndexType.PARAGRAPH_INDEX) + document.doc_form = kwargs.get("doc_form", IndexStructureType.PARAGRAPH_INDEX) document.doc_language = kwargs.get("doc_language", "en") db_session_with_containers.add(document) @@ -244,8 +244,11 @@ class TestDeleteSegmentFromIndexTask: mock_processor = MagicMock() mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + # Extract segment IDs for the task + segment_ids = [segment.id for segment in segments] + # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed successfully assert result is None # Task should return None on success @@ -279,7 +282,7 @@ class TestDeleteSegmentFromIndexTask: index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)] # Execute the task with non-existent dataset - result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id) + result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id, []) # Verify the task completed without exceptions assert result is None # Task should return None when dataset not found @@ -305,7 +308,7 @@ class TestDeleteSegmentFromIndexTask: index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)] # Execute the task with non-existent document - result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id, []) # Verify the task completed without exceptions assert result is None # Task should return None when document not found @@ -330,9 +333,10 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Execute the task with disabled document - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without exceptions assert result is None # Task should return None when document is disabled @@ -357,9 +361,10 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Execute the task with archived document - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without exceptions assert result is None # Task should return None when document is archived @@ -386,9 +391,10 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Execute the task with incomplete indexing - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without exceptions assert result is None # Task should return None when indexing is not completed @@ -409,7 +415,11 @@ class TestDeleteSegmentFromIndexTask: fake = Faker() # Test different document forms - document_forms = [IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX, IndexType.PARENT_CHILD_INDEX] + document_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in document_forms: # Create test data for each document form @@ -420,13 +430,14 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 2, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Mock the index processor mock_processor = MagicMock() mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed successfully assert result is None @@ -469,6 +480,7 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Mock the index processor to raise an exception mock_processor = MagicMock() @@ -476,7 +488,7 @@ class TestDeleteSegmentFromIndexTask: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - should not raise exception - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without raising exceptions assert result is None # Task should return None even when exceptions occur @@ -518,7 +530,7 @@ class TestDeleteSegmentFromIndexTask: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, []) # Verify the task completed successfully assert result is None @@ -555,13 +567,14 @@ class TestDeleteSegmentFromIndexTask: # Create large number of segments segments = self._create_test_document_segments(db_session_with_containers, document, account, 50, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Mock the index processor mock_processor = MagicMock() mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed successfully assert result is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py new file mode 100644 index 0000000000..aca4be1ffd --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -0,0 +1,763 @@ +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document, DocumentSegment +from tasks.duplicate_document_indexing_task import ( + _duplicate_document_indexing_task, # Core function + _duplicate_document_indexing_task_with_tenant_queue, # Tenant queue wrapper function + duplicate_document_indexing_task, # Deprecated old interface + normal_duplicate_document_indexing_task, # New normal task + priority_duplicate_document_indexing_task, # New priority task +) + + +class TestDuplicateDocumentIndexingTasks: + """Integration tests for duplicate document indexing tasks using testcontainers. + + This test class covers: + - Core _duplicate_document_indexing_task function + - Deprecated duplicate_document_indexing_task function + - New normal_duplicate_document_indexing_task function + - New priority_duplicate_document_indexing_task function + - Tenant queue wrapper _duplicate_document_indexing_task_with_tenant_queue function + - Document segment cleanup logic + """ + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_indexing_runner, + patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_feature_service, + patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_index_processor_factory, + ): + # Setup mock indexing runner + mock_runner_instance = MagicMock() + mock_indexing_runner.return_value = mock_runner_instance + + # Setup mock feature service + mock_features = MagicMock() + mock_features.billing.enabled = False + mock_feature_service.get_features.return_value = mock_features + + # Setup mock index processor factory + mock_processor = MagicMock() + mock_processor.clean = MagicMock() + mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + + yield { + "indexing_runner": mock_indexing_runner, + "indexing_runner_instance": mock_runner_instance, + "feature_service": mock_feature_service, + "features": mock_features, + "index_processor_factory": mock_index_processor_factory, + "index_processor": mock_processor, + } + + def _create_test_dataset_and_documents( + self, db_session_with_containers, mock_external_service_dependencies, document_count=3 + ): + """ + Helper method to create a test dataset and documents for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + document_count: Number of documents to create + + Returns: + tuple: (dataset, documents) - Created dataset and document instances + """ + fake = Faker() + + # Create account and tenant + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Create dataset + dataset = Dataset( + id=fake.uuid4(), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + ) + db.session.add(dataset) + db.session.commit() + + # Create documents + documents = [] + for i in range(document_count): + document = Document( + id=fake.uuid4(), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=i, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=account.id, + indexing_status="waiting", + enabled=True, + doc_form="text_model", + ) + db.session.add(document) + documents.append(document) + + db.session.commit() + + # Refresh dataset to ensure it's properly loaded + db.session.refresh(dataset) + + return dataset, documents + + def _create_test_dataset_with_segments( + self, db_session_with_containers, mock_external_service_dependencies, document_count=3, segments_per_doc=2 + ): + """ + Helper method to create a test dataset with documents and segments. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + document_count: Number of documents to create + segments_per_doc: Number of segments per document + + Returns: + tuple: (dataset, documents, segments) - Created dataset, documents and segments + """ + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count + ) + + fake = Faker() + segments = [] + + # Create segments for each document + for document in documents: + for i in range(segments_per_doc): + segment = DocumentSegment( + id=fake.uuid4(), + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + document_id=document.id, + position=i, + index_node_id=f"{document.id}-node-{i}", + index_node_hash=fake.sha256(), + content=fake.text(max_nb_chars=200), + word_count=50, + tokens=100, + status="completed", + enabled=True, + indexing_at=fake.date_time_this_year(), + created_by=dataset.created_by, # Add required field + ) + db.session.add(segment) + segments.append(segment) + + db.session.commit() + + # Refresh to ensure all relationships are loaded + for document in documents: + db.session.refresh(document) + + return dataset, documents, segments + + def _create_test_dataset_with_billing_features( + self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + ): + """ + Helper method to create a test dataset with billing features configured. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + billing_enabled: Whether billing is enabled + + Returns: + tuple: (dataset, documents) - Created dataset and document instances + """ + fake = Faker() + + # Create account and tenant + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Create dataset + dataset = Dataset( + id=fake.uuid4(), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + ) + db.session.add(dataset) + db.session.commit() + + # Create documents + documents = [] + for i in range(3): + document = Document( + id=fake.uuid4(), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=i, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=account.id, + indexing_status="waiting", + enabled=True, + doc_form="text_model", + ) + db.session.add(document) + documents.append(document) + + db.session.commit() + + # Configure billing features + mock_external_service_dependencies["features"].billing.enabled = billing_enabled + if billing_enabled: + mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.SANDBOX + mock_external_service_dependencies["features"].vector_space.limit = 100 + mock_external_service_dependencies["features"].vector_space.size = 50 + + # Refresh dataset to ensure it's properly loaded + db.session.refresh(dataset) + + return dataset, documents + + def test_duplicate_document_indexing_task_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful duplicate document indexing with multiple documents. + + This test verifies: + - Proper dataset retrieval from database + - Correct document processing and status updates + - IndexingRunner integration + - Database state updates + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=3 + ) + document_ids = [doc.id for doc in documents] + + # Act: Execute the task + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify the expected outcomes + # Verify indexing runner was called correctly + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were updated to parsing status + # Re-query documents from database since _duplicate_document_indexing_task uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + # Verify the run method was called with correct documents + call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args + assert call_args is not None + processed_documents = call_args[0][0] # First argument should be documents list + assert len(processed_documents) == 3 + + def test_duplicate_document_indexing_task_with_segment_cleanup( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test duplicate document indexing with existing segments that need cleanup. + + This test verifies: + - Old segments are identified and cleaned + - Index processor clean method is called + - Segments are deleted from database + - New indexing proceeds after cleanup + """ + # Arrange: Create test data with existing segments + dataset, documents, segments = self._create_test_dataset_with_segments( + db_session_with_containers, mock_external_service_dependencies, document_count=2, segments_per_doc=3 + ) + document_ids = [doc.id for doc in documents] + + # Act: Execute the task + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify segment cleanup + # Verify index processor clean was called for each document with segments + assert mock_external_service_dependencies["index_processor"].clean.call_count == len(documents) + + # Verify segments were deleted from database + # Re-query segments from database since _duplicate_document_indexing_task uses a different session + for segment in segments: + deleted_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first() + assert deleted_segment is None + + # Verify documents were updated to parsing status + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + # Verify indexing runner was called + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + def test_duplicate_document_indexing_task_dataset_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of non-existent dataset. + + This test verifies: + - Proper error handling for missing datasets + - Early return without processing + - Database session cleanup + - No unnecessary indexing runner calls + """ + # Arrange: Use non-existent dataset ID + fake = Faker() + non_existent_dataset_id = fake.uuid4() + document_ids = [fake.uuid4() for _ in range(3)] + + # Act: Execute the task with non-existent dataset + _duplicate_document_indexing_task(non_existent_dataset_id, document_ids) + + # Assert: Verify no processing occurred + mock_external_service_dependencies["indexing_runner"].assert_not_called() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() + mock_external_service_dependencies["index_processor"].clean.assert_not_called() + + def test_duplicate_document_indexing_task_document_not_found_in_dataset( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling when some documents don't exist in the dataset. + + This test verifies: + - Only existing documents are processed + - Non-existent documents are ignored + - Indexing runner receives only valid documents + - Database state updates correctly + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + + # Mix existing and non-existent document IDs + fake = Faker() + existing_document_ids = [doc.id for doc in documents] + non_existent_document_ids = [fake.uuid4() for _ in range(2)] + all_document_ids = existing_document_ids + non_existent_document_ids + + # Act: Execute the task with mixed document IDs + _duplicate_document_indexing_task(dataset.id, all_document_ids) + + # Assert: Verify only existing documents were processed + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify only existing documents were updated + # Re-query documents from database since _duplicate_document_indexing_task uses a different session + for doc_id in existing_document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + # Verify the run method was called with only existing documents + call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args + assert call_args is not None + processed_documents = call_args[0][0] # First argument should be documents list + assert len(processed_documents) == 2 # Only existing documents + + def test_duplicate_document_indexing_task_indexing_runner_exception( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of IndexingRunner exceptions. + + This test verifies: + - Exceptions from IndexingRunner are properly caught + - Task completes without raising exceptions + - Database session is properly closed + - Error logging occurs + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Mock IndexingRunner to raise an exception + mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception( + "Indexing runner failed" + ) + + # Act: Execute the task + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify exception was handled gracefully + # The task should complete without raising exceptions + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify documents were still updated to parsing status before the exception + # Re-query documents from database since _duplicate_document_indexing_task close the session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + + def test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test billing validation for sandbox plan batch upload limit. + + This test verifies: + - Sandbox plan batch upload limit enforcement + - Error handling for batch upload limit exceeded + - Document status updates to error state + - Proper error message recording + """ + # Arrange: Create test data with billing enabled + dataset, documents = self._create_test_dataset_with_billing_features( + db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + ) + + # Configure sandbox plan with batch limit + mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.SANDBOX + + # Create more documents than sandbox plan allows (limit is 1) + fake = Faker() + extra_documents = [] + for i in range(2): # Total will be 5 documents (3 existing + 2 new) + document = Document( + id=fake.uuid4(), + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=i + 3, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=dataset.created_by, + indexing_status="waiting", + enabled=True, + doc_form="text_model", + ) + db.session.add(document) + extra_documents.append(document) + + db.session.commit() + all_documents = documents + extra_documents + document_ids = [doc.id for doc in all_documents] + + # Act: Execute the task with too many documents for sandbox plan + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify error handling + # Re-query documents from database since _duplicate_document_indexing_task uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "error" + assert updated_document.error is not None + assert "batch upload" in updated_document.error.lower() + assert updated_document.stopped_at is not None + + # Verify indexing runner was not called due to early validation error + mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() + + def test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test billing validation for vector space limit. + + This test verifies: + - Vector space limit enforcement + - Error handling for vector space limit exceeded + - Document status updates to error state + - Proper error message recording + """ + # Arrange: Create test data with billing enabled + dataset, documents = self._create_test_dataset_with_billing_features( + db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + ) + + # Configure TEAM plan with vector space limit exceeded + mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.TEAM + mock_external_service_dependencies["features"].vector_space.limit = 100 + mock_external_service_dependencies["features"].vector_space.size = 98 # Almost at limit + + document_ids = [doc.id for doc in documents] # 3 documents will exceed limit + + # Act: Execute the task with documents that will exceed vector space limit + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify error handling + # Re-query documents from database since _duplicate_document_indexing_task uses a different session + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "error" + assert updated_document.error is not None + assert "limit" in updated_document.error.lower() + assert updated_document.stopped_at is not None + + # Verify indexing runner was not called due to early validation error + mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() + + def test_duplicate_document_indexing_task_with_empty_document_list( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of empty document list. + + This test verifies: + - Empty document list is handled gracefully + - No processing occurs + - No errors are raised + - Database session is properly closed + """ + # Arrange: Create test dataset + dataset, _ = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=0 + ) + document_ids = [] + + # Act: Execute the task with empty document list + _duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify IndexingRunner was called with empty list + # Note: The actual implementation does call run([]) with empty list + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once_with([]) + + def test_deprecated_duplicate_document_indexing_task_delegates_to_core( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test that deprecated duplicate_document_indexing_task delegates to core function. + + This test verifies: + - Deprecated function calls core _duplicate_document_indexing_task + - Proper parameter passing + - Backward compatibility + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Act: Execute the deprecated task + duplicate_document_indexing_task(dataset.id, document_ids) + + # Assert: Verify core function was executed + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Clear session cache to see database updates from task's session + db.session.expire_all() + + # Verify documents were processed + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + + @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") + def test_normal_duplicate_document_indexing_task_with_tenant_queue( + self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test normal_duplicate_document_indexing_task with tenant isolation queue. + + This test verifies: + - Task uses tenant isolation queue correctly + - Core processing function is called + - Queue management (pull tasks, delete key) works properly + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Mock tenant isolated queue to return no next tasks + mock_queue = MagicMock() + mock_queue.pull_tasks.return_value = [] + mock_queue_class.return_value = mock_queue + + # Act: Execute the normal task + normal_duplicate_document_indexing_task(dataset.tenant_id, dataset.id, document_ids) + + # Assert: Verify processing occurred + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify tenant queue was used + mock_queue_class.assert_called_with(dataset.tenant_id, "duplicate_document_indexing") + mock_queue.pull_tasks.assert_called_once() + mock_queue.delete_task_key.assert_called_once() + + # Clear session cache to see database updates from task's session + db.session.expire_all() + + # Verify documents were processed + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + + @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") + def test_priority_duplicate_document_indexing_task_with_tenant_queue( + self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test priority_duplicate_document_indexing_task with tenant isolation queue. + + This test verifies: + - Task uses tenant isolation queue correctly + - Core processing function is called + - Queue management works properly + - Same behavior as normal task with different queue assignment + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Mock tenant isolated queue to return no next tasks + mock_queue = MagicMock() + mock_queue.pull_tasks.return_value = [] + mock_queue_class.return_value = mock_queue + + # Act: Execute the priority task + priority_duplicate_document_indexing_task(dataset.tenant_id, dataset.id, document_ids) + + # Assert: Verify processing occurred + mock_external_service_dependencies["indexing_runner"].assert_called_once() + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + # Verify tenant queue was used + mock_queue_class.assert_called_with(dataset.tenant_id, "duplicate_document_indexing") + mock_queue.pull_tasks.assert_called_once() + mock_queue.delete_task_key.assert_called_once() + + # Clear session cache to see database updates from task's session + db.session.expire_all() + + # Verify documents were processed + for doc_id in document_ids: + updated_document = db.session.query(Document).where(Document.id == doc_id).first() + assert updated_document.indexing_status == "parsing" + + @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") + def test_tenant_queue_wrapper_processes_next_tasks( + self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test tenant queue wrapper processes next queued tasks. + + This test verifies: + - After completing current task, next tasks are pulled from queue + - Next tasks are executed correctly + - Task waiting time is set for next tasks + """ + # Arrange: Create test data + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + document_ids = [doc.id for doc in documents] + + # Extract values before session detachment + tenant_id = dataset.tenant_id + dataset_id = dataset.id + + # Mock tenant isolated queue to return next task + mock_queue = MagicMock() + next_task = { + "tenant_id": tenant_id, + "dataset_id": dataset_id, + "document_ids": document_ids, + } + mock_queue.pull_tasks.return_value = [next_task] + mock_queue_class.return_value = mock_queue + + # Mock the task function to track calls + mock_task_func = MagicMock() + + # Act: Execute the wrapper function + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert: Verify next task was scheduled + mock_queue.pull_tasks.assert_called_once() + mock_queue.set_task_waiting_time.assert_called_once() + mock_task_func.delay.assert_called_once_with( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_ids=document_ids, + ) + mock_queue.delete_task_key.assert_not_called() diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index 798fe091ab..b738646736 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.ext_database import db from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -95,7 +95,7 @@ class TestEnableSegmentsToIndexTask: created_by=account.id, indexing_status="completed", enabled=True, - doc_form=IndexType.PARAGRAPH_INDEX, + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db.session.add(document) db.session.commit() @@ -166,7 +166,7 @@ class TestEnableSegmentsToIndexTask: ) # Update document to use different index type - document.doc_form = IndexType.QA_INDEX + document.doc_form = IndexStructureType.QA_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -185,7 +185,9 @@ class TestEnableSegmentsToIndexTask: enable_segments_to_index_task(segment_ids, dataset.id, document.id) # Assert: Verify different index type handling - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.QA_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with correct parameters @@ -328,7 +330,9 @@ class TestEnableSegmentsToIndexTask: enable_segments_to_index_task(non_existent_segment_ids, dataset.id, document.id) # Assert: Verify index processor was created but load was not called - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_enable_segments_to_index_with_parent_child_structure( @@ -350,7 +354,7 @@ class TestEnableSegmentsToIndexTask: ) # Update document to use parent-child index type - document.doc_form = IndexType.PARENT_CHILD_INDEX + document.doc_form = IndexStructureType.PARENT_CHILD_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -383,7 +387,7 @@ class TestEnableSegmentsToIndexTask: # Assert: Verify parent-child index processing mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( - IndexType.PARENT_CHILD_INDEX + IndexStructureType.PARENT_CHILD_INDEX ) mock_external_service_dependencies["index_processor"].load.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/trigger/__init__.py b/api/tests/test_containers_integration_tests/trigger/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/tests/test_containers_integration_tests/trigger/__init__.py @@ -0,0 +1 @@ + diff --git a/api/tests/test_containers_integration_tests/trigger/conftest.py b/api/tests/test_containers_integration_tests/trigger/conftest.py new file mode 100644 index 0000000000..9c1fd5e0ec --- /dev/null +++ b/api/tests/test_containers_integration_tests/trigger/conftest.py @@ -0,0 +1,182 @@ +""" +Fixtures for trigger integration tests. + +This module provides fixtures for creating test data (tenant, account, app) +and mock objects used across trigger-related tests. +""" + +from __future__ import annotations + +from collections.abc import Generator +from typing import Any + +import pytest +from sqlalchemy.orm import Session + +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.model import App + + +@pytest.fixture +def tenant_and_account(db_session_with_containers: Session) -> Generator[tuple[Tenant, Account], None, None]: + """ + Create a tenant and account for testing. + + This fixture creates a tenant, account, and their association, + then cleans up after the test completes. + + Yields: + tuple[Tenant, Account]: The created tenant and account + """ + tenant = Tenant(name="trigger-e2e") + account = Account(name="tester", email="tester@example.com", interface_language="en-US") + db_session_with_containers.add_all([tenant, account]) + db_session_with_containers.commit() + + join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=TenantAccountRole.OWNER.value) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + yield tenant, account + + # Cleanup + db_session_with_containers.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() + db_session_with_containers.query(Account).filter_by(id=account.id).delete() + db_session_with_containers.query(Tenant).filter_by(id=tenant.id).delete() + db_session_with_containers.commit() + + +@pytest.fixture +def app_model( + db_session_with_containers: Session, tenant_and_account: tuple[Tenant, Account] +) -> Generator[App, None, None]: + """ + Create an app for testing. + + This fixture creates a workflow app associated with the tenant and account, + then cleans up after the test completes. + + Yields: + App: The created app + """ + tenant, account = tenant_and_account + app = App( + tenant_id=tenant.id, + name="trigger-app", + description="trigger e2e", + mode="workflow", + icon_type="emoji", + icon="robot", + icon_background="#FFEAD5", + enable_site=True, + enable_api=True, + api_rpm=100, + api_rph=1000, + is_demo=False, + is_public=False, + is_universal=False, + created_by=account.id, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + + yield app + + # Cleanup - delete related records first + from models.trigger import ( + AppTrigger, + TriggerSubscription, + WorkflowPluginTrigger, + WorkflowSchedulePlan, + WorkflowTriggerLog, + WorkflowWebhookTrigger, + ) + from models.workflow import Workflow + + db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app.id).delete() + db_session_with_containers.query(WorkflowSchedulePlan).filter_by(app_id=app.id).delete() + db_session_with_containers.query(WorkflowWebhookTrigger).filter_by(app_id=app.id).delete() + db_session_with_containers.query(WorkflowPluginTrigger).filter_by(app_id=app.id).delete() + db_session_with_containers.query(AppTrigger).filter_by(app_id=app.id).delete() + db_session_with_containers.query(TriggerSubscription).filter_by(tenant_id=tenant.id).delete() + db_session_with_containers.query(Workflow).filter_by(app_id=app.id).delete() + db_session_with_containers.query(App).filter_by(id=app.id).delete() + db_session_with_containers.commit() + + +class MockCeleryGroup: + """Mock for celery group() function that collects dispatched tasks.""" + + def __init__(self) -> None: + self.collected: list[dict[str, Any]] = [] + self._applied = False + + def __call__(self, items: Any) -> MockCeleryGroup: + self.collected = list(items) + return self + + def apply_async(self) -> None: + self._applied = True + + @property + def applied(self) -> bool: + return self._applied + + +class MockCelerySignature: + """Mock for celery task signature that returns task info dict.""" + + def s(self, schedule_id: str) -> dict[str, str]: + return {"schedule_id": schedule_id} + + +@pytest.fixture +def mock_celery_group() -> MockCeleryGroup: + """ + Provide a mock celery group for testing task dispatch. + + Returns: + MockCeleryGroup: Mock group that collects dispatched tasks + """ + return MockCeleryGroup() + + +@pytest.fixture +def mock_celery_signature() -> MockCelerySignature: + """ + Provide a mock celery signature for testing task dispatch. + + Returns: + MockCelerySignature: Mock signature generator + """ + return MockCelerySignature() + + +class MockPluginSubscription: + """Mock plugin subscription for testing plugin triggers.""" + + def __init__( + self, + subscription_id: str = "sub-1", + tenant_id: str = "tenant-1", + provider_id: str = "provider-1", + ) -> None: + self.id = subscription_id + self.tenant_id = tenant_id + self.provider_id = provider_id + self.credentials: dict[str, str] = {"token": "secret"} + self.credential_type = "api-key" + + def to_entity(self) -> MockPluginSubscription: + return self + + +@pytest.fixture +def mock_plugin_subscription() -> MockPluginSubscription: + """ + Provide a mock plugin subscription for testing. + + Returns: + MockPluginSubscription: Mock subscription instance + """ + return MockPluginSubscription() diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py new file mode 100644 index 0000000000..604d68f257 --- /dev/null +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -0,0 +1,911 @@ +from __future__ import annotations + +import importlib +import json +import time +from datetime import timedelta +from types import SimpleNamespace +from typing import Any + +import pytest +from flask import Flask, Response +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from configs import dify_config +from core.plugin.entities.request import TriggerInvokeEventResponse +from core.trigger.debug import event_selectors +from core.trigger.debug.event_bus import TriggerDebugEventBus +from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller +from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key +from core.workflow.enums import NodeType +from libs.datetime_utils import naive_utc_now +from models.account import Account, Tenant +from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus +from models.model import App +from models.trigger import ( + AppTrigger, + TriggerSubscription, + WorkflowPluginTrigger, + WorkflowSchedulePlan, + WorkflowTriggerLog, + WorkflowWebhookTrigger, +) +from models.workflow import Workflow +from schedule import workflow_schedule_task +from schedule.workflow_schedule_task import poll_workflow_schedules +from services import feature_service as feature_service_module +from services.trigger import webhook_service +from services.trigger.schedule_service import ScheduleService +from services.workflow_service import WorkflowService +from tasks import trigger_processing_tasks + +from .conftest import MockCeleryGroup, MockCelerySignature, MockPluginSubscription + +# Test constants +WEBHOOK_ID_PRODUCTION = "wh1234567890123456789012" +WEBHOOK_ID_DEBUG = "whdebug1234567890123456" +TEST_TRIGGER_URL = "https://trigger.example.com/base" + + +def _build_workflow_graph(root_node_id: str, trigger_type: NodeType) -> str: + """Build a minimal workflow graph JSON for testing.""" + node_data: dict[str, Any] = {"type": trigger_type.value, "title": "trigger"} + if trigger_type == NodeType.TRIGGER_WEBHOOK: + node_data.update( + { + "method": "POST", + "content_type": "application/json", + "headers": [], + "params": [], + "body": [], + } + ) + graph = { + "nodes": [ + {"id": root_node_id, "data": node_data}, + {"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}}, + ], + "edges": [{"source": root_node_id, "target": "answer-1", "sourceHandle": "success"}], + } + return json.dumps(graph) + + +def test_publish_blocks_start_and_trigger_coexistence( + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Publishing should fail when both start and trigger nodes coexist.""" + tenant, account = tenant_and_account + + graph = { + "nodes": [ + {"id": "start", "data": {"type": NodeType.START.value}}, + {"id": "trig", "data": {"type": NodeType.TRIGGER_WEBHOOK.value}}, + ], + "edges": [], + } + draft_workflow = Workflow.new( + tenant_id=tenant.id, + app_id=app_model.id, + type="workflow", + version=Workflow.VERSION_DRAFT, + graph=json.dumps(graph), + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + db_session_with_containers.add(draft_workflow) + db_session_with_containers.commit() + + workflow_service = WorkflowService() + + monkeypatch.setattr( + feature_service_module.FeatureService, + "get_system_features", + classmethod(lambda _cls: SimpleNamespace(plugin_manager=SimpleNamespace(enabled=False))), + ) + monkeypatch.setattr("services.workflow_service.dify_config", SimpleNamespace(BILLING_ENABLED=False)) + + with pytest.raises(ValueError, match="Start node and trigger nodes cannot coexist"): + workflow_service.publish_workflow(session=db_session_with_containers, app_model=app_model, account=account) + + +def test_trigger_url_uses_config_base(monkeypatch: pytest.MonkeyPatch) -> None: + """TRIGGER_URL config should be reflected in generated webhook and plugin endpoints.""" + original_url = getattr(dify_config, "TRIGGER_URL", None) + + try: + monkeypatch.setattr(dify_config, "TRIGGER_URL", TEST_TRIGGER_URL) + endpoint_module = importlib.reload(importlib.import_module("core.trigger.utils.endpoint")) + + assert ( + endpoint_module.generate_webhook_trigger_endpoint(WEBHOOK_ID_PRODUCTION) + == f"{TEST_TRIGGER_URL}/triggers/webhook/{WEBHOOK_ID_PRODUCTION}" + ) + assert ( + endpoint_module.generate_webhook_trigger_endpoint(WEBHOOK_ID_PRODUCTION, True) + == f"{TEST_TRIGGER_URL}/triggers/webhook-debug/{WEBHOOK_ID_PRODUCTION}" + ) + assert ( + endpoint_module.generate_plugin_trigger_endpoint_url("end-1") == f"{TEST_TRIGGER_URL}/triggers/plugin/end-1" + ) + finally: + # Restore original config and reload module + if original_url is not None: + monkeypatch.setattr(dify_config, "TRIGGER_URL", original_url) + importlib.reload(importlib.import_module("core.trigger.utils.endpoint")) + + +def test_webhook_trigger_creates_trigger_log( + test_client_with_containers: FlaskClient, + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Production webhook trigger should create a trigger log in the database.""" + tenant, account = tenant_and_account + + webhook_node_id = "webhook-node" + graph_json = _build_workflow_graph(webhook_node_id, NodeType.TRIGGER_WEBHOOK) + published_workflow = Workflow.new( + tenant_id=tenant.id, + app_id=app_model.id, + type="workflow", + version=Workflow.version_from_datetime(naive_utc_now()), + graph=graph_json, + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + db_session_with_containers.add(published_workflow) + app_model.workflow_id = published_workflow.id + db_session_with_containers.commit() + + webhook_trigger = WorkflowWebhookTrigger( + app_id=app_model.id, + node_id=webhook_node_id, + tenant_id=tenant.id, + webhook_id=WEBHOOK_ID_PRODUCTION, + created_by=account.id, + ) + app_trigger = AppTrigger( + tenant_id=tenant.id, + app_id=app_model.id, + node_id=webhook_node_id, + trigger_type=AppTriggerType.TRIGGER_WEBHOOK, + status=AppTriggerStatus.ENABLED, + title="webhook", + ) + + db_session_with_containers.add_all([webhook_trigger, app_trigger]) + db_session_with_containers.commit() + + def _fake_trigger_workflow_async(session: Session, user: Any, trigger_data: Any) -> SimpleNamespace: + log = WorkflowTriggerLog( + tenant_id=trigger_data.tenant_id, + app_id=trigger_data.app_id, + workflow_id=trigger_data.workflow_id, + root_node_id=trigger_data.root_node_id, + trigger_metadata=trigger_data.trigger_metadata.model_dump_json() if trigger_data.trigger_metadata else "{}", + trigger_type=trigger_data.trigger_type, + workflow_run_id=None, + outputs=None, + trigger_data=trigger_data.model_dump_json(), + inputs=json.dumps(dict(trigger_data.inputs)), + status=WorkflowTriggerStatus.SUCCEEDED, + error="", + queue_name="triggered_workflow_dispatcher", + celery_task_id="celery-test", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + session.add(log) + session.commit() + return SimpleNamespace(workflow_trigger_log_id=log.id, task_id=None, status="queued", queue="test") + + monkeypatch.setattr( + webhook_service.AsyncWorkflowService, + "trigger_workflow_async", + _fake_trigger_workflow_async, + ) + + response = test_client_with_containers.post(f"/triggers/webhook/{webhook_trigger.webhook_id}", json={"foo": "bar"}) + + assert response.status_code == 200 + + db_session_with_containers.expire_all() + logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all() + assert logs, "Webhook trigger should create trigger log" + + +@pytest.mark.parametrize("schedule_type", ["visual", "cron"]) +def test_schedule_poll_dispatches_due_plan( + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + mock_celery_group: MockCeleryGroup, + mock_celery_signature: MockCelerySignature, + monkeypatch: pytest.MonkeyPatch, + schedule_type: str, +) -> None: + """Schedule plans (both visual and cron) should be polled and dispatched when due.""" + tenant, _ = tenant_and_account + + app_trigger = AppTrigger( + tenant_id=tenant.id, + app_id=app_model.id, + node_id=f"schedule-{schedule_type}", + trigger_type=AppTriggerType.TRIGGER_SCHEDULE, + status=AppTriggerStatus.ENABLED, + title=f"schedule-{schedule_type}", + ) + plan = WorkflowSchedulePlan( + app_id=app_model.id, + node_id=f"schedule-{schedule_type}", + tenant_id=tenant.id, + cron_expression="* * * * *", + timezone="UTC", + next_run_at=naive_utc_now() - timedelta(minutes=1), + ) + db_session_with_containers.add_all([app_trigger, plan]) + db_session_with_containers.commit() + + next_time = naive_utc_now() + timedelta(hours=1) + monkeypatch.setattr(workflow_schedule_task, "calculate_next_run_at", lambda *_args, **_kwargs: next_time) + monkeypatch.setattr(workflow_schedule_task, "group", mock_celery_group) + monkeypatch.setattr(workflow_schedule_task, "run_schedule_trigger", mock_celery_signature) + + poll_workflow_schedules() + + assert mock_celery_group.collected, f"Should dispatch signatures for due {schedule_type} schedules" + scheduled_ids = {sig["schedule_id"] for sig in mock_celery_group.collected} + assert plan.id in scheduled_ids + + +def test_schedule_visual_debug_poll_generates_event(monkeypatch: pytest.MonkeyPatch) -> None: + """Visual mode schedule node should generate event in single-step debug.""" + base_now = naive_utc_now() + monkeypatch.setattr(event_selectors, "naive_utc_now", lambda: base_now) + monkeypatch.setattr( + event_selectors, + "calculate_next_run_at", + lambda *_args, **_kwargs: base_now - timedelta(minutes=1), + ) + node_config = { + "id": "schedule-visual", + "data": { + "type": NodeType.TRIGGER_SCHEDULE.value, + "mode": "visual", + "frequency": "daily", + "visual_config": {"time": "3:00 PM"}, + "timezone": "UTC", + }, + } + poller = event_selectors.ScheduleTriggerDebugEventPoller( + tenant_id="tenant", + user_id="user", + app_id="app", + node_config=node_config, + node_id="schedule-visual", + ) + event = poller.poll() + assert event is not None + assert event.workflow_args["inputs"] == {} + + +def test_plugin_trigger_dispatches_and_debug_events( + test_client_with_containers: FlaskClient, + mock_plugin_subscription: MockPluginSubscription, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Plugin trigger endpoint should dispatch events and generate debug events.""" + endpoint_id = "1cc7fa12-3f7b-4f6a-9c8d-1234567890ab" + + debug_events: list[dict[str, Any]] = [] + dispatched_payloads: list[dict[str, Any]] = [] + + def _fake_process_endpoint(_endpoint_id: str, _request: Any) -> Response: + dispatch_data = { + "user_id": "end-user", + "tenant_id": mock_plugin_subscription.tenant_id, + "endpoint_id": _endpoint_id, + "provider_id": mock_plugin_subscription.provider_id, + "subscription_id": mock_plugin_subscription.id, + "timestamp": int(time.time()), + "events": ["created", "updated"], + "request_id": f"req-{_endpoint_id}", + } + trigger_processing_tasks.dispatch_triggered_workflows_async.delay(dispatch_data) + return Response("ok", status=202) + + monkeypatch.setattr( + "services.trigger.trigger_service.TriggerService.process_endpoint", + staticmethod(_fake_process_endpoint), + ) + + monkeypatch.setattr( + trigger_processing_tasks.TriggerDebugEventBus, + "dispatch", + staticmethod(lambda **kwargs: debug_events.append(kwargs) or 1), + ) + + def _fake_delay(dispatch_data: dict[str, Any]) -> None: + dispatched_payloads.append(dispatch_data) + trigger_processing_tasks.dispatch_trigger_debug_event( + events=dispatch_data["events"], + user_id=dispatch_data["user_id"], + timestamp=dispatch_data["timestamp"], + request_id=dispatch_data["request_id"], + subscription=mock_plugin_subscription, + ) + + monkeypatch.setattr( + trigger_processing_tasks.dispatch_triggered_workflows_async, + "delay", + staticmethod(_fake_delay), + ) + + response = test_client_with_containers.post(f"/triggers/plugin/{endpoint_id}", json={"hello": "world"}) + + assert response.status_code == 202 + assert dispatched_payloads, "Plugin trigger should enqueue workflow dispatch payload" + assert debug_events, "Plugin trigger should dispatch debug events" + dispatched_event_names = {event["event"].name for event in debug_events} + assert dispatched_event_names == {"created", "updated"} + + +def test_webhook_debug_dispatches_event( + test_client_with_containers: FlaskClient, + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Webhook single-step debug should dispatch debug event and be pollable.""" + tenant, account = tenant_and_account + webhook_node_id = "webhook-debug-node" + graph_json = _build_workflow_graph(webhook_node_id, NodeType.TRIGGER_WEBHOOK) + draft_workflow = Workflow.new( + tenant_id=tenant.id, + app_id=app_model.id, + type="workflow", + version=Workflow.VERSION_DRAFT, + graph=graph_json, + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + db_session_with_containers.add(draft_workflow) + db_session_with_containers.commit() + + webhook_trigger = WorkflowWebhookTrigger( + app_id=app_model.id, + node_id=webhook_node_id, + tenant_id=tenant.id, + webhook_id=WEBHOOK_ID_DEBUG, + created_by=account.id, + ) + db_session_with_containers.add(webhook_trigger) + db_session_with_containers.commit() + + debug_events: list[dict[str, Any]] = [] + original_dispatch = TriggerDebugEventBus.dispatch + monkeypatch.setattr( + "controllers.trigger.webhook.TriggerDebugEventBus.dispatch", + lambda **kwargs: (debug_events.append(kwargs), original_dispatch(**kwargs))[1], + ) + + # Listener polls first to enter waiting pool + poller = WebhookTriggerDebugEventPoller( + tenant_id=tenant.id, + user_id=account.id, + app_id=app_model.id, + node_config=draft_workflow.get_node_config_by_id(webhook_node_id), + node_id=webhook_node_id, + ) + assert poller.poll() is None + + response = test_client_with_containers.post( + f"/triggers/webhook-debug/{webhook_trigger.webhook_id}", + json={"foo": "bar"}, + headers={"Content-Type": "application/json"}, + ) + + assert response.status_code == 200 + assert debug_events, "Debug event should be sent to event bus" + # Second poll should get the event + event = poller.poll() + assert event is not None + assert event.workflow_args["inputs"]["webhook_body"]["foo"] == "bar" + assert debug_events[0]["pool_key"].endswith(f":{app_model.id}:{webhook_node_id}") + + +def test_plugin_single_step_debug_flow( + flask_app_with_containers: Flask, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Plugin single-step debug: listen -> dispatch event -> poller receives and returns variables.""" + tenant_id = "tenant-1" + app_id = "app-1" + user_id = "user-1" + node_id = "plugin-node" + provider_id = "langgenius/provider-1/provider-1" + node_config = { + "id": node_id, + "data": { + "type": NodeType.TRIGGER_PLUGIN.value, + "title": "plugin", + "plugin_id": "plugin-1", + "plugin_unique_identifier": "plugin-1", + "provider_id": provider_id, + "event_name": "created", + "subscription_id": "sub-1", + "parameters": {}, + }, + } + # Start listening + poller = PluginTriggerDebugEventPoller( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + node_config=node_config, + node_id=node_id, + ) + assert poller.poll() is None + + from core.trigger.debug.events import build_plugin_pool_key + + pool_key = build_plugin_pool_key( + tenant_id=tenant_id, + provider_id=provider_id, + subscription_id="sub-1", + name="created", + ) + TriggerDebugEventBus.dispatch( + tenant_id=tenant_id, + event=PluginTriggerDebugEvent( + timestamp=int(time.time()), + user_id=user_id, + name="created", + request_id="req-1", + subscription_id="sub-1", + provider_id="provider-1", + ), + pool_key=pool_key, + ) + + from core.plugin.entities.request import TriggerInvokeEventResponse + + monkeypatch.setattr( + "services.trigger.trigger_service.TriggerService.invoke_trigger_event", + staticmethod( + lambda **_kwargs: TriggerInvokeEventResponse( + variables={"echo": "pong"}, + cancelled=False, + ) + ), + ) + + event = poller.poll() + assert event is not None + assert event.workflow_args["inputs"]["echo"] == "pong" + + +def test_schedule_trigger_creates_trigger_log( + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Schedule trigger execution should create WorkflowTriggerLog in database.""" + from tasks import workflow_schedule_tasks + + tenant, account = tenant_and_account + + # Create published workflow with schedule trigger node + schedule_node_id = "schedule-node" + graph = { + "nodes": [ + { + "id": schedule_node_id, + "data": { + "type": NodeType.TRIGGER_SCHEDULE.value, + "title": "schedule", + "mode": "cron", + "cron_expression": "0 9 * * *", + "timezone": "UTC", + }, + }, + {"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}}, + ], + "edges": [{"source": schedule_node_id, "target": "answer-1", "sourceHandle": "success"}], + } + published_workflow = Workflow.new( + tenant_id=tenant.id, + app_id=app_model.id, + type="workflow", + version=Workflow.version_from_datetime(naive_utc_now()), + graph=json.dumps(graph), + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + db_session_with_containers.add(published_workflow) + app_model.workflow_id = published_workflow.id + db_session_with_containers.commit() + + # Create schedule plan + plan = WorkflowSchedulePlan( + app_id=app_model.id, + node_id=schedule_node_id, + tenant_id=tenant.id, + cron_expression="0 9 * * *", + timezone="UTC", + next_run_at=naive_utc_now() - timedelta(minutes=1), + ) + app_trigger = AppTrigger( + tenant_id=tenant.id, + app_id=app_model.id, + node_id=schedule_node_id, + trigger_type=AppTriggerType.TRIGGER_SCHEDULE, + status=AppTriggerStatus.ENABLED, + title="schedule", + ) + db_session_with_containers.add_all([plan, app_trigger]) + db_session_with_containers.commit() + + # Mock AsyncWorkflowService to create WorkflowTriggerLog + def _fake_trigger_workflow_async(session: Session, user: Any, trigger_data: Any) -> SimpleNamespace: + log = WorkflowTriggerLog( + tenant_id=trigger_data.tenant_id, + app_id=trigger_data.app_id, + workflow_id=published_workflow.id, + root_node_id=trigger_data.root_node_id, + trigger_metadata="{}", + trigger_type=AppTriggerType.TRIGGER_SCHEDULE, + workflow_run_id=None, + outputs=None, + trigger_data=trigger_data.model_dump_json(), + inputs=json.dumps(dict(trigger_data.inputs)), + status=WorkflowTriggerStatus.SUCCEEDED, + error="", + queue_name="schedule_executor", + celery_task_id="celery-schedule-test", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + ) + session.add(log) + session.commit() + return SimpleNamespace(workflow_trigger_log_id=log.id, task_id=None, status="queued", queue="test") + + monkeypatch.setattr( + workflow_schedule_tasks.AsyncWorkflowService, + "trigger_workflow_async", + _fake_trigger_workflow_async, + ) + + # Mock quota to avoid rate limiting + from enums import quota_type + + monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited()) + + # Execute schedule trigger + workflow_schedule_tasks.run_schedule_trigger(plan.id) + + # Verify WorkflowTriggerLog was created + db_session_with_containers.expire_all() + logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all() + assert logs, "Schedule trigger should create WorkflowTriggerLog" + assert logs[0].trigger_type == AppTriggerType.TRIGGER_SCHEDULE + assert logs[0].root_node_id == schedule_node_id + + +@pytest.mark.parametrize( + ("mode", "frequency", "visual_config", "cron_expression", "expected_cron"), + [ + # Visual mode: hourly + ("visual", "hourly", {"on_minute": 30}, None, "30 * * * *"), + # Visual mode: daily + ("visual", "daily", {"time": "3:00 PM"}, None, "0 15 * * *"), + # Visual mode: weekly + ("visual", "weekly", {"time": "9:00 AM", "weekdays": ["mon", "wed", "fri"]}, None, "0 9 * * 1,3,5"), + # Visual mode: monthly + ("visual", "monthly", {"time": "10:30 AM", "monthly_days": [1, 15]}, None, "30 10 1,15 * *"), + # Cron mode: direct expression + ("cron", None, None, "*/5 * * * *", "*/5 * * * *"), + ], +) +def test_schedule_visual_cron_conversion( + mode: str, + frequency: str | None, + visual_config: dict[str, Any] | None, + cron_expression: str | None, + expected_cron: str, +) -> None: + """Schedule visual config should correctly convert to cron expression.""" + + node_config: dict[str, Any] = { + "id": "schedule-node", + "data": { + "type": NodeType.TRIGGER_SCHEDULE.value, + "mode": mode, + "timezone": "UTC", + }, + } + + if mode == "visual": + node_config["data"]["frequency"] = frequency + node_config["data"]["visual_config"] = visual_config + else: + node_config["data"]["cron_expression"] = cron_expression + + config = ScheduleService.to_schedule_config(node_config) + + assert config.cron_expression == expected_cron, f"Expected {expected_cron}, got {config.cron_expression}" + assert config.timezone == "UTC" + assert config.node_id == "schedule-node" + + +def test_plugin_trigger_full_chain_with_db_verification( + test_client_with_containers: FlaskClient, + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Plugin trigger should create WorkflowTriggerLog and WorkflowPluginTrigger records.""" + + tenant, account = tenant_and_account + + # Create published workflow with plugin trigger node + plugin_node_id = "plugin-trigger-node" + provider_id = "langgenius/test-provider/test-provider" + subscription_id = "sub-plugin-test" + endpoint_id = "2cc7fa12-3f7b-4f6a-9c8d-1234567890ab" + + graph = { + "nodes": [ + { + "id": plugin_node_id, + "data": { + "type": NodeType.TRIGGER_PLUGIN.value, + "title": "plugin", + "plugin_id": "test-plugin", + "plugin_unique_identifier": "test-plugin", + "provider_id": provider_id, + "event_name": "test_event", + "subscription_id": subscription_id, + "parameters": {}, + }, + }, + {"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}}, + ], + "edges": [{"source": plugin_node_id, "target": "answer-1", "sourceHandle": "success"}], + } + published_workflow = Workflow.new( + tenant_id=tenant.id, + app_id=app_model.id, + type="workflow", + version=Workflow.version_from_datetime(naive_utc_now()), + graph=json.dumps(graph), + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + db_session_with_containers.add(published_workflow) + app_model.workflow_id = published_workflow.id + db_session_with_containers.commit() + + # Create trigger subscription + subscription = TriggerSubscription( + name="test-subscription", + tenant_id=tenant.id, + user_id=account.id, + provider_id=provider_id, + endpoint_id=endpoint_id, + parameters={}, + properties={}, + credentials={"token": "test-secret"}, + credential_type="api-key", + ) + db_session_with_containers.add(subscription) + db_session_with_containers.commit() + + # Update subscription_id to match the created subscription + graph["nodes"][0]["data"]["subscription_id"] = subscription.id + published_workflow.graph = json.dumps(graph) + db_session_with_containers.commit() + + # Create WorkflowPluginTrigger + plugin_trigger = WorkflowPluginTrigger( + app_id=app_model.id, + tenant_id=tenant.id, + node_id=plugin_node_id, + provider_id=provider_id, + event_name="test_event", + subscription_id=subscription.id, + ) + app_trigger = AppTrigger( + tenant_id=tenant.id, + app_id=app_model.id, + node_id=plugin_node_id, + trigger_type=AppTriggerType.TRIGGER_PLUGIN, + status=AppTriggerStatus.ENABLED, + title="plugin", + ) + db_session_with_containers.add_all([plugin_trigger, app_trigger]) + db_session_with_containers.commit() + + # Track dispatched data + dispatched_data: list[dict[str, Any]] = [] + + def _fake_process_endpoint(_endpoint_id: str, _request: Any) -> Response: + dispatch_data = { + "user_id": "end-user", + "tenant_id": tenant.id, + "endpoint_id": _endpoint_id, + "provider_id": provider_id, + "subscription_id": subscription.id, + "timestamp": int(time.time()), + "events": ["test_event"], + "request_id": f"req-{_endpoint_id}", + } + dispatched_data.append(dispatch_data) + return Response("ok", status=202) + + monkeypatch.setattr( + "services.trigger.trigger_service.TriggerService.process_endpoint", + staticmethod(_fake_process_endpoint), + ) + + response = test_client_with_containers.post(f"/triggers/plugin/{endpoint_id}", json={"test": "data"}) + + assert response.status_code == 202 + assert dispatched_data, "Plugin trigger should dispatch event data" + assert dispatched_data[0]["subscription_id"] == subscription.id + assert dispatched_data[0]["events"] == ["test_event"] + + # Verify database records exist + db_session_with_containers.expire_all() + plugin_triggers = ( + db_session_with_containers.query(WorkflowPluginTrigger) + .filter_by(app_id=app_model.id, node_id=plugin_node_id) + .all() + ) + assert plugin_triggers, "WorkflowPluginTrigger record should exist" + assert plugin_triggers[0].provider_id == provider_id + assert plugin_triggers[0].event_name == "test_event" + + +def test_plugin_debug_via_http_endpoint( + test_client_with_containers: FlaskClient, + db_session_with_containers: Session, + tenant_and_account: tuple[Tenant, Account], + app_model: App, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Plugin single-step debug via HTTP endpoint should dispatch debug event and be pollable.""" + + tenant, account = tenant_and_account + + provider_id = "langgenius/debug-provider/debug-provider" + endpoint_id = "3cc7fa12-3f7b-4f6a-9c8d-1234567890ab" + event_name = "debug_event" + + # Create subscription + subscription = TriggerSubscription( + name="debug-subscription", + tenant_id=tenant.id, + user_id=account.id, + provider_id=provider_id, + endpoint_id=endpoint_id, + parameters={}, + properties={}, + credentials={"token": "debug-secret"}, + credential_type="api-key", + ) + db_session_with_containers.add(subscription) + db_session_with_containers.commit() + + # Create plugin trigger node config + node_id = "plugin-debug-node" + node_config = { + "id": node_id, + "data": { + "type": NodeType.TRIGGER_PLUGIN.value, + "title": "plugin-debug", + "plugin_id": "debug-plugin", + "plugin_unique_identifier": "debug-plugin", + "provider_id": provider_id, + "event_name": event_name, + "subscription_id": subscription.id, + "parameters": {}, + }, + } + + # Start listening with poller + + poller = PluginTriggerDebugEventPoller( + tenant_id=tenant.id, + user_id=account.id, + app_id=app_model.id, + node_config=node_config, + node_id=node_id, + ) + assert poller.poll() is None, "First poll should return None (waiting)" + + # Track debug events dispatched + debug_events: list[dict[str, Any]] = [] + original_dispatch = TriggerDebugEventBus.dispatch + + def _tracking_dispatch(**kwargs: Any) -> int: + debug_events.append(kwargs) + return original_dispatch(**kwargs) + + monkeypatch.setattr(TriggerDebugEventBus, "dispatch", staticmethod(_tracking_dispatch)) + + # Mock process_endpoint to trigger debug event dispatch + def _fake_process_endpoint(_endpoint_id: str, _request: Any) -> Response: + # Simulate what happens inside process_endpoint + dispatch_triggered_workflows_async + pool_key = build_plugin_pool_key( + tenant_id=tenant.id, + provider_id=provider_id, + subscription_id=subscription.id, + name=event_name, + ) + TriggerDebugEventBus.dispatch( + tenant_id=tenant.id, + event=PluginTriggerDebugEvent( + timestamp=int(time.time()), + user_id="end-user", + name=event_name, + request_id=f"req-{_endpoint_id}", + subscription_id=subscription.id, + provider_id=provider_id, + ), + pool_key=pool_key, + ) + return Response("ok", status=202) + + monkeypatch.setattr( + "services.trigger.trigger_service.TriggerService.process_endpoint", + staticmethod(_fake_process_endpoint), + ) + + # Call HTTP endpoint + response = test_client_with_containers.post(f"/triggers/plugin/{endpoint_id}", json={"debug": "payload"}) + + assert response.status_code == 202 + assert debug_events, "Debug event should be dispatched via HTTP endpoint" + assert debug_events[0]["event"].name == event_name + + # Mock invoke_trigger_event for poller + + monkeypatch.setattr( + "services.trigger.trigger_service.TriggerService.invoke_trigger_event", + staticmethod( + lambda **_kwargs: TriggerInvokeEventResponse( + variables={"http_debug": "success"}, + cancelled=False, + ) + ), + ) + + # Second poll should receive the event + event = poller.poll() + assert event is not None, "Poller should receive debug event after HTTP trigger" + assert event.workflow_args["inputs"]["http_debug"] == "success" diff --git a/api/tests/unit_tests/controllers/console/billing/test_billing.py b/api/tests/unit_tests/controllers/console/billing/test_billing.py index eaa489d56b..c80758c857 100644 --- a/api/tests/unit_tests/controllers/console/billing/test_billing.py +++ b/api/tests/unit_tests/controllers/console/billing/test_billing.py @@ -125,7 +125,7 @@ class TestPartnerTenants: resource = PartnerTenants() # Act & Assert - # reqparse will raise BadRequest for missing required field + # Validation should raise BadRequest for missing required field with pytest.raises(BadRequest): resource.put(partner_key_encoded) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_chat_request_payload.py b/api/tests/unit_tests/controllers/service_api/app/test_chat_request_payload.py new file mode 100644 index 0000000000..1fb7e7009d --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/app/test_chat_request_payload.py @@ -0,0 +1,25 @@ +import uuid + +import pytest +from pydantic import ValidationError + +from controllers.service_api.app.completion import ChatRequestPayload + + +def test_chat_request_payload_accepts_blank_conversation_id(): + payload = ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": ""}) + + assert payload.conversation_id is None + + +def test_chat_request_payload_validates_uuid(): + conversation_id = str(uuid.uuid4()) + + payload = ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": conversation_id}) + + assert payload.conversation_id == conversation_id + + +def test_chat_request_payload_rejects_invalid_uuid(): + with pytest.raises(ValidationError): + ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": "invalid"}) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py index 5c484403a6..acff191c79 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py @@ -256,24 +256,18 @@ class TestFilePreviewApi: mock_app, # App query for tenant validation ] - with patch("controllers.service_api.app.file_preview.reqparse") as mock_reqparse: - # Mock request parsing - mock_parser = Mock() - mock_parser.parse_args.return_value = {"as_attachment": False} - mock_reqparse.RequestParser.return_value = mock_parser + # Test the core logic directly without Flask decorators + # Validate file ownership + result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id) + assert result_message_file == mock_message_file + assert result_upload_file == mock_upload_file - # Test the core logic directly without Flask decorators - # Validate file ownership - result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id) - assert result_message_file == mock_message_file - assert result_upload_file == mock_upload_file + # Test file response building + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) + assert response is not None - # Test file response building - response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) - assert response is not None - - # Verify storage was called correctly - mock_storage.load.assert_not_called() # Since we're testing components separately + # Verify storage was called correctly + mock_storage.load.assert_not_called() # Since we're testing components separately @patch("controllers.service_api.app.file_preview.storage") def test_storage_error_handling( diff --git a/api/tests/unit_tests/controllers/test_conversation_rename_payload.py b/api/tests/unit_tests/controllers/test_conversation_rename_payload.py new file mode 100644 index 0000000000..494176cbd9 --- /dev/null +++ b/api/tests/unit_tests/controllers/test_conversation_rename_payload.py @@ -0,0 +1,20 @@ +import pytest +from pydantic import ValidationError + +from controllers.console.explore.conversation import ConversationRenamePayload as ConsolePayload +from controllers.service_api.app.conversation import ConversationRenamePayload as ServicePayload + + +@pytest.mark.parametrize("payload_cls", [ConsolePayload, ServicePayload]) +def test_payload_allows_auto_generate_without_name(payload_cls): + payload = payload_cls.model_validate({"auto_generate": True}) + + assert payload.auto_generate is True + assert payload.name is None + + +@pytest.mark.parametrize("payload_cls", [ConsolePayload, ServicePayload]) +@pytest.mark.parametrize("value", [None, "", " "]) +def test_payload_requires_name_when_not_auto_generate(payload_cls, value): + with pytest.raises(ValidationError): + payload_cls.model_validate({"name": value, "auto_generate": False}) diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index fdab39f133..d622c3a555 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -265,3 +265,82 @@ def test_validate_inputs_with_default_value(): ) assert result == [{"id": "file1", "name": "doc1.pdf"}, {"id": "file2", "name": "doc2.pdf"}] + + +def test_validate_inputs_optional_file_with_empty_string(): + """Test that optional FILE variable with empty string returns None""" + base_app_generator = BaseAppGenerator() + + var_file = VariableEntity( + variable="test_file", + label="test_file", + type=VariableEntityType.FILE, + required=False, + ) + + result = base_app_generator._validate_inputs( + variable_entity=var_file, + value="", + ) + + assert result is None + + +def test_validate_inputs_optional_file_list_with_empty_list(): + """Test that optional FILE_LIST variable with empty list returns None""" + base_app_generator = BaseAppGenerator() + + var_file_list = VariableEntity( + variable="test_file_list", + label="test_file_list", + type=VariableEntityType.FILE_LIST, + required=False, + ) + + result = base_app_generator._validate_inputs( + variable_entity=var_file_list, + value=[], + ) + + assert result is None + + +def test_validate_inputs_required_file_with_empty_string_fails(): + """Test that required FILE variable with empty string still fails validation""" + base_app_generator = BaseAppGenerator() + + var_file = VariableEntity( + variable="test_file", + label="test_file", + type=VariableEntityType.FILE, + required=True, + ) + + with pytest.raises(ValueError) as exc_info: + base_app_generator._validate_inputs( + variable_entity=var_file, + value="", + ) + + assert "must be a file" in str(exc_info.value) + + +def test_validate_inputs_optional_file_with_empty_string_ignores_default(): + """Test that optional FILE variable with empty string returns None, not the default""" + base_app_generator = BaseAppGenerator() + + var_file = VariableEntity( + variable="test_file", + label="test_file", + type=VariableEntityType.FILE, + required=False, + default={"id": "file123", "name": "default.pdf"}, + ) + + # When value is empty string (from frontend), should return None, not default + result = base_app_generator._validate_inputs( + variable_entity=var_file, + value="", + ) + + assert result is None diff --git a/api/tests/unit_tests/core/helper/test_tool_provider_cache.py b/api/tests/unit_tests/core/helper/test_tool_provider_cache.py new file mode 100644 index 0000000000..00f7c9d7e9 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_tool_provider_cache.py @@ -0,0 +1,129 @@ +import json +from unittest.mock import patch + +import pytest +from redis.exceptions import RedisError + +from core.helper.tool_provider_cache import ToolProviderListCache +from core.tools.entities.api_entities import ToolProviderTypeApiLiteral + + +@pytest.fixture +def mock_redis_client(): + """Fixture: Mock Redis client""" + with patch("core.helper.tool_provider_cache.redis_client") as mock: + yield mock + + +class TestToolProviderListCache: + """Test class for ToolProviderListCache""" + + def test_generate_cache_key(self): + """Test cache key generation logic""" + # Scenario 1: Specify typ (valid literal value) + tenant_id = "tenant_123" + typ: ToolProviderTypeApiLiteral = "builtin" + expected_key = f"tool_providers:tenant_id:{tenant_id}:type:{typ}" + assert ToolProviderListCache._generate_cache_key(tenant_id, typ) == expected_key + + # Scenario 2: typ is None (defaults to "all") + expected_key_all = f"tool_providers:tenant_id:{tenant_id}:type:all" + assert ToolProviderListCache._generate_cache_key(tenant_id) == expected_key_all + + def test_get_cached_providers_hit(self, mock_redis_client): + """Test get cached providers - cache hit and successful decoding""" + tenant_id = "tenant_123" + typ: ToolProviderTypeApiLiteral = "api" + mock_providers = [{"id": "tool", "name": "test_provider"}] + mock_redis_client.get.return_value = json.dumps(mock_providers).encode("utf-8") + + result = ToolProviderListCache.get_cached_providers(tenant_id, typ) + + mock_redis_client.get.assert_called_once_with(ToolProviderListCache._generate_cache_key(tenant_id, typ)) + assert result == mock_providers + + def test_get_cached_providers_decode_error(self, mock_redis_client): + """Test get cached providers - cache hit but decoding failed""" + tenant_id = "tenant_123" + mock_redis_client.get.return_value = b"invalid_json_data" + + result = ToolProviderListCache.get_cached_providers(tenant_id) + + assert result is None + mock_redis_client.get.assert_called_once() + + def test_get_cached_providers_miss(self, mock_redis_client): + """Test get cached providers - cache miss""" + tenant_id = "tenant_123" + mock_redis_client.get.return_value = None + + result = ToolProviderListCache.get_cached_providers(tenant_id) + + assert result is None + mock_redis_client.get.assert_called_once() + + def test_set_cached_providers(self, mock_redis_client): + """Test set cached providers""" + tenant_id = "tenant_123" + typ: ToolProviderTypeApiLiteral = "builtin" + mock_providers = [{"id": "tool", "name": "test_provider"}] + cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) + + ToolProviderListCache.set_cached_providers(tenant_id, typ, mock_providers) + + mock_redis_client.setex.assert_called_once_with( + cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(mock_providers) + ) + + def test_invalidate_cache_specific_type(self, mock_redis_client): + """Test invalidate cache - specific type""" + tenant_id = "tenant_123" + typ: ToolProviderTypeApiLiteral = "workflow" + cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ) + + ToolProviderListCache.invalidate_cache(tenant_id, typ) + + mock_redis_client.delete.assert_called_once_with(cache_key) + + def test_invalidate_cache_all_types(self, mock_redis_client): + """Test invalidate cache - clear all tenant cache""" + tenant_id = "tenant_123" + mock_keys = [ + b"tool_providers:tenant_id:tenant_123:type:all", + b"tool_providers:tenant_id:tenant_123:type:builtin", + ] + mock_redis_client.scan_iter.return_value = mock_keys + + ToolProviderListCache.invalidate_cache(tenant_id) + + mock_redis_client.scan_iter.assert_called_once_with(f"tool_providers:tenant_id:{tenant_id}:*") + mock_redis_client.delete.assert_called_once_with(*mock_keys) + + def test_invalidate_cache_no_keys(self, mock_redis_client): + """Test invalidate cache - no cache keys for tenant""" + tenant_id = "tenant_123" + mock_redis_client.scan_iter.return_value = [] + + ToolProviderListCache.invalidate_cache(tenant_id) + + mock_redis_client.delete.assert_not_called() + + def test_redis_fallback_default_return(self, mock_redis_client): + """Test redis_fallback decorator - default return value (Redis error)""" + mock_redis_client.get.side_effect = RedisError("Redis connection error") + + result = ToolProviderListCache.get_cached_providers("tenant_123") + + assert result is None + mock_redis_client.get.assert_called_once() + + def test_redis_fallback_no_default(self, mock_redis_client): + """Test redis_fallback decorator - no default return value (Redis error)""" + mock_redis_client.setex.side_effect = RedisError("Redis connection error") + + try: + ToolProviderListCache.set_cached_providers("tenant_123", "mcp", []) + except RedisError: + pytest.fail("set_cached_providers should not raise RedisError (handled by fallback)") + + mock_redis_client.setex.assert_called_once() diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index d9f6dcc43c..025a0d8d70 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -53,7 +53,7 @@ from sqlalchemy.exc import IntegrityError from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import ModelPropertyKey -from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeConnectionError, @@ -99,10 +99,10 @@ class TestCacheEmbeddingDocuments: @pytest.fixture def sample_embedding_result(self): - """Create a sample TextEmbeddingResult for testing. + """Create a sample EmbeddingResult for testing. Returns: - TextEmbeddingResult: Mock embedding result with proper structure + EmbeddingResult: Mock embedding result with proper structure """ # Create normalized embedding vectors (dimension 1536 for ada-002) embedding_vector = np.random.randn(1536) @@ -118,7 +118,7 @@ class TestCacheEmbeddingDocuments: latency=0.5, ) - return TextEmbeddingResult( + return EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_vector], usage=usage, @@ -197,7 +197,7 @@ class TestCacheEmbeddingDocuments: latency=0.8, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -296,7 +296,7 @@ class TestCacheEmbeddingDocuments: latency=0.6, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=new_embeddings, usage=usage, @@ -386,7 +386,7 @@ class TestCacheEmbeddingDocuments: latency=0.5, ) - return TextEmbeddingResult( + return EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -449,7 +449,7 @@ class TestCacheEmbeddingDocuments: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[valid_vector.tolist(), nan_vector], usage=usage, @@ -629,7 +629,7 @@ class TestCacheEmbeddingQuery: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -728,7 +728,7 @@ class TestCacheEmbeddingQuery: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[nan_vector], usage=usage, @@ -793,7 +793,7 @@ class TestCacheEmbeddingQuery: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -873,13 +873,13 @@ class TestEmbeddingModelSwitching: latency=0.3, ) - result_ada = TextEmbeddingResult( + result_ada = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_ada], usage=usage, ) - result_3_small = TextEmbeddingResult( + result_3_small = EmbeddingResult( model="text-embedding-3-small", embeddings=[normalized_3_small], usage=usage, @@ -953,13 +953,13 @@ class TestEmbeddingModelSwitching: latency=0.4, ) - result_openai = TextEmbeddingResult( + result_openai = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_openai], usage=usage_openai, ) - result_cohere = TextEmbeddingResult( + result_cohere = EmbeddingResult( model="embed-english-v3.0", embeddings=[normalized_cohere], usage=usage_cohere, @@ -1042,7 +1042,7 @@ class TestEmbeddingDimensionValidation: latency=0.7, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1095,7 +1095,7 @@ class TestEmbeddingDimensionValidation: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1148,7 +1148,7 @@ class TestEmbeddingDimensionValidation: latency=0.3, ) - result_ada = TextEmbeddingResult( + result_ada = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_ada], usage=usage_ada, @@ -1181,7 +1181,7 @@ class TestEmbeddingDimensionValidation: latency=0.4, ) - result_cohere = TextEmbeddingResult( + result_cohere = EmbeddingResult( model="embed-english-v3.0", embeddings=[normalized_cohere], usage=usage_cohere, @@ -1279,7 +1279,7 @@ class TestEmbeddingEdgeCases: latency=0.1, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1322,7 +1322,7 @@ class TestEmbeddingEdgeCases: latency=1.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1370,7 +1370,7 @@ class TestEmbeddingEdgeCases: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1422,7 +1422,7 @@ class TestEmbeddingEdgeCases: latency=0.2, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1478,7 +1478,7 @@ class TestEmbeddingEdgeCases: ) # Model returns embeddings for all texts - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1546,7 +1546,7 @@ class TestEmbeddingEdgeCases: latency=0.8, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1603,7 +1603,7 @@ class TestEmbeddingEdgeCases: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1657,7 +1657,7 @@ class TestEmbeddingEdgeCases: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1757,7 +1757,7 @@ class TestEmbeddingCachePerformance: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1826,7 +1826,7 @@ class TestEmbeddingCachePerformance: latency=0.5, ) - return TextEmbeddingResult( + return EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1888,7 +1888,7 @@ class TestEmbeddingCachePerformance: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index d26e98db8d..c00fee8fe5 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -62,7 +62,7 @@ from core.indexing_runner import ( IndexingRunner, ) from core.model_runtime.entities.model_entities import ModelType -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import ChildDocument, Document from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetProcessRule @@ -112,7 +112,7 @@ def create_mock_dataset_document( document_id: str | None = None, dataset_id: str | None = None, tenant_id: str | None = None, - doc_form: str = IndexType.PARAGRAPH_INDEX, + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, data_source_type: str = "upload_file", doc_language: str = "English", ) -> Mock: @@ -133,8 +133,8 @@ def create_mock_dataset_document( Mock: A configured mock DatasetDocument object with all required attributes. Example: - >>> doc = create_mock_dataset_document(doc_form=IndexType.QA_INDEX) - >>> assert doc.doc_form == IndexType.QA_INDEX + >>> doc = create_mock_dataset_document(doc_form=IndexStructureType.QA_INDEX) + >>> assert doc.doc_form == IndexStructureType.QA_INDEX """ doc = Mock(spec=DatasetDocument) doc.id = document_id or str(uuid.uuid4()) @@ -276,7 +276,7 @@ class TestIndexingRunnerExtract: doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) doc.tenant_id = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX doc.data_source_type = "upload_file" doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())} return doc @@ -616,7 +616,7 @@ class TestIndexingRunnerLoad: doc = Mock(spec=DatasetDocument) doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX return doc @pytest.fixture @@ -700,7 +700,7 @@ class TestIndexingRunnerLoad: """Test loading with parent-child index structure.""" # Arrange runner = IndexingRunner() - sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX + sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX sample_dataset.indexing_technique = "high_quality" # Add child documents @@ -775,7 +775,7 @@ class TestIndexingRunnerRun: doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) doc.tenant_id = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX doc.doc_language = "English" doc.data_source_type = "upload_file" doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())} @@ -802,6 +802,21 @@ class TestIndexingRunnerRun: mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}} mock_dependencies["db"].session.scalar.return_value = mock_process_rule + # Mock current_user (Account) for _transform + mock_current_user = MagicMock() + mock_current_user.set_tenant_id = MagicMock() + + # Setup db.session.query to return different results based on the model + def mock_query_side_effect(model): + mock_query_result = MagicMock() + if model.__name__ == "Dataset": + mock_query_result.filter_by.return_value.first.return_value = mock_dataset + elif model.__name__ == "Account": + mock_query_result.filter_by.return_value.first.return_value = mock_current_user + return mock_query_result + + mock_dependencies["db"].session.query.side_effect = mock_query_side_effect + # Mock processor mock_processor = MagicMock() mock_dependencies["factory"].return_value.init_index_processor.return_value = mock_processor @@ -1268,7 +1283,7 @@ class TestIndexingRunnerLoadSegments: doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) doc.created_by = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX return doc @pytest.fixture @@ -1316,7 +1331,7 @@ class TestIndexingRunnerLoadSegments: """Test loading segments for parent-child index.""" # Arrange runner = IndexingRunner() - sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX + sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX # Add child documents for doc in sample_documents: @@ -1413,7 +1428,7 @@ class TestIndexingRunnerEstimate: tenant_id=tenant_id, extract_settings=extract_settings, tmp_processing_rule={"mode": "automatic", "rules": {}}, - doc_form=IndexType.PARAGRAPH_INDEX, + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index 4912884c55..ebe6c37818 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -26,6 +26,18 @@ from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.weight_rerank import WeightRerankRunner +def create_mock_model_instance(): + """Create a properly configured mock ModelInstance for reranking tests.""" + mock_instance = Mock(spec=ModelInstance) + # Setup provider_model_bundle chain for check_model_support_vision + mock_instance.provider_model_bundle = Mock() + mock_instance.provider_model_bundle.configuration = Mock() + mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id" + mock_instance.provider = "test-provider" + mock_instance.model = "test-model" + return mock_instance + + class TestRerankModelRunner: """Unit tests for RerankModelRunner. @@ -37,10 +49,23 @@ class TestRerankModelRunner: - Metadata preservation and score injection """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + @pytest.fixture def mock_model_instance(self): """Create a mock ModelInstance for reranking.""" mock_instance = Mock(spec=ModelInstance) + # Setup provider_model_bundle chain for check_model_support_vision + mock_instance.provider_model_bundle = Mock() + mock_instance.provider_model_bundle.configuration = Mock() + mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id" + mock_instance.provider = "test-provider" + mock_instance.model = "test-model" return mock_instance @pytest.fixture @@ -803,7 +828,7 @@ class TestRerankRunnerFactory: - Parameters are forwarded to runner constructor """ # Arrange: Mock model instance - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() # Act: Create runner via factory runner = RerankRunnerFactory.create_rerank_runner( @@ -865,7 +890,7 @@ class TestRerankRunnerFactory: - String values are properly matched """ # Arrange: Mock model instance - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() # Act: Create runner using enum value runner = RerankRunnerFactory.create_rerank_runner( @@ -886,6 +911,13 @@ class TestRerankIntegration: - Real-world usage scenarios """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_model_reranking_full_workflow(self): """Test complete model-based reranking workflow. @@ -895,7 +927,7 @@ class TestRerankIntegration: - Top results are returned correctly """ # Arrange: Create mock model and documents - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -951,7 +983,7 @@ class TestRerankIntegration: - Normalization is consistent """ # Arrange: Create mock model with various scores - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -991,6 +1023,13 @@ class TestRerankEdgeCases: - Concurrent reranking scenarios """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_rerank_with_empty_metadata(self): """Test reranking when documents have empty metadata. @@ -1000,7 +1039,7 @@ class TestRerankEdgeCases: - Empty metadata documents are processed correctly """ # Arrange: Create documents with empty metadata - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1046,7 +1085,7 @@ class TestRerankEdgeCases: - Score comparison logic works at boundary """ # Arrange: Create mock with various scores including negatives - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1082,7 +1121,7 @@ class TestRerankEdgeCases: - No overflow or precision issues """ # Arrange: All documents with perfect scores - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1117,7 +1156,7 @@ class TestRerankEdgeCases: - Content encoding is preserved """ # Arrange: Documents with special characters - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1159,7 +1198,7 @@ class TestRerankEdgeCases: - Content is not truncated unexpectedly """ # Arrange: Documents with very long content - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() long_content = "This is a very long document. " * 1000 # ~30,000 characters mock_rerank_result = RerankResult( @@ -1196,7 +1235,7 @@ class TestRerankEdgeCases: - All documents are processed correctly """ # Arrange: Create 100 documents - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() num_docs = 100 # Create rerank results for all documents @@ -1287,7 +1326,7 @@ class TestRerankEdgeCases: - Documents can still be ranked """ # Arrange: Empty query - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1325,6 +1364,13 @@ class TestRerankPerformance: - Score calculation optimization """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_rerank_batch_processing(self): """Test that documents are processed in a single batch. @@ -1334,7 +1380,7 @@ class TestRerankPerformance: - Efficient batch processing """ # Arrange: Multiple documents - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[RerankDocument(index=i, text=f"Doc {i}", score=0.9 - i * 0.1) for i in range(5)], @@ -1435,6 +1481,13 @@ class TestRerankErrorHandling: - Error propagation """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_rerank_model_invocation_error(self): """Test handling of model invocation errors. @@ -1444,7 +1497,7 @@ class TestRerankErrorHandling: - Error context is preserved """ # Arrange: Mock model that raises exception - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_model_instance.invoke_rerank.side_effect = RuntimeError("Model invocation failed") documents = [ @@ -1470,7 +1523,7 @@ class TestRerankErrorHandling: - Invalid results don't corrupt output """ # Arrange: Rerank result with invalid index - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 0163e42992..affd6c648f 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -425,15 +425,15 @@ class TestRetrievalService: # ==================== Vector Search Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_basic(self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents): + def test_vector_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test basic vector/semantic search functionality. This test validates the core vector search flow: 1. Dataset is retrieved from database - 2. embedding_search is called via ThreadPoolExecutor + 2. _retrieve is called via ThreadPoolExecutor 3. Documents are added to shared all_documents list 4. Results are returned to caller @@ -447,28 +447,28 @@ class TestRetrievalService: # Set up the mock dataset that will be "retrieved" from database mock_get_dataset.return_value = mock_dataset - # Create a side effect function that simulates embedding_search behavior - # In the real implementation, embedding_search: - # 1. Gets the dataset - # 2. Creates a Vector instance - # 3. Calls search_by_vector with embeddings - # 4. Extends all_documents with results - def side_effect_embedding_search( + # Create a side effect function that simulates _retrieve behavior + # _retrieve modifies the all_documents list in place + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - """Simulate embedding_search adding documents to the shared list.""" - all_documents.extend(sample_documents) + """Simulate _retrieve adding documents to the shared list.""" + if all_documents is not None: + all_documents.extend(sample_documents) - mock_embedding_search.side_effect = side_effect_embedding_search + mock_retrieve.side_effect = side_effect_retrieve # Define test parameters query = "What is Python?" # Natural language query @@ -481,7 +481,7 @@ class TestRetrievalService: # 1. Check if query is empty (early return if so) # 2. Get the dataset using _get_dataset # 3. Create ThreadPoolExecutor - # 4. Submit embedding_search task + # 4. Submit _retrieve task # 5. Wait for completion # 6. Return all_documents list results = RetrievalService.retrieve( @@ -502,15 +502,13 @@ class TestRetrievalService: # Verify documents maintain their scores (highest score first in sample_documents) assert results[0].metadata["score"] == 0.95, "First document should have highest score from sample_documents" - # Verify embedding_search was called exactly once + # Verify _retrieve was called exactly once # This confirms the search method was invoked by ThreadPoolExecutor - mock_embedding_search.assert_called_once() + mock_retrieve.assert_called_once() - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_with_document_filter( - self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents - ): + def test_vector_search_with_document_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test vector search with document ID filtering. @@ -522,21 +520,25 @@ class TestRetrievalService: mock_get_dataset.return_value = mock_dataset filtered_docs = [sample_documents[0]] - def side_effect_embedding_search( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.extend(filtered_docs) + if all_documents is not None: + all_documents.extend(filtered_docs) - mock_embedding_search.side_effect = side_effect_embedding_search + mock_retrieve.side_effect = side_effect_retrieve document_ids_filter = [sample_documents[0].metadata["document_id"]] # Act @@ -552,12 +554,12 @@ class TestRetrievalService: assert len(results) == 1 assert results[0].metadata["doc_id"] == "doc1" # Verify document_ids_filter was passed - call_kwargs = mock_embedding_search.call_args.kwargs + call_kwargs = mock_retrieve.call_args.kwargs assert call_kwargs["document_ids_filter"] == document_ids_filter - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_empty_results(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_vector_search_empty_results(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test vector search when no results match the query. @@ -567,8 +569,8 @@ class TestRetrievalService: """ # Arrange mock_get_dataset.return_value = mock_dataset - # embedding_search doesn't add anything to all_documents - mock_embedding_search.side_effect = lambda *args, **kwargs: None + # _retrieve doesn't add anything to all_documents + mock_retrieve.side_effect = lambda *args, **kwargs: None # Act results = RetrievalService.retrieve( @@ -583,9 +585,9 @@ class TestRetrievalService: # ==================== Keyword Search Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_keyword_search_basic(self, mock_get_dataset, mock_keyword_search, mock_dataset, sample_documents): + def test_keyword_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test basic keyword search functionality. @@ -597,12 +599,25 @@ class TestRetrievalService: # Arrange mock_get_dataset.return_value = mock_dataset - def side_effect_keyword_search( - flask_app, dataset_id, query, top_k, all_documents, exceptions, document_ids_filter=None + def side_effect_retrieve( + flask_app, + retrieval_method, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, + document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.extend(sample_documents) + if all_documents is not None: + all_documents.extend(sample_documents) - mock_keyword_search.side_effect = side_effect_keyword_search + mock_retrieve.side_effect = side_effect_retrieve query = "Python programming" top_k = 3 @@ -618,7 +633,7 @@ class TestRetrievalService: # Assert assert len(results) == 3 assert all(isinstance(doc, Document) for doc in results) - mock_keyword_search.assert_called_once() + mock_retrieve.assert_called_once() @patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") @@ -1147,11 +1162,9 @@ class TestRetrievalService: # ==================== Metadata Filtering Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_with_metadata_filter( - self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents - ): + def test_vector_search_with_metadata_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test vector search with metadata-based document filtering. @@ -1166,21 +1179,25 @@ class TestRetrievalService: filtered_doc = sample_documents[0] filtered_doc.metadata["category"] = "programming" - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.append(filtered_doc) + if all_documents is not None: + all_documents.append(filtered_doc) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve # Act results = RetrievalService.retrieve( @@ -1243,9 +1260,9 @@ class TestRetrievalService: # Assert assert results == [] - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test that exceptions during retrieval are properly handled. @@ -1256,22 +1273,26 @@ class TestRetrievalService: # Arrange mock_get_dataset.return_value = mock_dataset - # Make embedding_search add an exception to the exceptions list + # Make _retrieve add an exception to the exceptions list def side_effect_with_exception( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - exceptions.append("Search failed") + if exceptions is not None: + exceptions.append("Search failed") - mock_embedding_search.side_effect = side_effect_with_exception + mock_retrieve.side_effect = side_effect_with_exception # Act & Assert with pytest.raises(ValueError) as exc_info: @@ -1286,9 +1307,9 @@ class TestRetrievalService: # ==================== Score Threshold Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test vector search with score threshold filtering. @@ -1306,21 +1327,25 @@ class TestRetrievalService: provider="dify", ) - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.append(high_score_doc) + if all_documents is not None: + all_documents.append(high_score_doc) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve score_threshold = 0.8 @@ -1339,9 +1364,9 @@ class TestRetrievalService: # ==================== Top-K Limiting Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test that retrieval respects top_k parameter. @@ -1362,22 +1387,26 @@ class TestRetrievalService: for i in range(10) ] - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): # Return only top_k documents - all_documents.extend(many_docs[:top_k]) + if all_documents is not None: + all_documents.extend(many_docs[:top_k]) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve top_k = 3 @@ -1390,9 +1419,9 @@ class TestRetrievalService: ) # Assert - # Verify top_k was passed to embedding_search - assert mock_embedding_search.called - call_kwargs = mock_embedding_search.call_args.kwargs + # Verify _retrieve was called + assert mock_retrieve.called + call_kwargs = mock_retrieve.call_args.kwargs assert call_kwargs["top_k"] == top_k # Verify we got the right number of results assert len(results) == top_k @@ -1421,11 +1450,9 @@ class TestRetrievalService: # ==================== Reranking Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_semantic_search_with_reranking( - self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents - ): + def test_semantic_search_with_reranking(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test semantic search with reranking model. @@ -1439,22 +1466,26 @@ class TestRetrievalService: # Simulate reranking changing order reranked_docs = list(reversed(sample_documents)) - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - # embedding_search handles reranking internally - all_documents.extend(reranked_docs) + # _retrieve handles reranking internally + if all_documents is not None: + all_documents.extend(reranked_docs) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve reranking_model = { "reranking_provider_name": "cohere", @@ -1473,7 +1504,7 @@ class TestRetrievalService: # Assert # For semantic search with reranking, reranking_model should be passed assert len(results) == 3 - call_kwargs = mock_embedding_search.call_args.kwargs + call_kwargs = mock_retrieve.call_args.kwargs assert call_kwargs["reranking_model"] == reranking_model diff --git a/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py b/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py index 0bf4a3cf91..1361e16b06 100644 --- a/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py +++ b/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py @@ -1,3 +1,5 @@ +from types import SimpleNamespace + import pytest from core.tools.utils.web_reader_tool import ( @@ -103,7 +105,10 @@ def test_get_url_html_flow_with_chardet_and_readability(monkeypatch: pytest.Monk monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head) monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get) - monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"}) + + mock_best = SimpleNamespace(encoding="utf-8") + mock_from_bytes = SimpleNamespace(best=lambda: mock_best) + monkeypatch.setattr(mod.charset_normalizer, "from_bytes", lambda _: mock_from_bytes) # readability → a dict that maps to Article, then FULL_TEMPLATE def fake_simple_json_from_html_string(html, use_readability=True): @@ -134,7 +139,9 @@ def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch: pytest. monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head) monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get) - monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"}) + mock_best = SimpleNamespace(encoding="utf-8") + mock_from_bytes = SimpleNamespace(best=lambda: mock_best) + monkeypatch.setattr(mod.charset_normalizer, "from_bytes", lambda _: mock_from_bytes) # readability returns empty plain_text monkeypatch.setattr(mod, "simple_json_from_html_string", lambda html, use_readability=True: {"plain_text": []}) @@ -162,7 +169,9 @@ def test_get_url_403_cloudscraper_fallback(monkeypatch: pytest.MonkeyPatch, stub monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head) monkeypatch.setattr(mod.cloudscraper, "create_scraper", lambda: FakeScraper()) - monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"}) + mock_best = SimpleNamespace(encoding="utf-8") + mock_from_bytes = SimpleNamespace(best=lambda: mock_best) + monkeypatch.setattr(mod.charset_normalizer, "from_bytes", lambda _: mock_from_bytes) monkeypatch.setattr( mod, "simple_json_from_html_string", @@ -234,7 +243,10 @@ def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch: pytest.Mo monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head) monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get) - monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"}) + + mock_best = SimpleNamespace(encoding="utf-8") + mock_from_bytes = SimpleNamespace(best=lambda: mock_best) + monkeypatch.setattr(mod.charset_normalizer, "from_bytes", lambda _: mock_from_bytes) monkeypatch.setattr( mod, "simple_json_from_html_string", diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 02bf8e82f1..5d180c7cbc 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -1,3 +1,5 @@ +from types import SimpleNamespace + import pytest from core.app.entities.app_invoke_entities import InvokeFrom @@ -214,3 +216,76 @@ def test_create_variable_message(): assert message.message.variable_name == var_name assert message.message.variable_value == var_value assert message.message.stream is False + + +def test_resolve_user_from_database_falls_back_to_end_user(monkeypatch: pytest.MonkeyPatch): + """Ensure worker context can resolve EndUser when Account is missing.""" + + class StubSession: + def __init__(self, results: list): + self.results = results + + def scalar(self, _stmt): + return self.results.pop(0) + + tenant = SimpleNamespace(id="tenant_id") + end_user = SimpleNamespace(id="end_user_id", tenant_id="tenant_id") + db_stub = SimpleNamespace(session=StubSession([tenant, None, end_user])) + + monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub) + + entity = ToolEntity( + identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), + parameters=[], + description=None, + has_runtime_parameters=False, + ) + runtime = ToolRuntime(tenant_id="tenant_id", invoke_from=InvokeFrom.SERVICE_API) + tool = WorkflowTool( + workflow_app_id="", + workflow_as_tool_id="", + version="1", + workflow_entities={}, + workflow_call_depth=1, + entity=entity, + runtime=runtime, + ) + + resolved_user = tool._resolve_user_from_database(user_id=end_user.id) + + assert resolved_user is end_user + + +def test_resolve_user_from_database_returns_none_when_no_tenant(monkeypatch: pytest.MonkeyPatch): + """Return None if tenant cannot be found in worker context.""" + + class StubSession: + def __init__(self, results: list): + self.results = results + + def scalar(self, _stmt): + return self.results.pop(0) + + db_stub = SimpleNamespace(session=StubSession([None])) + monkeypatch.setattr("core.tools.workflow_as_tool.tool.db", db_stub) + + entity = ToolEntity( + identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), + parameters=[], + description=None, + has_runtime_parameters=False, + ) + runtime = ToolRuntime(tenant_id="missing_tenant", invoke_from=InvokeFrom.SERVICE_API) + tool = WorkflowTool( + workflow_app_id="", + workflow_as_tool_id="", + version="1", + workflow_entities={}, + workflow_call_depth=1, + entity=entity, + runtime=runtime, + ) + + resolved_user = tool._resolve_user_from_database(user_id="any") + + assert resolved_user is None diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py new file mode 100644 index 0000000000..15eac6b537 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py @@ -0,0 +1,39 @@ +"""Tests for the EventManager.""" + +from __future__ import annotations + +import logging + +from core.workflow.graph_engine.event_management.event_manager import EventManager +from core.workflow.graph_engine.layers.base import GraphEngineLayer +from core.workflow.graph_events import GraphEngineEvent + + +class _FaultyLayer(GraphEngineLayer): + """Layer that raises from on_event to test error handling.""" + + def on_graph_start(self) -> None: # pragma: no cover - not used in tests + pass + + def on_event(self, event: GraphEngineEvent) -> None: + raise RuntimeError("boom") + + def on_graph_end(self, error: Exception | None) -> None: # pragma: no cover - not used in tests + pass + + +def test_event_manager_logs_layer_errors(caplog) -> None: + """Ensure errors raised by layers are logged when collecting events.""" + + event_manager = EventManager() + event_manager.set_layers([_FaultyLayer()]) + + with caplog.at_level(logging.ERROR): + event_manager.collect(GraphEngineEvent()) + + error_logs = [record for record in caplog.records if "Error in layer on_event" in record.getMessage()] + assert error_logs, "Expected layer errors to be logged" + + log_record = error_logs[0] + assert log_record.exc_info is not None + assert isinstance(log_record.exc_info[1], RuntimeError) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py b/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py new file mode 100644 index 0000000000..b1380cd6d2 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py @@ -0,0 +1,60 @@ +""" +Test case for end node without value_type field (backward compatibility). + +This test validates that end nodes work correctly even when the value_type +field is missing from the output configuration, ensuring backward compatibility +with older workflow definitions. +""" + +from core.workflow.graph_events import ( + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) + +from .test_table_runner import TableTestRunner, WorkflowTestCase + + +def test_end_node_without_value_type_field(): + """ + Test that end node works without explicit value_type field. + + The fixture implements a simple workflow that: + 1. Takes a query input from start node + 2. Passes it directly to end node + 3. End node outputs the value without specifying value_type + 4. Should correctly infer the type and output the value + + This ensures backward compatibility with workflow definitions + created before value_type became a required field. + """ + fixture_name = "end_node_without_value_type_field_workflow" + + case = WorkflowTestCase( + fixture_path=fixture_name, + inputs={"query": "test query"}, + expected_outputs={"query": "test query"}, + expected_event_sequence=[ + # Graph start + GraphRunStartedEvent, + # Start node + NodeRunStartedEvent, + NodeRunStreamChunkEvent, # Start node streams the input value + NodeRunSucceededEvent, + # End node + NodeRunStartedEvent, + NodeRunSucceededEvent, + # Graph end + GraphRunSucceededEvent, + ], + description="End node without value_type field should work correctly", + ) + + runner = TableTestRunner() + result = runner.run_test_case(case) + assert result.success, f"Test failed: {result.error}" + assert result.actual_outputs == {"query": "test query"}, ( + f"Expected output to be {{'query': 'test query'}}, got {result.actual_outputs}" + ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py b/api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py index 98f344babf..b9bf4be13a 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py @@ -7,9 +7,31 @@ This module tests the iteration node's ability to: """ from .test_database_utils import skip_if_database_unavailable +from .test_mock_config import MockConfigBuilder, NodeMockConfig from .test_table_runner import TableTestRunner, WorkflowTestCase +def _create_iteration_mock_config(): + """Helper to create a mock config for iteration tests.""" + + def code_inner_handler(node): + pool = node.graph_runtime_state.variable_pool + item_seg = pool.get(["iteration_node", "item"]) + if item_seg is not None: + item = item_seg.to_object() + return {"result": [item, item * 2]} + # This fallback is likely unreachable, but if it is, + # it doesn't simulate iteration with different values as the comment suggests. + return {"result": [1, 2]} + + return ( + MockConfigBuilder() + .with_node_output("code_node", {"result": [1, 2, 3]}) + .with_node_config(NodeMockConfig(node_id="code_inner_node", custom_handler=code_inner_handler)) + .build() + ) + + @skip_if_database_unavailable() def test_iteration_with_flatten_output_enabled(): """ @@ -27,7 +49,8 @@ def test_iteration_with_flatten_output_enabled(): inputs={}, expected_outputs={"output": [1, 2, 2, 4, 3, 6]}, description="Iteration with flatten_output=True flattens nested arrays", - use_auto_mock=False, # Run code nodes directly + use_auto_mock=True, # Use auto-mock to avoid sandbox service + mock_config=_create_iteration_mock_config(), ) result = runner.run_test_case(test_case) @@ -56,7 +79,8 @@ def test_iteration_with_flatten_output_disabled(): inputs={}, expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]}, description="Iteration with flatten_output=False preserves nested structure", - use_auto_mock=False, # Run code nodes directly + use_auto_mock=True, # Use auto-mock to avoid sandbox service + mock_config=_create_iteration_mock_config(), ) result = runner.run_test_case(test_case) @@ -81,14 +105,16 @@ def test_iteration_flatten_output_comparison(): inputs={}, expected_outputs={"output": [1, 2, 2, 4, 3, 6]}, description="flatten_output=True: Flattened output", - use_auto_mock=False, # Run code nodes directly + use_auto_mock=True, # Use auto-mock to avoid sandbox service + mock_config=_create_iteration_mock_config(), ), WorkflowTestCase( fixture_path="iteration_flatten_output_disabled_workflow", inputs={}, expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]}, description="flatten_output=False: Nested output", - use_auto_mock=False, # Run code nodes directly + use_auto_mock=True, # Use auto-mock to avoid sandbox service + mock_config=_create_iteration_mock_config(), ), ] diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py index 0f6b7e4ab6..47a5df92a4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py @@ -1,3 +1,4 @@ +import json from unittest.mock import Mock, PropertyMock, patch import httpx @@ -138,3 +139,95 @@ def test_is_file_with_no_content_disposition(mock_response): type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) response = Response(mock_response) assert response.is_file + + +# UTF-8 Encoding Tests +@pytest.mark.parametrize( + ("content_bytes", "expected_text", "description"), + [ + # Chinese UTF-8 bytes + ( + b'{"message": "\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c"}', + '{"message": "你好世界"}', + "Chinese characters UTF-8", + ), + # Japanese UTF-8 bytes + ( + b'{"message": "\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf"}', + '{"message": "こんにちは"}', + "Japanese characters UTF-8", + ), + # Korean UTF-8 bytes + ( + b'{"message": "\xec\x95\x88\xeb\x85\x95\xed\x95\x98\xec\x84\xb8\xec\x9a\x94"}', + '{"message": "안녕하세요"}', + "Korean characters UTF-8", + ), + # Arabic UTF-8 + (b'{"text": "\xd9\x85\xd8\xb1\xd8\xad\xd8\xa8\xd8\xa7"}', '{"text": "مرحبا"}', "Arabic characters UTF-8"), + # European characters UTF-8 + (b'{"text": "Caf\xc3\xa9 M\xc3\xbcnchen"}', '{"text": "Café München"}', "European accented characters"), + # Simple ASCII + (b'{"text": "Hello World"}', '{"text": "Hello World"}', "Simple ASCII text"), + ], +) +def test_text_property_utf8_decoding(mock_response, content_bytes, expected_text, description): + """Test that Response.text properly decodes UTF-8 content with charset_normalizer""" + mock_response.headers = {"content-type": "application/json; charset=utf-8"} + type(mock_response).content = PropertyMock(return_value=content_bytes) + # Mock httpx response.text to return something different (simulating potential encoding issues) + mock_response.text = "incorrect-fallback-text" # To ensure we are not falling back to httpx's text property + + response = Response(mock_response) + + # Our enhanced text property should decode properly using charset_normalizer + assert response.text == expected_text, ( + f"Failed for {description}: got {repr(response.text)}, expected {repr(expected_text)}" + ) + + +def test_text_property_fallback_to_httpx(mock_response): + """Test that Response.text falls back to httpx.text when charset_normalizer fails""" + mock_response.headers = {"content-type": "application/json"} + + # Create malformed UTF-8 bytes + malformed_bytes = b'{"text": "\xff\xfe\x00\x00 invalid"}' + type(mock_response).content = PropertyMock(return_value=malformed_bytes) + + # Mock httpx.text to return some fallback value + fallback_text = '{"text": "fallback"}' + mock_response.text = fallback_text + + response = Response(mock_response) + + # Should fall back to httpx's text when charset_normalizer fails + assert response.text == fallback_text + + +@pytest.mark.parametrize( + ("json_content", "description"), + [ + # JSON with escaped Unicode (like Flask jsonify()) + ('{"message": "\\u4f60\\u597d\\u4e16\\u754c"}', "JSON with escaped Unicode"), + # JSON with mixed escape sequences and UTF-8 + ('{"mixed": "Hello \\u4f60\\u597d"}', "Mixed escaped and regular text"), + # JSON with complex escape sequences + ('{"complex": "\\ud83d\\ude00\\u4f60\\u597d"}', "Emoji and Chinese escapes"), + ], +) +def test_text_property_with_escaped_unicode(mock_response, json_content, description): + """Test Response.text with JSON containing Unicode escape sequences""" + mock_response.headers = {"content-type": "application/json"} + + content_bytes = json_content.encode("utf-8") + type(mock_response).content = PropertyMock(return_value=content_bytes) + mock_response.text = json_content # httpx would return the same for valid UTF-8 + + response = Response(mock_response) + + # Should preserve the escape sequences (valid JSON) + assert response.text == json_content, f"Failed for {description}" + + # The text should be valid JSON that can be parsed back to proper Unicode + parsed = json.loads(response.text) + assert isinstance(parsed, dict), f"Invalid JSON for {description}" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py new file mode 100644 index 0000000000..83799c9508 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -0,0 +1,227 @@ +import time + +import pytest +from pydantic import ValidationError as PydanticValidationError + +from core.app.app_config.entities import VariableEntity, VariableEntityType +from core.workflow.entities import GraphInitParams +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable + + +def make_start_node(user_inputs, variables): + variable_pool = VariablePool( + system_variables=SystemVariable(), + user_inputs=user_inputs, + conversation_variables=[], + ) + + config = { + "id": "start", + "data": StartNodeData(title="Start", variables=variables).model_dump(), + } + + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + ) + + return StartNode( + id="start", + config=config, + graph_init_params=GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="wf", + graph_config={}, + user_id="u", + user_from="account", + invoke_from="debugger", + call_depth=0, + ), + graph_runtime_state=graph_runtime_state, + ) + + +def test_json_object_valid_schema(): + schema = { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age"], + } + + variables = [ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + json_schema=schema, + ) + ] + + user_inputs = {"profile": {"age": 20, "name": "Tom"}} + + node = make_start_node(user_inputs, variables) + result = node._run() + + assert result.outputs["profile"] == {"age": 20, "name": "Tom"} + + +def test_json_object_invalid_json_string(): + variables = [ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + ) + ] + + # Missing closing brace makes this invalid JSON + user_inputs = {"profile": '{"age": 20, "name": "Tom"'} + + node = make_start_node(user_inputs, variables) + + with pytest.raises(ValueError, match="profile must be a JSON object"): + node._run() + + +@pytest.mark.parametrize("value", ["[1, 2, 3]", "123"]) +def test_json_object_valid_json_but_not_object(value): + variables = [ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + ) + ] + + user_inputs = {"profile": value} + + node = make_start_node(user_inputs, variables) + + with pytest.raises(ValueError, match="profile must be a JSON object"): + node._run() + + +def test_json_object_does_not_match_schema(): + schema = { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age", "name"], + } + + variables = [ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + json_schema=schema, + ) + ] + + # age is a string, which violates the schema (expects number) + user_inputs = {"profile": {"age": "twenty", "name": "Tom"}} + + node = make_start_node(user_inputs, variables) + + with pytest.raises(ValueError, match=r"JSON object for 'profile' does not match schema:"): + node._run() + + +def test_json_object_missing_required_schema_field(): + schema = { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age", "name"], + } + + variables = [ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + json_schema=schema, + ) + ] + + # Missing required field "name" + user_inputs = {"profile": {"age": 20}} + + node = make_start_node(user_inputs, variables) + + with pytest.raises( + ValueError, match=r"JSON object for 'profile' does not match schema: 'name' is a required property" + ): + node._run() + + +def test_json_object_required_variable_missing_from_inputs(): + variables = [ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + ) + ] + + user_inputs = {} + + node = make_start_node(user_inputs, variables) + + with pytest.raises(ValueError, match="profile is required in input form"): + node._run() + + +def test_json_object_invalid_json_schema_string(): + variable = VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + ) + + # Bypass pydantic type validation on assignment to simulate an invalid JSON schema string + variable.json_schema = "{invalid-json-schema" + + variables = [variable] + user_inputs = {"profile": '{"age": 20}'} + + # Invalid json_schema string should be rejected during node data hydration + with pytest.raises(PydanticValidationError): + make_start_node(user_inputs, variables) + + +def test_json_object_optional_variable_not_provided(): + variables = [ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=False, + ) + ] + + user_inputs = {} + + node = make_start_node(user_inputs, variables) + + # Current implementation raises a validation error even when the variable is optional + with pytest.raises(ValueError, match="profile must be a JSON object"): + node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index ef23a8f565..c62fc4d8fe 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -30,7 +30,13 @@ def test_overwrite_string_variable(): "nodes": [ {"data": {"type": "start", "title": "Start"}, "id": "start"}, { - "data": {"type": "assigner", "version": "1", "title": "Variable Assigner", "items": []}, + "data": { + "type": "assigner", + "title": "Variable Assigner", + "assigned_variable_selector": ["conversation", "test_conversation_variable"], + "write_mode": "over-write", + "input_variable_selector": ["node_id", "test_string_variable"], + }, "id": "assigner", }, ], @@ -131,7 +137,13 @@ def test_append_variable_to_array(): "nodes": [ {"data": {"type": "start", "title": "Start"}, "id": "start"}, { - "data": {"type": "assigner", "version": "1", "title": "Variable Assigner", "items": []}, + "data": { + "type": "assigner", + "title": "Variable Assigner", + "assigned_variable_selector": ["conversation", "test_conversation_variable"], + "write_mode": "append", + "input_variable_selector": ["node_id", "test_string_variable"], + }, "id": "assigner", }, ], @@ -231,7 +243,13 @@ def test_clear_array(): "nodes": [ {"data": {"type": "start", "title": "Start"}, "id": "start"}, { - "data": {"type": "assigner", "version": "1", "title": "Variable Assigner", "items": []}, + "data": { + "type": "assigner", + "title": "Variable Assigner", + "assigned_variable_selector": ["conversation", "test_conversation_variable"], + "write_mode": "clear", + "input_variable_selector": [], + }, "id": "assigner", }, ], diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index f793341e73..caa36734ad 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -78,7 +78,7 @@ def test_remove_first_from_array(): "nodes": [ {"data": {"type": "start", "title": "Start"}, "id": "start"}, { - "data": {"type": "assigner", "title": "Variable Assigner", "items": []}, + "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, "id": "assigner", }, ], @@ -162,7 +162,7 @@ def test_remove_last_from_array(): "nodes": [ {"data": {"type": "start", "title": "Start"}, "id": "start"}, { - "data": {"type": "assigner", "title": "Variable Assigner", "items": []}, + "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, "id": "assigner", }, ], @@ -243,7 +243,7 @@ def test_remove_first_from_empty_array(): "nodes": [ {"data": {"type": "start", "title": "Start"}, "id": "start"}, { - "data": {"type": "assigner", "title": "Variable Assigner", "items": []}, + "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, "id": "assigner", }, ], @@ -324,7 +324,7 @@ def test_remove_last_from_empty_array(): "nodes": [ {"data": {"type": "start", "title": "Start"}, "id": "start"}, { - "data": {"type": "assigner", "title": "Variable Assigner", "items": []}, + "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, "id": "assigner", }, ], diff --git a/api/tests/unit_tests/extensions/otel/__init__.py b/api/tests/unit_tests/extensions/otel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/extensions/otel/conftest.py b/api/tests/unit_tests/extensions/otel/conftest.py new file mode 100644 index 0000000000..b7f27c4da8 --- /dev/null +++ b/api/tests/unit_tests/extensions/otel/conftest.py @@ -0,0 +1,96 @@ +""" +Shared fixtures for OTel tests. + +Provides: +- Mock TracerProvider with MemorySpanExporter +- Mock configurations +- Test data factories +""" + +from unittest.mock import MagicMock, create_autospec + +import pytest +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace import set_tracer_provider + + +@pytest.fixture +def memory_span_exporter(): + """Provide an in-memory span exporter for testing.""" + return InMemorySpanExporter() + + +@pytest.fixture +def tracer_provider_with_memory_exporter(memory_span_exporter): + """Provide a TracerProvider configured with memory exporter.""" + import opentelemetry.trace as trace_api + + trace_api._TRACER_PROVIDER = None + trace_api._TRACER_PROVIDER_SET_ONCE._done = False + + provider = TracerProvider() + processor = SimpleSpanProcessor(memory_span_exporter) + provider.add_span_processor(processor) + set_tracer_provider(provider) + + yield provider + + provider.force_flush() + + +@pytest.fixture +def mock_app_model(): + """Create a mock App model.""" + app = MagicMock() + app.id = "test-app-id" + app.tenant_id = "test-tenant-id" + return app + + +@pytest.fixture +def mock_account_user(): + """Create a mock Account user.""" + from models.model import Account + + user = create_autospec(Account, instance=True) + user.id = "test-user-id" + return user + + +@pytest.fixture +def mock_end_user(): + """Create a mock EndUser.""" + from models.model import EndUser + + user = create_autospec(EndUser, instance=True) + user.id = "test-end-user-id" + return user + + +@pytest.fixture +def mock_workflow_runner(): + """Create a mock WorkflowAppRunner.""" + runner = MagicMock() + runner.application_generate_entity = MagicMock() + runner.application_generate_entity.user_id = "test-user-id" + runner.application_generate_entity.stream = True + runner.application_generate_entity.app_config = MagicMock() + runner.application_generate_entity.app_config.app_id = "test-app-id" + runner.application_generate_entity.app_config.tenant_id = "test-tenant-id" + runner.application_generate_entity.app_config.workflow_id = "test-workflow-id" + return runner + + +@pytest.fixture(autouse=True) +def reset_handler_instances(): + """Reset handler singleton instances before each test.""" + from extensions.otel.decorators.base import _HANDLER_INSTANCES + + _HANDLER_INSTANCES.clear() + from extensions.otel.decorators.handler import SpanHandler + + _HANDLER_INSTANCES[SpanHandler] = SpanHandler() + yield + _HANDLER_INSTANCES.clear() diff --git a/api/tests/unit_tests/extensions/otel/decorators/__init__.py b/api/tests/unit_tests/extensions/otel/decorators/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/extensions/otel/decorators/handlers/__init__.py b/api/tests/unit_tests/extensions/otel/decorators/handlers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/extensions/otel/decorators/handlers/test_generate_handler.py b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_generate_handler.py new file mode 100644 index 0000000000..f7475f2239 --- /dev/null +++ b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_generate_handler.py @@ -0,0 +1,92 @@ +""" +Tests for AppGenerateHandler. + +Test objectives: +1. Verify handler compatibility with real function signature (fails when parameters change) +2. Verify span attribute mapping correctness +""" + +from unittest.mock import patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.otel.decorators.handlers.generate_handler import AppGenerateHandler +from extensions.otel.semconv import DifySpanAttributes, GenAIAttributes + + +class TestAppGenerateHandler: + """Core tests for AppGenerateHandler""" + + @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) + def test_compatible_with_real_function_signature( + self, tracer_provider_with_memory_exporter, mock_app_model, mock_account_user + ): + """ + Verify handler compatibility with real AppGenerateService.generate signature. + + If AppGenerateService.generate parameters change, this test will fail, + prompting developers to update the handler's parameter extraction logic. + """ + from services.app_generate_service import AppGenerateService + + handler = AppGenerateHandler() + + kwargs = { + "app_model": mock_app_model, + "user": mock_account_user, + "args": {"workflow_id": "test-wf-123"}, + "invoke_from": InvokeFrom.DEBUGGER, + "streaming": True, + "root_node_id": None, + } + + arguments = handler._extract_arguments(AppGenerateService.generate, (), kwargs) + + assert arguments is not None, "Failed to extract arguments from AppGenerateService.generate" + assert "app_model" in arguments, "Handler uses app_model but parameter is missing" + assert "user" in arguments, "Handler uses user but parameter is missing" + assert "args" in arguments, "Handler uses args but parameter is missing" + assert "streaming" in arguments, "Handler uses streaming but parameter is missing" + + @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) + def test_all_span_attributes_set_correctly( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_app_model, mock_account_user + ): + """Verify all span attributes are mapped correctly""" + handler = AppGenerateHandler() + tracer = tracer_provider_with_memory_exporter.get_tracer(__name__) + + test_app_id = "app-456" + test_tenant_id = "tenant-789" + test_user_id = "user-111" + test_workflow_id = "wf-222" + + mock_app_model.id = test_app_id + mock_app_model.tenant_id = test_tenant_id + mock_account_user.id = test_user_id + + def dummy_func(app_model, user, args, invoke_from, streaming=True): + return "result" + + handler.wrapper( + tracer, + dummy_func, + (), + { + "app_model": mock_app_model, + "user": mock_account_user, + "args": {"workflow_id": test_workflow_id}, + "invoke_from": InvokeFrom.DEBUGGER, + "streaming": False, + }, + ) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + attrs = spans[0].attributes + + assert attrs[DifySpanAttributes.APP_ID] == test_app_id + assert attrs[DifySpanAttributes.TENANT_ID] == test_tenant_id + assert attrs[GenAIAttributes.USER_ID] == test_user_id + assert attrs[DifySpanAttributes.WORKFLOW_ID] == test_workflow_id + assert attrs[DifySpanAttributes.USER_TYPE] == "Account" + assert attrs[DifySpanAttributes.STREAMING] is False diff --git a/api/tests/unit_tests/extensions/otel/decorators/handlers/test_workflow_app_runner_handler.py b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_workflow_app_runner_handler.py new file mode 100644 index 0000000000..500f80fc3c --- /dev/null +++ b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_workflow_app_runner_handler.py @@ -0,0 +1,76 @@ +""" +Tests for WorkflowAppRunnerHandler. + +Test objectives: +1. Verify handler compatibility with real WorkflowAppRunner structure (fails when structure changes) +2. Verify span attribute mapping correctness +""" + +from unittest.mock import patch + +from extensions.otel.decorators.handlers.workflow_app_runner_handler import WorkflowAppRunnerHandler +from extensions.otel.semconv import DifySpanAttributes, GenAIAttributes + + +class TestWorkflowAppRunnerHandler: + """Core tests for WorkflowAppRunnerHandler""" + + def test_handler_structure_dependencies(self): + """ + Verify handler dependencies on WorkflowAppRunner structure. + + Handler depends on: + - runner.application_generate_entity (WorkflowAppGenerateEntity) + - entity.app_config (WorkflowAppConfig) + - entity.user_id, entity.stream + - app_config.app_id, app_config.tenant_id, app_config.workflow_id + + If these attribute paths change in real types, this test will fail, + prompting developers to update the handler's attribute access logic. + """ + from core.app.app_config.entities import WorkflowUIBasedAppConfig + from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity + + required_entity_fields = ["user_id", "stream", "app_config"] + entity_fields = WorkflowAppGenerateEntity.model_fields + for field in required_entity_fields: + assert field in entity_fields, f"Handler expects WorkflowAppGenerateEntity.{field} but field is missing" + + required_config_fields = ["app_id", "tenant_id", "workflow_id"] + config_fields = WorkflowUIBasedAppConfig.model_fields + for field in required_config_fields: + assert field in config_fields, f"Handler expects app_config.{field} but field is missing" + + @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) + def test_all_span_attributes_set_correctly( + self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_workflow_runner + ): + """Verify all span attributes are mapped correctly""" + handler = WorkflowAppRunnerHandler() + tracer = tracer_provider_with_memory_exporter.get_tracer(__name__) + + test_app_id = "app-999" + test_tenant_id = "tenant-888" + test_user_id = "user-777" + test_workflow_id = "wf-666" + + mock_workflow_runner.application_generate_entity.user_id = test_user_id + mock_workflow_runner.application_generate_entity.stream = False + mock_workflow_runner.application_generate_entity.app_config.app_id = test_app_id + mock_workflow_runner.application_generate_entity.app_config.tenant_id = test_tenant_id + mock_workflow_runner.application_generate_entity.app_config.workflow_id = test_workflow_id + + def runner_run(self): + return "result" + + handler.wrapper(tracer, runner_run, (mock_workflow_runner,), {}) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + attrs = spans[0].attributes + + assert attrs[DifySpanAttributes.APP_ID] == test_app_id + assert attrs[DifySpanAttributes.TENANT_ID] == test_tenant_id + assert attrs[GenAIAttributes.USER_ID] == test_user_id + assert attrs[DifySpanAttributes.WORKFLOW_ID] == test_workflow_id + assert attrs[DifySpanAttributes.STREAMING] is False diff --git a/api/tests/unit_tests/extensions/otel/decorators/test_base.py b/api/tests/unit_tests/extensions/otel/decorators/test_base.py new file mode 100644 index 0000000000..a42f861bb7 --- /dev/null +++ b/api/tests/unit_tests/extensions/otel/decorators/test_base.py @@ -0,0 +1,119 @@ +""" +Tests for trace_span decorator. + +Test coverage: +- Decorator basic functionality +- Enable/disable logic +- Handler singleton management +- Integration with OpenTelemetry SDK +""" + +from unittest.mock import patch + +import pytest +from opentelemetry.trace import StatusCode + +from extensions.otel.decorators.base import trace_span + + +class TestTraceSpanDecorator: + """Test trace_span decorator basic functionality.""" + + @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) + def test_decorated_function_executes_normally(self, tracer_provider_with_memory_exporter): + """Test that decorated function executes and returns correct value.""" + + @trace_span() + def test_func(x, y): + return x + y + + result = test_func(2, 3) + assert result == 5 + + @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) + def test_decorator_with_args_and_kwargs(self, tracer_provider_with_memory_exporter): + """Test that decorator correctly handles args and kwargs.""" + + @trace_span() + def test_func(a, b, c=10): + return a + b + c + + result = test_func(1, 2, c=3) + assert result == 6 + + +class TestTraceSpanWithMemoryExporter: + """Test trace_span with MemorySpanExporter to verify span creation.""" + + @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) + def test_span_is_created_and_exported(self, tracer_provider_with_memory_exporter, memory_span_exporter): + """Test that span is created and exported to memory exporter.""" + + @trace_span() + def test_func(): + return "result" + + test_func() + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + + @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) + def test_span_name_matches_function(self, tracer_provider_with_memory_exporter, memory_span_exporter): + """Test that span name matches the decorated function.""" + + @trace_span() + def my_test_function(): + return "result" + + my_test_function() + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + assert "my_test_function" in spans[0].name + + @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) + def test_span_status_is_ok_on_success(self, tracer_provider_with_memory_exporter, memory_span_exporter): + """Test that span status is OK when function succeeds.""" + + @trace_span() + def test_func(): + return "result" + + test_func() + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == StatusCode.OK + + @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) + def test_span_status_is_error_on_exception(self, tracer_provider_with_memory_exporter, memory_span_exporter): + """Test that span status is ERROR when function raises exception.""" + + @trace_span() + def test_func(): + raise ValueError("test error") + + with pytest.raises(ValueError, match="test error"): + test_func() + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == StatusCode.ERROR + + @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) + def test_exception_is_recorded_in_span(self, tracer_provider_with_memory_exporter, memory_span_exporter): + """Test that exception details are recorded in span events.""" + + @trace_span() + def test_func(): + raise ValueError("test error") + + with pytest.raises(ValueError): + test_func() + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + events = spans[0].events + assert len(events) > 0 + assert any("exception" in event.name.lower() for event in events) diff --git a/api/tests/unit_tests/extensions/otel/decorators/test_handler.py b/api/tests/unit_tests/extensions/otel/decorators/test_handler.py new file mode 100644 index 0000000000..44788bab9a --- /dev/null +++ b/api/tests/unit_tests/extensions/otel/decorators/test_handler.py @@ -0,0 +1,258 @@ +""" +Tests for SpanHandler base class. + +Test coverage: +- _build_span_name method +- _extract_arguments method +- wrapper method default implementation +- Signature caching +""" + +from unittest.mock import patch + +import pytest +from opentelemetry.trace import StatusCode + +from extensions.otel.decorators.handler import SpanHandler + + +class TestSpanHandlerExtractArguments: + """Test SpanHandler._extract_arguments method.""" + + def test_extract_positional_arguments(self): + """Test extracting positional arguments.""" + handler = SpanHandler() + + def func(a, b, c): + pass + + args = (1, 2, 3) + kwargs = {} + result = handler._extract_arguments(func, args, kwargs) + + assert result is not None + assert result["a"] == 1 + assert result["b"] == 2 + assert result["c"] == 3 + + def test_extract_keyword_arguments(self): + """Test extracting keyword arguments.""" + handler = SpanHandler() + + def func(a, b, c): + pass + + args = () + kwargs = {"a": 1, "b": 2, "c": 3} + result = handler._extract_arguments(func, args, kwargs) + + assert result is not None + assert result["a"] == 1 + assert result["b"] == 2 + assert result["c"] == 3 + + def test_extract_mixed_arguments(self): + """Test extracting mixed positional and keyword arguments.""" + handler = SpanHandler() + + def func(a, b, c): + pass + + args = (1,) + kwargs = {"b": 2, "c": 3} + result = handler._extract_arguments(func, args, kwargs) + + assert result is not None + assert result["a"] == 1 + assert result["b"] == 2 + assert result["c"] == 3 + + def test_extract_arguments_with_defaults(self): + """Test extracting arguments with default values.""" + handler = SpanHandler() + + def func(a, b=10, c=20): + pass + + args = (1,) + kwargs = {} + result = handler._extract_arguments(func, args, kwargs) + + assert result is not None + assert result["a"] == 1 + assert result["b"] == 10 + assert result["c"] == 20 + + def test_extract_arguments_handles_self(self): + """Test extracting arguments from instance method (with self).""" + handler = SpanHandler() + + class MyClass: + def method(self, a, b): + pass + + instance = MyClass() + args = (1, 2) + kwargs = {} + result = handler._extract_arguments(instance.method, args, kwargs) + + assert result is not None + assert result["a"] == 1 + assert result["b"] == 2 + + def test_extract_arguments_returns_none_on_error(self): + """Test that _extract_arguments returns None when extraction fails.""" + handler = SpanHandler() + + def func(a, b): + pass + + args = (1,) + kwargs = {} + result = handler._extract_arguments(func, args, kwargs) + + assert result is None + + def test_signature_caching(self): + """Test that function signatures are cached.""" + handler = SpanHandler() + + def func(a, b): + pass + + assert func not in handler._signature_cache + + handler._extract_arguments(func, (1, 2), {}) + assert func in handler._signature_cache + + cached_sig = handler._signature_cache[func] + handler._extract_arguments(func, (3, 4), {}) + assert handler._signature_cache[func] is cached_sig + + +class TestSpanHandlerWrapper: + """Test SpanHandler.wrapper default implementation.""" + + @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) + def test_wrapper_creates_span(self, tracer_provider_with_memory_exporter, memory_span_exporter): + """Test that wrapper creates a span.""" + handler = SpanHandler() + tracer = tracer_provider_with_memory_exporter.get_tracer(__name__) + + def test_func(): + return "result" + + result = handler.wrapper(tracer, test_func, (), {}) + + assert result == "result" + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + + @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) + def test_wrapper_sets_span_kind_internal(self, tracer_provider_with_memory_exporter, memory_span_exporter): + """Test that wrapper sets SpanKind to INTERNAL.""" + from opentelemetry.trace import SpanKind + + handler = SpanHandler() + tracer = tracer_provider_with_memory_exporter.get_tracer(__name__) + + def test_func(): + return "result" + + handler.wrapper(tracer, test_func, (), {}) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].kind == SpanKind.INTERNAL + + @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) + def test_wrapper_sets_status_ok_on_success(self, tracer_provider_with_memory_exporter, memory_span_exporter): + """Test that wrapper sets status to OK when function succeeds.""" + handler = SpanHandler() + tracer = tracer_provider_with_memory_exporter.get_tracer(__name__) + + def test_func(): + return "result" + + handler.wrapper(tracer, test_func, (), {}) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == StatusCode.OK + + @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) + def test_wrapper_records_exception_on_error(self, tracer_provider_with_memory_exporter, memory_span_exporter): + """Test that wrapper records exception when function raises.""" + handler = SpanHandler() + tracer = tracer_provider_with_memory_exporter.get_tracer(__name__) + + def test_func(): + raise ValueError("test error") + + with pytest.raises(ValueError, match="test error"): + handler.wrapper(tracer, test_func, (), {}) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + events = spans[0].events + assert len(events) > 0 + assert any("exception" in event.name.lower() for event in events) + + @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) + def test_wrapper_sets_status_error_on_exception(self, tracer_provider_with_memory_exporter, memory_span_exporter): + """Test that wrapper sets status to ERROR when function raises exception.""" + handler = SpanHandler() + tracer = tracer_provider_with_memory_exporter.get_tracer(__name__) + + def test_func(): + raise ValueError("test error") + + with pytest.raises(ValueError): + handler.wrapper(tracer, test_func, (), {}) + + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == StatusCode.ERROR + assert "test error" in spans[0].status.description + + @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) + def test_wrapper_re_raises_exception(self, tracer_provider_with_memory_exporter): + """Test that wrapper re-raises exception after recording it.""" + handler = SpanHandler() + tracer = tracer_provider_with_memory_exporter.get_tracer(__name__) + + def test_func(): + raise ValueError("test error") + + with pytest.raises(ValueError, match="test error"): + handler.wrapper(tracer, test_func, (), {}) + + @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) + def test_wrapper_passes_arguments_correctly(self, tracer_provider_with_memory_exporter, memory_span_exporter): + """Test that wrapper correctly passes arguments to wrapped function.""" + handler = SpanHandler() + tracer = tracer_provider_with_memory_exporter.get_tracer(__name__) + + def test_func(a, b, c=10): + return a + b + c + + result = handler.wrapper(tracer, test_func, (1, 2), {"c": 3}) + + assert result == 6 + + @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) + def test_wrapper_with_memory_exporter(self, tracer_provider_with_memory_exporter, memory_span_exporter): + """Test wrapper end-to-end with memory exporter.""" + handler = SpanHandler() + tracer = tracer_provider_with_memory_exporter.get_tracer(__name__) + + def my_function(x): + return x * 2 + + result = handler.wrapper(tracer, my_function, (5,), {}) + + assert result == 10 + spans = memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + assert "my_function" in spans[0].name + assert spans[0].status.status_code == StatusCode.OK diff --git a/api/tests/unit_tests/extensions/test_ext_request_logging.py b/api/tests/unit_tests/extensions/test_ext_request_logging.py index cf6e172e4d..dcb457c806 100644 --- a/api/tests/unit_tests/extensions/test_ext_request_logging.py +++ b/api/tests/unit_tests/extensions/test_ext_request_logging.py @@ -263,3 +263,62 @@ class TestResponseUnmodified: ) assert response.text == _RESPONSE_NEEDLE assert response.status_code == 200 + + +class TestRequestFinishedInfoAccessLine: + def test_info_access_log_includes_method_path_status_duration_trace_id(self, monkeypatch, caplog): + """Ensure INFO access line contains expected fields with computed duration and trace id.""" + app = _get_test_app() + # Push a real request context so flask.request and g are available + with app.test_request_context("/foo", method="GET"): + # Seed start timestamp via the extension's own start hook and control perf_counter deterministically + seq = iter([100.0, 100.123456]) + monkeypatch.setattr(ext_request_logging.time, "perf_counter", lambda: next(seq)) + # Provide a deterministic trace id + monkeypatch.setattr( + ext_request_logging, + "get_trace_id_from_otel_context", + lambda: "trace-xyz", + ) + # Simulate request_started to record start timestamp on g + ext_request_logging._log_request_started(app) + + # Capture logs from the real logger at INFO level only (skip DEBUG branch) + caplog.set_level(logging.INFO, logger=ext_request_logging.__name__) + response = Response(json.dumps({"ok": True}), mimetype="application/json", status=200) + _log_request_finished(app, response) + + # Verify a single INFO record with the five fields in order + info_records = [rec for rec in caplog.records if rec.levelno == logging.INFO] + assert len(info_records) == 1 + msg = info_records[0].getMessage() + # Expected format: METHOD PATH STATUS DURATION_MS TRACE_ID + assert "GET" in msg + assert "/foo" in msg + assert "200" in msg + assert "123.456" in msg # rounded to 3 decimals + assert "trace-xyz" in msg + + def test_info_access_log_uses_dash_without_start_timestamp(self, monkeypatch, caplog): + app = _get_test_app() + with app.test_request_context("/bar", method="POST"): + # No g.__request_started_ts set -> duration should be '-' + monkeypatch.setattr( + ext_request_logging, + "get_trace_id_from_otel_context", + lambda: "tid-no-start", + ) + caplog.set_level(logging.INFO, logger=ext_request_logging.__name__) + response = Response("OK", mimetype="text/plain", status=204) + _log_request_finished(app, response) + + info_records = [rec for rec in caplog.records if rec.levelno == logging.INFO] + assert len(info_records) == 1 + msg = info_records[0].getMessage() + assert "POST" in msg + assert "/bar" in msg + assert "204" in msg + # Duration placeholder + # The fields are space separated; ensure a standalone '-' appears + assert " - " in msg or msg.endswith(" -") + assert "tid-no-start" in msg diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index 268ba1282a..e35788660d 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -1149,3 +1149,258 @@ class TestModelIntegration: # Assert assert site.app_id == app.id assert app.enable_site is True + + +class TestConversationStatusCount: + """Test suite for Conversation.status_count property N+1 query fix.""" + + def test_status_count_no_messages(self): + """Test status_count returns None when conversation has no messages.""" + # Arrange + conversation = Conversation( + app_id=str(uuid4()), + mode=AppMode.CHAT, + name="Test Conversation", + status="normal", + from_source="api", + ) + conversation.id = str(uuid4()) + + # Mock the database query to return no messages + with patch("models.model.db.session.scalars") as mock_scalars: + mock_scalars.return_value.all.return_value = [] + + # Act + result = conversation.status_count + + # Assert + assert result is None + + def test_status_count_messages_without_workflow_runs(self): + """Test status_count when messages have no workflow_run_id.""" + # Arrange + app_id = str(uuid4()) + conversation_id = str(uuid4()) + + conversation = Conversation( + app_id=app_id, + mode=AppMode.CHAT, + name="Test Conversation", + status="normal", + from_source="api", + ) + conversation.id = conversation_id + + # Mock the database query to return no messages with workflow_run_id + with patch("models.model.db.session.scalars") as mock_scalars: + mock_scalars.return_value.all.return_value = [] + + # Act + result = conversation.status_count + + # Assert + assert result is None + + def test_status_count_batch_loading_implementation(self): + """Test that status_count uses batch loading instead of N+1 queries.""" + # Arrange + from core.workflow.enums import WorkflowExecutionStatus + + app_id = str(uuid4()) + conversation_id = str(uuid4()) + + # Create workflow run IDs + workflow_run_id_1 = str(uuid4()) + workflow_run_id_2 = str(uuid4()) + workflow_run_id_3 = str(uuid4()) + + conversation = Conversation( + app_id=app_id, + mode=AppMode.CHAT, + name="Test Conversation", + status="normal", + from_source="api", + ) + conversation.id = conversation_id + + # Mock messages with workflow_run_id + mock_messages = [ + MagicMock( + conversation_id=conversation_id, + workflow_run_id=workflow_run_id_1, + ), + MagicMock( + conversation_id=conversation_id, + workflow_run_id=workflow_run_id_2, + ), + MagicMock( + conversation_id=conversation_id, + workflow_run_id=workflow_run_id_3, + ), + ] + + # Mock workflow runs with different statuses + mock_workflow_runs = [ + MagicMock( + id=workflow_run_id_1, + status=WorkflowExecutionStatus.SUCCEEDED.value, + app_id=app_id, + ), + MagicMock( + id=workflow_run_id_2, + status=WorkflowExecutionStatus.FAILED.value, + app_id=app_id, + ), + MagicMock( + id=workflow_run_id_3, + status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value, + app_id=app_id, + ), + ] + + # Track database calls + calls_made = [] + + def mock_scalars(query): + calls_made.append(str(query)) + mock_result = MagicMock() + + # Return messages for the first query (messages with workflow_run_id) + if "messages" in str(query) and "conversation_id" in str(query): + mock_result.all.return_value = mock_messages + # Return workflow runs for the batch query + elif "workflow_runs" in str(query): + mock_result.all.return_value = mock_workflow_runs + else: + mock_result.all.return_value = [] + + return mock_result + + # Act & Assert + with patch("models.model.db.session.scalars", side_effect=mock_scalars): + result = conversation.status_count + + # Verify only 2 database queries were made (not N+1) + assert len(calls_made) == 2, f"Expected 2 queries, got {len(calls_made)}: {calls_made}" + + # Verify the first query gets messages + assert "messages" in calls_made[0] + assert "conversation_id" in calls_made[0] + + # Verify the second query batch loads workflow runs with proper filtering + assert "workflow_runs" in calls_made[1] + assert "app_id" in calls_made[1] # Security filter applied + assert "IN" in calls_made[1] # Batch loading with IN clause + + # Verify correct status counts + assert result["success"] == 1 # One SUCCEEDED + assert result["failed"] == 1 # One FAILED + assert result["partial_success"] == 1 # One PARTIAL_SUCCEEDED + + def test_status_count_app_id_filtering(self): + """Test that status_count filters workflow runs by app_id for security.""" + # Arrange + app_id = str(uuid4()) + other_app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_run_id = str(uuid4()) + + conversation = Conversation( + app_id=app_id, + mode=AppMode.CHAT, + name="Test Conversation", + status="normal", + from_source="api", + ) + conversation.id = conversation_id + + # Mock message with workflow_run_id + mock_messages = [ + MagicMock( + conversation_id=conversation_id, + workflow_run_id=workflow_run_id, + ), + ] + + calls_made = [] + + def mock_scalars(query): + calls_made.append(str(query)) + mock_result = MagicMock() + + if "messages" in str(query): + mock_result.all.return_value = mock_messages + elif "workflow_runs" in str(query): + # Return empty list because no workflow run matches the correct app_id + mock_result.all.return_value = [] # Workflow run filtered out by app_id + else: + mock_result.all.return_value = [] + + return mock_result + + # Act + with patch("models.model.db.session.scalars", side_effect=mock_scalars): + result = conversation.status_count + + # Assert - query should include app_id filter + workflow_query = calls_made[1] + assert "app_id" in workflow_query + + # Since workflow run has wrong app_id, it shouldn't be included in counts + assert result["success"] == 0 + assert result["failed"] == 0 + assert result["partial_success"] == 0 + + def test_status_count_handles_invalid_workflow_status(self): + """Test that status_count gracefully handles invalid workflow status values.""" + # Arrange + app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_run_id = str(uuid4()) + + conversation = Conversation( + app_id=app_id, + mode=AppMode.CHAT, + name="Test Conversation", + status="normal", + from_source="api", + ) + conversation.id = conversation_id + + mock_messages = [ + MagicMock( + conversation_id=conversation_id, + workflow_run_id=workflow_run_id, + ), + ] + + # Mock workflow run with invalid status + mock_workflow_runs = [ + MagicMock( + id=workflow_run_id, + status="invalid_status", # Invalid status that should raise ValueError + app_id=app_id, + ), + ] + + with patch("models.model.db.session.scalars") as mock_scalars: + # Mock the messages query + def mock_scalars_side_effect(query): + mock_result = MagicMock() + if "messages" in str(query): + mock_result.all.return_value = mock_messages + elif "workflow_runs" in str(query): + mock_result.all.return_value = mock_workflow_runs + else: + mock_result.all.return_value = [] + return mock_result + + mock_scalars.side_effect = mock_scalars_side_effect + + # Act - should not raise exception + result = conversation.status_count + + # Assert - should handle invalid status gracefully + assert result["success"] == 0 + assert result["failed"] == 0 + assert result["partial_success"] == 0 diff --git a/api/tests/unit_tests/services/document_indexing_task_proxy.py b/api/tests/unit_tests/services/document_indexing_task_proxy.py index 765c4b5e32..ff243b8dc3 100644 --- a/api/tests/unit_tests/services/document_indexing_task_proxy.py +++ b/api/tests/unit_tests/services/document_indexing_task_proxy.py @@ -117,7 +117,7 @@ import pytest from core.entities.document_task import DocumentTask from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan -from services.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy # ============================================================================ # Test Data Factory @@ -370,7 +370,7 @@ class TestDocumentIndexingTaskProxy: # Features Property Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_features_property(self, mock_feature_service): """ Test cached_property features. @@ -400,7 +400,7 @@ class TestDocumentIndexingTaskProxy: mock_feature_service.get_features.assert_called_once_with("tenant-123") - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_features_property_with_different_tenants(self, mock_feature_service): """ Test features property with different tenant IDs. @@ -438,7 +438,7 @@ class TestDocumentIndexingTaskProxy: # Direct Queue Routing Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue(self, mock_task): """ Test _send_to_direct_queue method. @@ -460,7 +460,7 @@ class TestDocumentIndexingTaskProxy: # Assert mock_task.delay.assert_called_once_with(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_send_to_direct_queue_with_priority_task(self, mock_task): """ Test _send_to_direct_queue with priority task function. @@ -481,7 +481,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue_with_single_document(self, mock_task): """ Test _send_to_direct_queue with single document ID. @@ -502,7 +502,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1"] ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue_with_empty_documents(self, mock_task): """ Test _send_to_direct_queue with empty document_ids list. @@ -525,7 +525,7 @@ class TestDocumentIndexingTaskProxy: # Tenant Queue Routing Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_with_existing_task_key(self, mock_task): """ Test _send_to_tenant_queue when task key exists. @@ -564,7 +564,7 @@ class TestDocumentIndexingTaskProxy: mock_task.delay.assert_not_called() - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_without_task_key(self, mock_task): """ Test _send_to_tenant_queue when no task key exists. @@ -594,7 +594,7 @@ class TestDocumentIndexingTaskProxy: proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_send_to_tenant_queue_with_priority_task(self, mock_task): """ Test _send_to_tenant_queue with priority task function. @@ -621,7 +621,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_document_task_serialization(self, mock_task): """ Test DocumentTask serialization in _send_to_tenant_queue. @@ -659,7 +659,7 @@ class TestDocumentIndexingTaskProxy: # Queue Type Selection Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_default_tenant_queue(self, mock_task): """ Test _send_to_default_tenant_queue method. @@ -678,7 +678,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_tenant_queue.assert_called_once_with(mock_task) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_send_to_priority_tenant_queue(self, mock_task): """ Test _send_to_priority_tenant_queue method. @@ -697,7 +697,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_tenant_queue.assert_called_once_with(mock_task) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_send_to_priority_direct_queue(self, mock_task): """ Test _send_to_priority_direct_queue method. @@ -720,7 +720,7 @@ class TestDocumentIndexingTaskProxy: # Dispatch Logic Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service): """ Test _dispatch method when billing is enabled with SANDBOX plan. @@ -745,7 +745,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_default_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_with_billing_enabled_team_plan(self, mock_feature_service): """ Test _dispatch method when billing is enabled with TEAM plan. @@ -770,7 +770,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_with_billing_enabled_professional_plan(self, mock_feature_service): """ Test _dispatch method when billing is enabled with PROFESSIONAL plan. @@ -795,7 +795,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_with_billing_disabled(self, mock_feature_service): """ Test _dispatch method when billing is disabled. @@ -818,7 +818,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_direct_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_edge_case_empty_plan(self, mock_feature_service): """ Test _dispatch method with empty plan string. @@ -842,7 +842,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_edge_case_none_plan(self, mock_feature_service): """ Test _dispatch method with None plan. @@ -870,7 +870,7 @@ class TestDocumentIndexingTaskProxy: # Delay Method Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_delay_method(self, mock_feature_service): """ Test delay method integration. @@ -895,7 +895,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_default_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_delay_method_with_team_plan(self, mock_feature_service): """ Test delay method with TEAM plan. @@ -920,7 +920,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_delay_method_with_billing_disabled(self, mock_feature_service): """ Test delay method with billing disabled. @@ -1021,7 +1021,7 @@ class TestDocumentIndexingTaskProxy: # Batch Operations Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_batch_operation_with_multiple_documents(self, mock_task): """ Test batch operation with multiple documents. @@ -1044,7 +1044,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=document_ids ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_batch_operation_with_large_batch(self, mock_task): """ Test batch operation with large batch of documents. @@ -1073,7 +1073,7 @@ class TestDocumentIndexingTaskProxy: # Error Handling Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue_task_delay_failure(self, mock_task): """ Test _send_to_direct_queue when task.delay() raises an exception. @@ -1090,7 +1090,7 @@ class TestDocumentIndexingTaskProxy: with pytest.raises(Exception, match="Task delay failed"): proxy._send_to_direct_queue(mock_task) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_push_tasks_failure(self, mock_task): """ Test _send_to_tenant_queue when push_tasks raises an exception. @@ -1111,7 +1111,7 @@ class TestDocumentIndexingTaskProxy: with pytest.raises(Exception, match="Push tasks failed"): proxy._send_to_tenant_queue(mock_task) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_set_waiting_time_failure(self, mock_task): """ Test _send_to_tenant_queue when set_task_waiting_time raises an exception. @@ -1132,7 +1132,7 @@ class TestDocumentIndexingTaskProxy: with pytest.raises(Exception, match="Set waiting time failed"): proxy._send_to_tenant_queue(mock_task) - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") def test_dispatch_feature_service_failure(self, mock_feature_service): """ Test _dispatch when FeatureService.get_features raises an exception. @@ -1153,8 +1153,8 @@ class TestDocumentIndexingTaskProxy: # Integration Tests # ======================================================================== - @patch("services.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_full_flow_sandbox_plan(self, mock_task, mock_feature_service): """ Test full flow for SANDBOX plan with tenant queue. @@ -1187,8 +1187,8 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_full_flow_team_plan(self, mock_task, mock_feature_service): """ Test full flow for TEAM plan with priority tenant queue. @@ -1221,8 +1221,8 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.FeatureService") - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.priority_document_indexing_task") def test_full_flow_billing_disabled(self, mock_task, mock_feature_service): """ Test full flow for billing disabled (self-hosted/enterprise). diff --git a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py new file mode 100644 index 0000000000..bd226f7536 --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py @@ -0,0 +1,177 @@ +import types +from unittest.mock import Mock, create_autospec + +import pytest +from redis.exceptions import LockNotOwnedError + +from models.account import Account +from models.dataset import Dataset, Document +from services.dataset_service import DocumentService, SegmentService + + +class FakeLock: + """Lock that always fails on enter with LockNotOwnedError.""" + + def __enter__(self): + raise LockNotOwnedError("simulated") + + def __exit__(self, exc_type, exc, tb): + # Normal contextmanager signature; return False so exceptions propagate + return False + + +@pytest.fixture +def fake_current_user(monkeypatch): + user = create_autospec(Account, instance=True) + user.id = "user-1" + user.current_tenant_id = "tenant-1" + monkeypatch.setattr("services.dataset_service.current_user", user) + return user + + +@pytest.fixture +def fake_features(monkeypatch): + """Features.billing.enabled == False to skip quota logic.""" + features = types.SimpleNamespace( + billing=types.SimpleNamespace(enabled=False, subscription=types.SimpleNamespace(plan="ENTERPRISE")), + documents_upload_quota=types.SimpleNamespace(limit=10_000, size=0), + ) + monkeypatch.setattr( + "services.dataset_service.FeatureService.get_features", + lambda tenant_id: features, + ) + return features + + +@pytest.fixture +def fake_lock(monkeypatch): + """Patch redis_client.lock to always raise LockNotOwnedError on enter.""" + + def _fake_lock(name, timeout=None, *args, **kwargs): + return FakeLock() + + # DatasetService imports redis_client directly from extensions.ext_redis + monkeypatch.setattr("services.dataset_service.redis_client.lock", _fake_lock) + + +# --------------------------------------------------------------------------- +# 1. Knowledge Pipeline document creation (save_document_with_dataset_id) +# --------------------------------------------------------------------------- + + +def test_save_document_with_dataset_id_ignores_lock_not_owned( + monkeypatch, + fake_current_user, + fake_features, + fake_lock, +): + # Arrange + dataset = create_autospec(Dataset, instance=True) + dataset.id = "ds-1" + dataset.tenant_id = fake_current_user.current_tenant_id + dataset.data_source_type = "upload_file" + dataset.indexing_technique = "high_quality" # so we skip re-initialization branch + + # Minimal knowledge_config stub that satisfies pre-lock code + info_list = types.SimpleNamespace(data_source_type="upload_file") + data_source = types.SimpleNamespace(info_list=info_list) + knowledge_config = types.SimpleNamespace( + doc_form="qa_model", + original_document_id=None, # go into "new document" branch + data_source=data_source, + indexing_technique="high_quality", + embedding_model=None, + embedding_model_provider=None, + retrieval_model=None, + process_rule=None, + duplicate=False, + doc_language="en", + ) + + account = fake_current_user + + # Avoid touching real doc_form logic + monkeypatch.setattr("services.dataset_service.DatasetService.check_doc_form", lambda *a, **k: None) + # Avoid real DB interactions + monkeypatch.setattr("services.dataset_service.db", Mock()) + + # Act: this would hit the redis lock, whose __enter__ raises LockNotOwnedError. + # Our implementation should catch it and still return (documents, batch). + documents, batch = DocumentService.save_document_with_dataset_id( + dataset=dataset, + knowledge_config=knowledge_config, + account=account, + ) + + # Assert + # We mainly care that: + # - No exception is raised + # - The function returns a sensible tuple + assert isinstance(documents, list) + assert isinstance(batch, str) + + +# --------------------------------------------------------------------------- +# 2. Single-segment creation (add_segment) +# --------------------------------------------------------------------------- + + +def test_add_segment_ignores_lock_not_owned( + monkeypatch, + fake_current_user, + fake_lock, +): + # Arrange + dataset = create_autospec(Dataset, instance=True) + dataset.id = "ds-1" + dataset.tenant_id = fake_current_user.current_tenant_id + dataset.indexing_technique = "economy" # skip embedding/token calculation branch + + document = create_autospec(Document, instance=True) + document.id = "doc-1" + document.dataset_id = dataset.id + document.word_count = 0 + document.doc_form = "qa_model" + + # Minimal args required by add_segment + args = { + "content": "question text", + "answer": "answer text", + "keywords": ["k1", "k2"], + } + + # Avoid real DB operations + db_mock = Mock() + db_mock.session = Mock() + monkeypatch.setattr("services.dataset_service.db", db_mock) + monkeypatch.setattr("services.dataset_service.VectorService", Mock()) + + # Act + result = SegmentService.create_segment(args=args, document=document, dataset=dataset) + + # Assert + # Under LockNotOwnedError except, add_segment should swallow the error and return None. + assert result is None + + +# --------------------------------------------------------------------------- +# 3. Multi-segment creation (multi_create_segment) +# --------------------------------------------------------------------------- + + +def test_multi_create_segment_ignores_lock_not_owned( + monkeypatch, + fake_current_user, + fake_lock, +): + # Arrange + dataset = create_autospec(Dataset, instance=True) + dataset.id = "ds-1" + dataset.tenant_id = fake_current_user.current_tenant_id + dataset.indexing_technique = "economy" # again, skip high_quality path + + document = create_autospec(Document, instance=True) + document.id = "doc-1" + document.dataset_id = dataset.id + document.word_count = 0 + document.doc_form = "qa_model" diff --git a/api/tests/unit_tests/services/test_document_indexing_task_proxy.py b/api/tests/unit_tests/services/test_document_indexing_task_proxy.py index d9183be9fb..98c30c3722 100644 --- a/api/tests/unit_tests/services/test_document_indexing_task_proxy.py +++ b/api/tests/unit_tests/services/test_document_indexing_task_proxy.py @@ -3,7 +3,7 @@ from unittest.mock import Mock, patch from core.entities.document_task import DocumentTask from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan -from services.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy class DocumentIndexingTaskProxyTestDataFactory: @@ -59,7 +59,7 @@ class TestDocumentIndexingTaskProxy: assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id assert proxy._tenant_isolated_task_queue._unique_key == "document_indexing" - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_features_property(self, mock_feature_service): """Test cached_property features.""" # Arrange @@ -77,7 +77,7 @@ class TestDocumentIndexingTaskProxy: assert features1 is features2 # Should be the same instance due to caching mock_feature_service.get_features.assert_called_once_with("tenant-123") - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_direct_queue(self, mock_task): """Test _send_to_direct_queue method.""" # Arrange @@ -92,7 +92,7 @@ class TestDocumentIndexingTaskProxy: tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] ) - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_with_existing_task_key(self, mock_task): """Test _send_to_tenant_queue when task key exists.""" # Arrange @@ -115,7 +115,7 @@ class TestDocumentIndexingTaskProxy: assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"] mock_task.delay.assert_not_called() - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") + @patch("services.document_indexing_proxy.document_indexing_task_proxy.normal_document_indexing_task") def test_send_to_tenant_queue_without_task_key(self, mock_task): """Test _send_to_tenant_queue when no task key exists.""" # Arrange @@ -135,8 +135,7 @@ class TestDocumentIndexingTaskProxy: ) proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() - @patch("services.document_indexing_task_proxy.normal_document_indexing_task") - def test_send_to_default_tenant_queue(self, mock_task): + def test_send_to_default_tenant_queue(self): """Test _send_to_default_tenant_queue method.""" # Arrange proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() @@ -146,10 +145,9 @@ class TestDocumentIndexingTaskProxy: proxy._send_to_default_tenant_queue() # Assert - proxy._send_to_tenant_queue.assert_called_once_with(mock_task) + proxy._send_to_tenant_queue.assert_called_once_with(proxy.NORMAL_TASK_FUNC) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") - def test_send_to_priority_tenant_queue(self, mock_task): + def test_send_to_priority_tenant_queue(self): """Test _send_to_priority_tenant_queue method.""" # Arrange proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() @@ -159,10 +157,9 @@ class TestDocumentIndexingTaskProxy: proxy._send_to_priority_tenant_queue() # Assert - proxy._send_to_tenant_queue.assert_called_once_with(mock_task) + proxy._send_to_tenant_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC) - @patch("services.document_indexing_task_proxy.priority_document_indexing_task") - def test_send_to_priority_direct_queue(self, mock_task): + def test_send_to_priority_direct_queue(self): """Test _send_to_priority_direct_queue method.""" # Arrange proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy() @@ -172,9 +169,9 @@ class TestDocumentIndexingTaskProxy: proxy._send_to_priority_direct_queue() # Assert - proxy._send_to_direct_queue.assert_called_once_with(mock_task) + proxy._send_to_direct_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC) - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service): """Test _dispatch method when billing is enabled with sandbox plan.""" # Arrange @@ -191,7 +188,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_default_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service): """Test _dispatch method when billing is enabled with non-sandbox plan.""" # Arrange @@ -208,7 +205,7 @@ class TestDocumentIndexingTaskProxy: # If billing enabled with non sandbox plan, should send to priority tenant queue proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_with_billing_disabled(self, mock_feature_service): """Test _dispatch method when billing is disabled.""" # Arrange @@ -223,7 +220,7 @@ class TestDocumentIndexingTaskProxy: # If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue proxy._send_to_priority_direct_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_delay_method(self, mock_feature_service): """Test delay method integration.""" # Arrange @@ -256,7 +253,7 @@ class TestDocumentIndexingTaskProxy: assert task.dataset_id == dataset_id assert task.document_ids == document_ids - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_edge_case_empty_plan(self, mock_feature_service): """Test _dispatch method with empty plan string.""" # Arrange @@ -271,7 +268,7 @@ class TestDocumentIndexingTaskProxy: # Assert proxy._send_to_priority_tenant_queue.assert_called_once() - @patch("services.document_indexing_task_proxy.FeatureService") + @patch("services.document_indexing_proxy.base.FeatureService") def test_dispatch_edge_case_none_plan(self, mock_feature_service): """Test _dispatch method with None plan.""" # Arrange diff --git a/api/tests/unit_tests/services/test_duplicate_document_indexing_task_proxy.py b/api/tests/unit_tests/services/test_duplicate_document_indexing_task_proxy.py new file mode 100644 index 0000000000..68bafe3d5e --- /dev/null +++ b/api/tests/unit_tests/services/test_duplicate_document_indexing_task_proxy.py @@ -0,0 +1,363 @@ +from unittest.mock import Mock, patch + +from core.entities.document_task import DocumentTask +from core.rag.pipeline.queue import TenantIsolatedTaskQueue +from enums.cloud_plan import CloudPlan +from services.document_indexing_proxy.duplicate_document_indexing_task_proxy import ( + DuplicateDocumentIndexingTaskProxy, +) + + +class DuplicateDocumentIndexingTaskProxyTestDataFactory: + """Factory class for creating test data and mock objects for DuplicateDocumentIndexingTaskProxy tests.""" + + @staticmethod + def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock: + """Create mock features with billing configuration.""" + features = Mock() + features.billing = Mock() + features.billing.enabled = billing_enabled + features.billing.subscription = Mock() + features.billing.subscription.plan = plan + return features + + @staticmethod + def create_mock_tenant_queue(has_task_key: bool = False) -> Mock: + """Create mock TenantIsolatedTaskQueue.""" + queue = Mock(spec=TenantIsolatedTaskQueue) + queue.get_task_key.return_value = "task_key" if has_task_key else None + queue.push_tasks = Mock() + queue.set_task_waiting_time = Mock() + return queue + + @staticmethod + def create_duplicate_document_task_proxy( + tenant_id: str = "tenant-123", dataset_id: str = "dataset-456", document_ids: list[str] | None = None + ) -> DuplicateDocumentIndexingTaskProxy: + """Create DuplicateDocumentIndexingTaskProxy instance for testing.""" + if document_ids is None: + document_ids = ["doc-1", "doc-2", "doc-3"] + return DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + +class TestDuplicateDocumentIndexingTaskProxy: + """Test cases for DuplicateDocumentIndexingTaskProxy class.""" + + def test_initialization(self): + """Test DuplicateDocumentIndexingTaskProxy initialization.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = ["doc-1", "doc-2", "doc-3"] + + # Act + proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + assert isinstance(proxy._tenant_isolated_task_queue, TenantIsolatedTaskQueue) + assert proxy._tenant_isolated_task_queue._tenant_id == tenant_id + assert proxy._tenant_isolated_task_queue._unique_key == "duplicate_document_indexing" + + def test_queue_name(self): + """Test QUEUE_NAME class variable.""" + # Arrange & Act + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + + # Assert + assert proxy.QUEUE_NAME == "duplicate_document_indexing" + + def test_task_functions(self): + """Test NORMAL_TASK_FUNC and PRIORITY_TASK_FUNC class variables.""" + # Arrange & Act + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + + # Assert + assert proxy.NORMAL_TASK_FUNC.__name__ == "normal_duplicate_document_indexing_task" + assert proxy.PRIORITY_TASK_FUNC.__name__ == "priority_duplicate_document_indexing_task" + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_features_property(self, mock_feature_service): + """Test cached_property features.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features() + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + + # Act + features1 = proxy.features + features2 = proxy.features # Second call should use cached property + + # Assert + assert features1 == mock_features + assert features2 == mock_features + assert features1 is features2 # Should be the same instance due to caching + mock_feature_service.get_features.assert_called_once_with("tenant-123") + + @patch( + "services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task" + ) + def test_send_to_direct_queue(self, mock_task): + """Test _send_to_direct_queue method.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + mock_task.delay = Mock() + + # Act + proxy._send_to_direct_queue(mock_task) + + # Assert + mock_task.delay.assert_called_once_with( + tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] + ) + + @patch( + "services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task" + ) + def test_send_to_tenant_queue_with_existing_task_key(self, mock_task): + """Test _send_to_tenant_queue when task key exists.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._tenant_isolated_task_queue = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( + has_task_key=True + ) + mock_task.delay = Mock() + + # Act + proxy._send_to_tenant_queue(mock_task) + + # Assert + proxy._tenant_isolated_task_queue.push_tasks.assert_called_once() + pushed_tasks = proxy._tenant_isolated_task_queue.push_tasks.call_args[0][0] + assert len(pushed_tasks) == 1 + assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask) + assert pushed_tasks[0]["tenant_id"] == "tenant-123" + assert pushed_tasks[0]["dataset_id"] == "dataset-456" + assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"] + mock_task.delay.assert_not_called() + + @patch( + "services.document_indexing_proxy.duplicate_document_indexing_task_proxy.normal_duplicate_document_indexing_task" + ) + def test_send_to_tenant_queue_without_task_key(self, mock_task): + """Test _send_to_tenant_queue when no task key exists.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._tenant_isolated_task_queue = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue( + has_task_key=False + ) + mock_task.delay = Mock() + + # Act + proxy._send_to_tenant_queue(mock_task) + + # Assert + proxy._tenant_isolated_task_queue.set_task_waiting_time.assert_called_once() + mock_task.delay.assert_called_once_with( + tenant_id="tenant-123", dataset_id="dataset-456", document_ids=["doc-1", "doc-2", "doc-3"] + ) + proxy._tenant_isolated_task_queue.push_tasks.assert_not_called() + + def test_send_to_default_tenant_queue(self): + """Test _send_to_default_tenant_queue method.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_tenant_queue = Mock() + + # Act + proxy._send_to_default_tenant_queue() + + # Assert + proxy._send_to_tenant_queue.assert_called_once_with(proxy.NORMAL_TASK_FUNC) + + def test_send_to_priority_tenant_queue(self): + """Test _send_to_priority_tenant_queue method.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_tenant_queue = Mock() + + # Act + proxy._send_to_priority_tenant_queue() + + # Assert + proxy._send_to_tenant_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC) + + def test_send_to_priority_direct_queue(self): + """Test _send_to_priority_direct_queue method.""" + # Arrange + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_direct_queue = Mock() + + # Act + proxy._send_to_priority_direct_queue() + + # Assert + proxy._send_to_direct_queue.assert_called_once_with(proxy.PRIORITY_TASK_FUNC) + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service): + """Test _dispatch method when billing is enabled with sandbox plan.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.SANDBOX + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_default_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_default_tenant_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service): + """Test _dispatch method when billing is enabled with non-sandbox plan.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.TEAM + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + # If billing enabled with non sandbox plan, should send to priority tenant queue + proxy._send_to_priority_tenant_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_with_billing_disabled(self, mock_feature_service): + """Test _dispatch method when billing is disabled.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features(billing_enabled=False) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_direct_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + # If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue + proxy._send_to_priority_direct_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_delay_method(self, mock_feature_service): + """Test delay method integration.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.SANDBOX + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_default_tenant_queue = Mock() + + # Act + proxy.delay() + + # Assert + # If billing enabled with sandbox plan, should send to default tenant queue + proxy._send_to_default_tenant_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_edge_case_empty_plan(self, mock_feature_service): + """Test _dispatch method with empty plan string.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan="" + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once() + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_edge_case_none_plan(self, mock_feature_service): + """Test _dispatch method with None plan.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=None + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once() + + def test_initialization_with_empty_document_ids(self): + """Test initialization with empty document_ids list.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = [] + + # Act + proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + + def test_initialization_with_single_document_id(self): + """Test initialization with single document_id.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = ["doc-1"] + + # Act + proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + + def test_initialization_with_large_batch(self): + """Test initialization with large batch of document IDs.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-456" + document_ids = [f"doc-{i}" for i in range(100)] + + # Act + proxy = DuplicateDocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) + + # Assert + assert proxy._tenant_id == tenant_id + assert proxy._dataset_id == dataset_id + assert proxy._document_ids == document_ids + assert len(proxy._document_ids) == 100 + + @patch("services.document_indexing_proxy.base.FeatureService") + def test_dispatch_with_professional_plan(self, mock_feature_service): + """Test _dispatch method when billing is enabled with professional plan.""" + # Arrange + mock_features = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_mock_features( + billing_enabled=True, plan=CloudPlan.PROFESSIONAL + ) + mock_feature_service.get_features.return_value = mock_features + proxy = DuplicateDocumentIndexingTaskProxyTestDataFactory.create_duplicate_document_task_proxy() + proxy._send_to_priority_tenant_queue = Mock() + + # Act + proxy._dispatch() + + # Assert + proxy._send_to_priority_tenant_queue.assert_called_once() diff --git a/api/tests/unit_tests/services/test_external_dataset_service.py b/api/tests/unit_tests/services/test_external_dataset_service.py index c12ea2f7cb..e2d62583f8 100644 --- a/api/tests/unit_tests/services/test_external_dataset_service.py +++ b/api/tests/unit_tests/services/test_external_dataset_service.py @@ -6,6 +6,7 @@ Target: 1500+ lines of comprehensive test coverage. """ import json +import re from datetime import datetime from unittest.mock import MagicMock, Mock, patch @@ -1791,8 +1792,8 @@ class TestExternalDatasetServiceFetchRetrieval: @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") @patch("services.external_knowledge_service.db") - def test_fetch_external_knowledge_retrieval_non_200_status(self, mock_db, mock_process, factory): - """Test retrieval returns empty list on non-200 status.""" + def test_fetch_external_knowledge_retrieval_non_200_status_raises_exception(self, mock_db, mock_process, factory): + """Test that non-200 status code raises Exception with response text.""" # Arrange binding = factory.create_external_knowledge_binding_mock() api = factory.create_external_knowledge_api_mock() @@ -1817,12 +1818,103 @@ class TestExternalDatasetServiceFetchRetrieval: mock_response = MagicMock() mock_response.status_code = 500 + mock_response.text = "Internal Server Error: Database connection failed" mock_process.return_value = mock_response - # Act - result = ExternalDatasetService.fetch_external_knowledge_retrieval( - "tenant-123", "dataset-123", "query", {"top_k": 5} - ) + # Act & Assert + with pytest.raises(Exception, match="Internal Server Error: Database connection failed"): + ExternalDatasetService.fetch_external_knowledge_retrieval( + "tenant-123", "dataset-123", "query", {"top_k": 5} + ) - # Assert - assert result == [] + @pytest.mark.parametrize( + ("status_code", "error_message"), + [ + (400, "Bad Request: Invalid query parameters"), + (401, "Unauthorized: Invalid API key"), + (403, "Forbidden: Access denied to resource"), + (404, "Not Found: Knowledge base not found"), + (429, "Too Many Requests: Rate limit exceeded"), + (500, "Internal Server Error: Database connection failed"), + (502, "Bad Gateway: External service unavailable"), + (503, "Service Unavailable: Maintenance mode"), + ], + ) + @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") + @patch("services.external_knowledge_service.db") + def test_fetch_external_knowledge_retrieval_various_error_status_codes( + self, mock_db, mock_process, factory, status_code, error_message + ): + """Test that various error status codes raise exceptions with response text.""" + # Arrange + tenant_id = "tenant-123" + dataset_id = "dataset-123" + + binding = factory.create_external_knowledge_binding_mock( + dataset_id=dataset_id, external_knowledge_api_id="api-123" + ) + api = factory.create_external_knowledge_api_mock(api_id="api-123") + + mock_binding_query = MagicMock() + mock_api_query = MagicMock() + + def query_side_effect(model): + if model == ExternalKnowledgeBindings: + return mock_binding_query + elif model == ExternalKnowledgeApis: + return mock_api_query + return MagicMock() + + mock_db.session.query.side_effect = query_side_effect + + mock_binding_query.filter_by.return_value = mock_binding_query + mock_binding_query.first.return_value = binding + + mock_api_query.filter_by.return_value = mock_api_query + mock_api_query.first.return_value = api + + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.text = error_message + mock_process.return_value = mock_response + + # Act & Assert + with pytest.raises(ValueError, match=re.escape(error_message)): + ExternalDatasetService.fetch_external_knowledge_retrieval(tenant_id, dataset_id, "query", {"top_k": 5}) + + @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") + @patch("services.external_knowledge_service.db") + def test_fetch_external_knowledge_retrieval_empty_response_text(self, mock_db, mock_process, factory): + """Test exception with empty response text.""" + # Arrange + binding = factory.create_external_knowledge_binding_mock() + api = factory.create_external_knowledge_api_mock() + + mock_binding_query = MagicMock() + mock_api_query = MagicMock() + + def query_side_effect(model): + if model == ExternalKnowledgeBindings: + return mock_binding_query + elif model == ExternalKnowledgeApis: + return mock_api_query + return MagicMock() + + mock_db.session.query.side_effect = query_side_effect + + mock_binding_query.filter_by.return_value = mock_binding_query + mock_binding_query.first.return_value = binding + + mock_api_query.filter_by.return_value = mock_api_query + mock_api_query.first.return_value = api + + mock_response = MagicMock() + mock_response.status_code = 503 + mock_response.text = "" + mock_process.return_value = mock_response + + # Act & Assert + with pytest.raises(Exception, match=""): + ExternalDatasetService.fetch_external_knowledge_retrieval( + "tenant-123", "dataset-123", "query", {"top_k": 5} + ) diff --git a/api/tests/unit_tests/services/test_metadata_bug_complete.py b/api/tests/unit_tests/services/test_metadata_bug_complete.py index bbfa9da15e..fc3a2fc416 100644 --- a/api/tests/unit_tests/services/test_metadata_bug_complete.py +++ b/api/tests/unit_tests/services/test_metadata_bug_complete.py @@ -2,8 +2,6 @@ from pathlib import Path from unittest.mock import Mock, create_autospec, patch import pytest -from flask_restx import reqparse -from werkzeug.exceptions import BadRequest from models.account import Account from services.entities.knowledge_entities.knowledge_entities import MetadataArgs @@ -77,60 +75,39 @@ class TestMetadataBugCompleteValidation: assert type_column.nullable is False, "type column should be nullable=False" assert name_column.nullable is False, "name column should be nullable=False" - def test_4_fixed_api_layer_rejects_null(self, app): - """Test Layer 4: Fixed API configuration properly rejects null values.""" - # Test Console API create endpoint (fixed) - parser = ( - reqparse.RequestParser() - .add_argument("type", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - ) + def test_4_fixed_api_layer_rejects_null(self): + """Test Layer 4: Fixed API configuration properly rejects null values using Pydantic.""" + with pytest.raises((ValueError, TypeError)): + MetadataArgs.model_validate({"type": None, "name": None}) - with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): - with pytest.raises(BadRequest): - parser.parse_args() + with pytest.raises((ValueError, TypeError)): + MetadataArgs.model_validate({"type": "string", "name": None}) - # Test with just name being null - with app.test_request_context(json={"type": "string", "name": None}, content_type="application/json"): - with pytest.raises(BadRequest): - parser.parse_args() + with pytest.raises((ValueError, TypeError)): + MetadataArgs.model_validate({"type": None, "name": "test"}) - # Test with just type being null - with app.test_request_context(json={"type": None, "name": "test"}, content_type="application/json"): - with pytest.raises(BadRequest): - parser.parse_args() - - def test_5_fixed_api_accepts_valid_values(self, app): + def test_5_fixed_api_accepts_valid_values(self): """Test that fixed API still accepts valid non-null values.""" - parser = ( - reqparse.RequestParser() - .add_argument("type", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - ) + args = MetadataArgs.model_validate({"type": "string", "name": "valid_name"}) + assert args.type == "string" + assert args.name == "valid_name" - with app.test_request_context(json={"type": "string", "name": "valid_name"}, content_type="application/json"): - args = parser.parse_args() - assert args["type"] == "string" - assert args["name"] == "valid_name" + def test_6_simulated_buggy_behavior(self): + """Test simulating the original buggy behavior by bypassing Pydantic validation.""" + mock_metadata_args = Mock() + mock_metadata_args.name = None + mock_metadata_args.type = None - def test_6_simulated_buggy_behavior(self, app): - """Test simulating the original buggy behavior with nullable=True.""" - # Simulate the old buggy configuration - buggy_parser = ( - reqparse.RequestParser() - .add_argument("type", type=str, required=True, nullable=True, location="json") - .add_argument("name", type=str, required=True, nullable=True, location="json") - ) + mock_user = create_autospec(Account, instance=True) + mock_user.current_tenant_id = "tenant-123" + mock_user.id = "user-456" - with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): - # This would pass in the buggy version - args = buggy_parser.parse_args() - assert args["type"] is None - assert args["name"] is None - - # But would crash when trying to create MetadataArgs - with pytest.raises((ValueError, TypeError)): - MetadataArgs.model_validate(args) + with patch( + "services.metadata_service.current_account_with_tenant", + return_value=(mock_user, mock_user.current_tenant_id), + ): + with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): + MetadataService.create_metadata("dataset-123", mock_metadata_args) def test_7_end_to_end_validation_layers(self): """Test all validation layers work together correctly.""" diff --git a/api/tests/unit_tests/services/test_metadata_nullable_bug.py b/api/tests/unit_tests/services/test_metadata_nullable_bug.py index c8a1a70422..f43f394489 100644 --- a/api/tests/unit_tests/services/test_metadata_nullable_bug.py +++ b/api/tests/unit_tests/services/test_metadata_nullable_bug.py @@ -1,7 +1,6 @@ from unittest.mock import Mock, create_autospec, patch import pytest -from flask_restx import reqparse from models.account import Account from services.entities.knowledge_entities.knowledge_entities import MetadataArgs @@ -51,76 +50,16 @@ class TestMetadataNullableBug: with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): MetadataService.update_metadata_name("dataset-123", "metadata-456", None) - def test_api_parser_accepts_null_values(self, app): - """Test that API parser configuration incorrectly accepts null values.""" - # Simulate the current API parser configuration - parser = ( - reqparse.RequestParser() - .add_argument("type", type=str, required=True, nullable=True, location="json") - .add_argument("name", type=str, required=True, nullable=True, location="json") - ) + def test_api_layer_now_uses_pydantic_validation(self): + """Verify that API layer relies on Pydantic validation instead of reqparse.""" + invalid_payload = {"type": None, "name": None} + with pytest.raises((ValueError, TypeError)): + MetadataArgs.model_validate(invalid_payload) - # Simulate request data with null values - with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): - # This should parse successfully due to nullable=True - args = parser.parse_args() - - # Verify that null values are accepted - assert args["type"] is None - assert args["name"] is None - - # This demonstrates the bug: API accepts None but business logic will crash - - def test_integration_bug_scenario(self, app): - """Test the complete bug scenario from API to service layer.""" - # Step 1: API parser accepts null values (current buggy behavior) - parser = ( - reqparse.RequestParser() - .add_argument("type", type=str, required=True, nullable=True, location="json") - .add_argument("name", type=str, required=True, nullable=True, location="json") - ) - - with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): - args = parser.parse_args() - - # Step 2: Try to create MetadataArgs with None values - # This should fail at Pydantic validation level - with pytest.raises((ValueError, TypeError)): - metadata_args = MetadataArgs.model_validate(args) - - # Step 3: If we bypass Pydantic (simulating the bug scenario) - # Move this outside the request context to avoid Flask-Login issues - mock_metadata_args = Mock() - mock_metadata_args.name = None # From args["name"] - mock_metadata_args.type = None # From args["type"] - - mock_user = create_autospec(Account, instance=True) - mock_user.current_tenant_id = "tenant-123" - mock_user.id = "user-456" - - with patch( - "services.metadata_service.current_account_with_tenant", - return_value=(mock_user, mock_user.current_tenant_id), - ): - # Step 4: Service layer crashes on len(None) - with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): - MetadataService.create_metadata("dataset-123", mock_metadata_args) - - def test_correct_nullable_false_configuration_works(self, app): - """Test that the correct nullable=False configuration works as expected.""" - # This tests the FIXED configuration - parser = ( - reqparse.RequestParser() - .add_argument("type", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - ) - - with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): - # This should fail with BadRequest due to nullable=False - from werkzeug.exceptions import BadRequest - - with pytest.raises(BadRequest): - parser.parse_args() + valid_payload = {"type": "string", "name": "valid"} + args = MetadataArgs.model_validate(valid_payload) + assert args.type == "string" + assert args.name == "valid" if __name__ == "__main__": diff --git a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py new file mode 100644 index 0000000000..9a107da1c7 --- /dev/null +++ b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py @@ -0,0 +1,88 @@ +import types + +import pytest + +from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.provider_entities import ConfigurateMethod +from models.provider import ProviderType +from services.model_provider_service import ModelProviderService + + +class _FakeConfigurations: + def __init__(self, provider_configuration: types.SimpleNamespace) -> None: + self._provider_configuration = provider_configuration + + def values(self) -> list[types.SimpleNamespace]: + return [self._provider_configuration] + + +@pytest.fixture +def service_with_fake_configurations(): + # Build a fake provider schema with minimal fields used by ProviderResponse + fake_provider = types.SimpleNamespace( + provider="langgenius/openai_api_compatible/openai_api_compatible", + label=I18nObject(en_US="OpenAI API Compatible", zh_Hans="OpenAI API Compatible"), + description=None, + icon_small=None, + icon_small_dark=None, + icon_large=None, + background=None, + help=None, + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.CUSTOMIZABLE_MODEL], + provider_credential_schema=None, + model_credential_schema=None, + ) + + # Include decrypted credentials to simulate the leak source + custom_model = CustomModelConfiguration( + model="gpt-4o-mini", + model_type=ModelType.LLM, + credentials={"api_key": "sk-plain-text", "endpoint": "https://example.com"}, + current_credential_id="cred-1", + current_credential_name="API KEY 1", + available_model_credentials=[], + unadded_to_model_list=False, + ) + + fake_custom_provider = types.SimpleNamespace( + current_credential_id="cred-1", + current_credential_name="API KEY 1", + available_credentials=[CredentialConfiguration(credential_id="cred-1", credential_name="API KEY 1")], + ) + + fake_custom_configuration = types.SimpleNamespace( + provider=fake_custom_provider, models=[custom_model], can_added_models=[] + ) + + fake_system_configuration = types.SimpleNamespace(enabled=False, current_quota_type=None, quota_configurations=[]) + + fake_provider_configuration = types.SimpleNamespace( + provider=fake_provider, + preferred_provider_type=ProviderType.CUSTOM, + custom_configuration=fake_custom_configuration, + system_configuration=fake_system_configuration, + is_custom_configuration_available=lambda: True, + ) + + class _FakeProviderManager: + def get_configurations(self, tenant_id: str) -> _FakeConfigurations: + return _FakeConfigurations(fake_provider_configuration) + + svc = ModelProviderService() + svc.provider_manager = _FakeProviderManager() + return svc + + +def test_get_provider_list_strips_credentials(service_with_fake_configurations: ModelProviderService): + providers = service_with_fake_configurations.get_provider_list(tenant_id="tenant-1", model_type=None) + + assert len(providers) == 1 + custom_models = providers[0].custom_configuration.custom_models + + assert custom_models is not None + assert len(custom_models) == 1 + # The sanitizer should drop credentials in list response + assert custom_models[0].credentials is None diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index b3b29fbe45..9d7599b8fe 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -19,7 +19,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client from models.dataset import Dataset, Document -from services.document_indexing_task_proxy import DocumentIndexingTaskProxy +from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy from tasks.document_indexing_task import ( _document_indexing, _document_indexing_with_tenant_queue, @@ -138,7 +138,9 @@ class TestTaskEnqueuing: with patch.object(DocumentIndexingTaskProxy, "features") as mock_features: mock_features.billing.enabled = False - with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task): proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) # Act @@ -163,7 +165,9 @@ class TestTaskEnqueuing: mock_features.billing.enabled = True mock_features.billing.subscription.plan = CloudPlan.SANDBOX - with patch("services.document_indexing_task_proxy.normal_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "NORMAL_TASK_FUNC", mock_task): proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) # Act @@ -187,7 +191,9 @@ class TestTaskEnqueuing: mock_features.billing.enabled = True mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL - with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task): proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) # Act @@ -211,7 +217,9 @@ class TestTaskEnqueuing: mock_features.billing.enabled = True mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL - with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task): proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids) # Act @@ -1493,7 +1501,9 @@ class TestEdgeCases: mock_features.billing.enabled = True mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL - with patch("services.document_indexing_task_proxy.priority_document_indexing_task") as mock_task: + # Mock the class variable directly + mock_task = Mock() + with patch.object(DocumentIndexingTaskProxy, "PRIORITY_TASK_FUNC", mock_task): # Act - Enqueue multiple tasks rapidly for doc_ids in document_ids_list: proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, doc_ids) @@ -1898,7 +1908,7 @@ class TestRobustness: - Error is propagated appropriately """ # Arrange - with patch("services.document_indexing_task_proxy.FeatureService.get_features") as mock_get_features: + with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_get_features: # Simulate FeatureService failure mock_get_features.side_effect = Exception("Feature service unavailable") diff --git a/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py new file mode 100644 index 0000000000..0be6ea045e --- /dev/null +++ b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py @@ -0,0 +1,567 @@ +""" +Unit tests for duplicate document indexing tasks. + +This module tests the duplicate document indexing task functionality including: +- Task enqueuing to different queues (normal, priority, tenant-isolated) +- Batch processing of multiple duplicate documents +- Progress tracking through task lifecycle +- Error handling and retry mechanisms +- Cleanup of old document data before re-indexing +""" + +import uuid +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.pipeline.queue import TenantIsolatedTaskQueue +from enums.cloud_plan import CloudPlan +from models.dataset import Dataset, Document, DocumentSegment +from tasks.duplicate_document_indexing_task import ( + _duplicate_document_indexing_task, + _duplicate_document_indexing_task_with_tenant_queue, + duplicate_document_indexing_task, + normal_duplicate_document_indexing_task, + priority_duplicate_document_indexing_task, +) + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def tenant_id(): + """Generate a unique tenant ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def dataset_id(): + """Generate a unique dataset ID for testing.""" + return str(uuid.uuid4()) + + +@pytest.fixture +def document_ids(): + """Generate a list of document IDs for testing.""" + return [str(uuid.uuid4()) for _ in range(3)] + + +@pytest.fixture +def mock_dataset(dataset_id, tenant_id): + """Create a mock Dataset object.""" + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.indexing_technique = "high_quality" + dataset.embedding_model_provider = "openai" + dataset.embedding_model = "text-embedding-ada-002" + return dataset + + +@pytest.fixture +def mock_documents(document_ids, dataset_id): + """Create mock Document objects.""" + documents = [] + for doc_id in document_ids: + doc = Mock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.error = None + doc.stopped_at = None + doc.processing_started_at = None + doc.doc_form = "text_model" + documents.append(doc) + return documents + + +@pytest.fixture +def mock_document_segments(document_ids): + """Create mock DocumentSegment objects.""" + segments = [] + for doc_id in document_ids: + for i in range(3): + segment = Mock(spec=DocumentSegment) + segment.id = str(uuid.uuid4()) + segment.document_id = doc_id + segment.index_node_id = f"node-{doc_id}-{i}" + segments.append(segment) + return segments + + +@pytest.fixture +def mock_db_session(): + """Mock database session.""" + with patch("tasks.duplicate_document_indexing_task.db.session") as mock_session: + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_session.scalars.return_value = MagicMock() + yield mock_session + + +@pytest.fixture +def mock_indexing_runner(): + """Mock IndexingRunner.""" + with patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_runner_class: + mock_runner = MagicMock(spec=IndexingRunner) + mock_runner_class.return_value = mock_runner + yield mock_runner + + +@pytest.fixture +def mock_feature_service(): + """Mock FeatureService.""" + with patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_service: + mock_features = Mock() + mock_features.billing = Mock() + mock_features.billing.enabled = False + mock_features.vector_space = Mock() + mock_features.vector_space.size = 0 + mock_features.vector_space.limit = 1000 + mock_service.get_features.return_value = mock_features + yield mock_service + + +@pytest.fixture +def mock_index_processor_factory(): + """Mock IndexProcessorFactory.""" + with patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_factory: + mock_processor = MagicMock() + mock_processor.clean = Mock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + yield mock_factory + + +@pytest.fixture +def mock_tenant_isolated_queue(): + """Mock TenantIsolatedTaskQueue.""" + with patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") as mock_queue_class: + mock_queue = MagicMock(spec=TenantIsolatedTaskQueue) + mock_queue.pull_tasks.return_value = [] + mock_queue.delete_task_key = Mock() + mock_queue.set_task_waiting_time = Mock() + mock_queue_class.return_value = mock_queue + yield mock_queue + + +# ============================================================================ +# Tests for deprecated duplicate_document_indexing_task +# ============================================================================ + + +class TestDuplicateDocumentIndexingTask: + """Tests for the deprecated duplicate_document_indexing_task function.""" + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_duplicate_document_indexing_task_calls_core_function(self, mock_core_func, dataset_id, document_ids): + """Test that duplicate_document_indexing_task calls the core _duplicate_document_indexing_task function.""" + # Act + duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + mock_core_func.assert_called_once_with(dataset_id, document_ids) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_duplicate_document_indexing_task_with_empty_document_ids(self, mock_core_func, dataset_id): + """Test duplicate_document_indexing_task with empty document_ids list.""" + # Arrange + document_ids = [] + + # Act + duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + mock_core_func.assert_called_once_with(dataset_id, document_ids) + + +# ============================================================================ +# Tests for _duplicate_document_indexing_task core function +# ============================================================================ + + +class TestDuplicateDocumentIndexingTaskCore: + """Tests for the _duplicate_document_indexing_task core function.""" + + def test_successful_duplicate_document_indexing( + self, + mock_db_session, + mock_indexing_runner, + mock_feature_service, + mock_index_processor_factory, + mock_dataset, + mock_documents, + mock_document_segments, + dataset_id, + document_ids, + ): + """Test successful duplicate document indexing flow.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Verify IndexingRunner was called + mock_indexing_runner.run.assert_called_once() + + # Verify all documents were set to parsing status + for doc in mock_documents: + assert doc.indexing_status == "parsing" + assert doc.processing_started_at is not None + + # Verify session operations + assert mock_db_session.commit.called + assert mock_db_session.close.called + + def test_duplicate_document_indexing_dataset_not_found(self, mock_db_session, dataset_id, document_ids): + """Test duplicate document indexing when dataset is not found.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = None + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Should close the session at least once + assert mock_db_session.close.called + + def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan( + self, + mock_db_session, + mock_feature_service, + mock_dataset, + dataset_id, + document_ids, + ): + """Test duplicate document indexing with billing enabled and sandbox plan.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_features = mock_feature_service.get_features.return_value + mock_features.billing.enabled = True + mock_features.billing.subscription.plan = CloudPlan.SANDBOX + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # For sandbox plan with multiple documents, should fail + mock_db_session.commit.assert_called() + + def test_duplicate_document_indexing_with_billing_limit_exceeded( + self, + mock_db_session, + mock_feature_service, + mock_dataset, + mock_documents, + dataset_id, + document_ids, + ): + """Test duplicate document indexing when billing limit is exceeded.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = [] # No segments to clean + mock_features = mock_feature_service.get_features.return_value + mock_features.billing.enabled = True + mock_features.billing.subscription.plan = CloudPlan.TEAM + mock_features.vector_space.size = 990 + mock_features.vector_space.limit = 1000 + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Should commit the session + assert mock_db_session.commit.called + # Should close the session + assert mock_db_session.close.called + + def test_duplicate_document_indexing_runner_error( + self, + mock_db_session, + mock_indexing_runner, + mock_feature_service, + mock_index_processor_factory, + mock_dataset, + mock_documents, + dataset_id, + document_ids, + ): + """Test duplicate document indexing when IndexingRunner raises an error.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = [] + mock_indexing_runner.run.side_effect = Exception("Indexing error") + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Should close the session even after error + mock_db_session.close.assert_called_once() + + def test_duplicate_document_indexing_document_is_paused( + self, + mock_db_session, + mock_indexing_runner, + mock_feature_service, + mock_index_processor_factory, + mock_dataset, + mock_documents, + dataset_id, + document_ids, + ): + """Test duplicate document indexing when document is paused.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = [] + mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused") + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Should handle DocumentIsPausedError gracefully + mock_db_session.close.assert_called_once() + + def test_duplicate_document_indexing_cleans_old_segments( + self, + mock_db_session, + mock_indexing_runner, + mock_feature_service, + mock_index_processor_factory, + mock_dataset, + mock_documents, + mock_document_segments, + dataset_id, + document_ids, + ): + """Test that duplicate document indexing cleans old segments.""" + # Arrange + mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents + mock_db_session.scalars.return_value.all.return_value = mock_document_segments + mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value + + # Act + _duplicate_document_indexing_task(dataset_id, document_ids) + + # Assert + # Verify clean was called for each document + assert mock_processor.clean.call_count == len(mock_documents) + + # Verify segments were deleted + for segment in mock_document_segments: + mock_db_session.delete.assert_any_call(segment) + + +# ============================================================================ +# Tests for tenant queue wrapper function +# ============================================================================ + + +class TestDuplicateDocumentIndexingTaskWithTenantQueue: + """Tests for _duplicate_document_indexing_task_with_tenant_queue function.""" + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_tenant_queue_wrapper_calls_core_function( + self, + mock_core_func, + mock_tenant_isolated_queue, + tenant_id, + dataset_id, + document_ids, + ): + """Test that tenant queue wrapper calls the core function.""" + # Arrange + mock_task_func = Mock() + + # Act + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert + mock_core_func.assert_called_once_with(dataset_id, document_ids) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_tenant_queue_wrapper_deletes_key_when_no_tasks( + self, + mock_core_func, + mock_tenant_isolated_queue, + tenant_id, + dataset_id, + document_ids, + ): + """Test that tenant queue wrapper deletes task key when no more tasks.""" + # Arrange + mock_task_func = Mock() + mock_tenant_isolated_queue.pull_tasks.return_value = [] + + # Act + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert + mock_tenant_isolated_queue.delete_task_key.assert_called_once() + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_tenant_queue_wrapper_processes_next_tasks( + self, + mock_core_func, + mock_tenant_isolated_queue, + tenant_id, + dataset_id, + document_ids, + ): + """Test that tenant queue wrapper processes next tasks from queue.""" + # Arrange + mock_task_func = Mock() + next_task = { + "tenant_id": tenant_id, + "dataset_id": dataset_id, + "document_ids": document_ids, + } + mock_tenant_isolated_queue.pull_tasks.return_value = [next_task] + + # Act + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert + mock_tenant_isolated_queue.set_task_waiting_time.assert_called_once() + mock_task_func.delay.assert_called_once_with( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_ids=document_ids, + ) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + def test_tenant_queue_wrapper_handles_core_function_error( + self, + mock_core_func, + mock_tenant_isolated_queue, + tenant_id, + dataset_id, + document_ids, + ): + """Test that tenant queue wrapper handles errors from core function.""" + # Arrange + mock_task_func = Mock() + mock_core_func.side_effect = Exception("Core function error") + + # Act + _duplicate_document_indexing_task_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task_func) + + # Assert + # Should still check for next tasks even after error + mock_tenant_isolated_queue.pull_tasks.assert_called_once() + + +# ============================================================================ +# Tests for normal_duplicate_document_indexing_task +# ============================================================================ + + +class TestNormalDuplicateDocumentIndexingTask: + """Tests for normal_duplicate_document_indexing_task function.""" + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_normal_task_calls_tenant_queue_wrapper( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + document_ids, + ): + """Test that normal task calls tenant queue wrapper.""" + # Act + normal_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task + ) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_normal_task_with_empty_document_ids( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + ): + """Test normal task with empty document_ids list.""" + # Arrange + document_ids = [] + + # Act + normal_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task + ) + + +# ============================================================================ +# Tests for priority_duplicate_document_indexing_task +# ============================================================================ + + +class TestPriorityDuplicateDocumentIndexingTask: + """Tests for priority_duplicate_document_indexing_task function.""" + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_priority_task_calls_tenant_queue_wrapper( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + document_ids, + ): + """Test that priority task calls tenant queue wrapper.""" + # Act + priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task + ) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_priority_task_with_single_document( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + ): + """Test priority task with single document.""" + # Arrange + document_ids = ["doc-1"] + + # Act + priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task + ) + + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + def test_priority_task_with_large_batch( + self, + mock_wrapper_func, + tenant_id, + dataset_id, + ): + """Test priority task with large batch of documents.""" + # Arrange + document_ids = [f"doc-{i}" for i in range(100)] + + # Act + priority_duplicate_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_wrapper_func.assert_called_once_with( + tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task + ) diff --git a/api/tests/unit_tests/utils/test_text_processing.py b/api/tests/unit_tests/utils/test_text_processing.py index 8bfc97ae63..11e017464a 100644 --- a/api/tests/unit_tests/utils/test_text_processing.py +++ b/api/tests/unit_tests/utils/test_text_processing.py @@ -8,10 +8,13 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols [ ("...Hello, World!", "Hello, World!"), ("。测试中文标点", "测试中文标点"), - ("!@#Test symbols", "Test symbols"), + # Note: ! is not in the removal pattern, only @# are removed, leaving "!Test symbols" + # The pattern intentionally excludes ! as per #11868 fix + ("@#Test symbols", "Test symbols"), ("Hello, World!", "Hello, World!"), ("", ""), (" ", " "), + ("【测试】", "【测试】"), ], ) def test_remove_leading_symbols(input_text, expected_output): diff --git a/api/uv.lock b/api/uv.lock index f691e90837..ca94e0b8c9 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1040,7 +1040,7 @@ wheels = [ [[package]] name = "clickzetta-connector-python" -version = "0.8.106" +version = "0.8.107" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "future" }, @@ -1054,7 +1054,7 @@ dependencies = [ { name = "urllib3" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/23/38/749c708619f402d4d582dfa73fbeb64ade77b1f250a93bd064d2a1aa3776/clickzetta_connector_python-0.8.106-py3-none-any.whl", hash = "sha256:120d6700051d97609dbd6655c002ab3bc260b7c8e67d39dfc7191e749563f7b4", size = 78121, upload-time = "2025-10-29T02:38:15.014Z" }, + { url = "https://files.pythonhosted.org/packages/19/b4/91dfe25592bbcaf7eede05849c77d09d43a2656943585bbcf7ba4cc604bc/clickzetta_connector_python-0.8.107-py3-none-any.whl", hash = "sha256:7f28752bfa0a50e89ed218db0540c02c6bfbfdae3589ac81cf28523d7caa93b0", size = 76864, upload-time = "2025-12-01T07:56:39.177Z" }, ] [[package]] @@ -1337,7 +1337,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.10.1" +version = "1.11.1" source = { virtual = "." } dependencies = [ { name = "apscheduler" }, @@ -1348,7 +1348,7 @@ dependencies = [ { name = "bs4" }, { name = "cachetools" }, { name = "celery" }, - { name = "chardet" }, + { name = "charset-normalizer" }, { name = "croniter" }, { name = "flask" }, { name = "flask-compress" }, @@ -1371,6 +1371,7 @@ dependencies = [ { name = "httpx-sse" }, { name = "jieba" }, { name = "json-repair" }, + { name = "jsonschema" }, { name = "langfuse" }, { name = "langsmith" }, { name = "litellm" }, @@ -1543,7 +1544,7 @@ requires-dist = [ { name = "bs4", specifier = "~=0.0.1" }, { name = "cachetools", specifier = "~=5.3.0" }, { name = "celery", specifier = "~=5.5.2" }, - { name = "chardet", specifier = "~=5.1.0" }, + { name = "charset-normalizer", specifier = ">=3.4.4" }, { name = "croniter", specifier = ">=6.0.0" }, { name = "flask", specifier = "~=3.1.2" }, { name = "flask-compress", specifier = ">=1.17,<1.18" }, @@ -1566,6 +1567,7 @@ requires-dist = [ { name = "httpx-sse", specifier = "~=0.4.0" }, { name = "jieba", specifier = "==0.42.1" }, { name = "json-repair", specifier = ">=0.41.1" }, + { name = "jsonschema", specifier = ">=4.25.1" }, { name = "langfuse", specifier = "~=2.51.3" }, { name = "langsmith", specifier = "~=0.1.77" }, { name = "litellm", specifier = "==1.77.1" }, @@ -1655,7 +1657,7 @@ dev = [ { name = "types-docutils", specifier = "~=0.21.0" }, { name = "types-flask-cors", specifier = "~=5.0.0" }, { name = "types-flask-migrate", specifier = "~=4.1.0" }, - { name = "types-gevent", specifier = "~=24.11.0" }, + { name = "types-gevent", specifier = "~=25.9.0" }, { name = "types-greenlet", specifier = "~=3.1.0" }, { name = "types-html5lib", specifier = "~=1.1.11" }, { name = "types-jmespath", specifier = ">=1.0.2.20240106" }, @@ -1679,7 +1681,7 @@ dev = [ { name = "types-redis", specifier = ">=4.6.0.20241004" }, { name = "types-regex", specifier = "~=2024.11.6" }, { name = "types-setuptools", specifier = ">=80.9.0" }, - { name = "types-shapely", specifier = "~=2.0.0" }, + { name = "types-shapely", specifier = "~=2.1.0" }, { name = "types-simplejson", specifier = ">=3.20.0" }, { name = "types-six", specifier = ">=1.17.0" }, { name = "types-tensorflow", specifier = ">=2.18.0" }, @@ -3083,12 +3085,13 @@ wheels = [ [[package]] name = "kubernetes" -version = "34.1.0" +version = "33.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, { name = "durationpy" }, { name = "google-auth" }, + { name = "oauthlib" }, { name = "python-dateutil" }, { name = "pyyaml" }, { name = "requests" }, @@ -3097,9 +3100,9 @@ dependencies = [ { name = "urllib3" }, { name = "websocket-client" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ef/55/3f880ef65f559cbed44a9aa20d3bdbc219a2c3a3bac4a30a513029b03ee9/kubernetes-34.1.0.tar.gz", hash = "sha256:8fe8edb0b5d290a2f3ac06596b23f87c658977d46b5f8df9d0f4ea83d0003912", size = 1083771, upload-time = "2025-09-29T20:23:49.283Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ae/52/19ebe8004c243fdfa78268a96727c71e08f00ff6fe69a301d0b7fcbce3c2/kubernetes-33.1.0.tar.gz", hash = "sha256:f64d829843a54c251061a8e7a14523b521f2dc5c896cf6d65ccf348648a88993", size = 1036779, upload-time = "2025-06-09T21:57:58.521Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ca/ec/65f7d563aa4a62dd58777e8f6aa882f15db53b14eb29aba0c28a20f7eb26/kubernetes-34.1.0-py2.py3-none-any.whl", hash = "sha256:bffba2272534e224e6a7a74d582deb0b545b7c9879d2cd9e4aae9481d1f2cc2a", size = 2008380, upload-time = "2025-09-29T20:23:47.684Z" }, + { url = "https://files.pythonhosted.org/packages/89/43/d9bebfc3db7dea6ec80df5cb2aad8d274dd18ec2edd6c4f21f32c237cbbb/kubernetes-33.1.0-py2.py3-none-any.whl", hash = "sha256:544de42b24b64287f7e0aa9513c93cb503f7f40eea39b20f66810011a86eabc5", size = 1941335, upload-time = "2025-06-09T21:57:56.327Z" }, ] [[package]] @@ -4672,27 +4675,27 @@ wheels = [ [[package]] name = "pyarrow" -version = "14.0.2" +version = "17.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d7/8b/d18b7eb6fb22e5ed6ffcbc073c85dae635778dbd1270a6cf5d750b031e84/pyarrow-14.0.2.tar.gz", hash = "sha256:36cef6ba12b499d864d1def3e990f97949e0b79400d08b7cf74504ffbd3eb025", size = 1063645, upload-time = "2023-12-18T15:43:41.625Z" } +sdist = { url = "https://files.pythonhosted.org/packages/27/4e/ea6d43f324169f8aec0e57569443a38bab4b398d09769ca64f7b4d467de3/pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28", size = 1112479, upload-time = "2024-07-17T10:41:25.092Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/94/8a/411ef0b05483076b7f548c74ccaa0f90c1e60d3875db71a821f6ffa8cf42/pyarrow-14.0.2-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:87482af32e5a0c0cce2d12eb3c039dd1d853bd905b04f3f953f147c7a196915b", size = 26904455, upload-time = "2023-12-18T15:40:43.477Z" }, - { url = "https://files.pythonhosted.org/packages/6c/6c/882a57798877e3a49ba54d8e0540bea24aed78fb42e1d860f08c3449c75e/pyarrow-14.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:059bd8f12a70519e46cd64e1ba40e97eae55e0cbe1695edd95384653d7626b23", size = 23997116, upload-time = "2023-12-18T15:40:48.533Z" }, - { url = "https://files.pythonhosted.org/packages/ec/3f/ef47fe6192ce4d82803a073db449b5292135406c364a7fc49dfbcd34c987/pyarrow-14.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f16111f9ab27e60b391c5f6d197510e3ad6654e73857b4e394861fc79c37200", size = 35944575, upload-time = "2023-12-18T15:40:55.128Z" }, - { url = "https://files.pythonhosted.org/packages/1a/90/2021e529d7f234a3909f419d4341d53382541ef77d957fa274a99c533b18/pyarrow-14.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06ff1264fe4448e8d02073f5ce45a9f934c0f3db0a04460d0b01ff28befc3696", size = 38079719, upload-time = "2023-12-18T15:41:02.565Z" }, - { url = "https://files.pythonhosted.org/packages/30/a9/474caf5fd54a6d5315aaf9284c6e8f5d071ca825325ad64c53137b646e1f/pyarrow-14.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6dd4f4b472ccf4042f1eab77e6c8bce574543f54d2135c7e396f413046397d5a", size = 35429706, upload-time = "2023-12-18T15:41:09.955Z" }, - { url = "https://files.pythonhosted.org/packages/d9/f8/cfba56f5353e51c19b0c240380ce39483f4c76e5c4aee5a000f3d75b72da/pyarrow-14.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:32356bfb58b36059773f49e4e214996888eeea3a08893e7dbde44753799b2a02", size = 38001476, upload-time = "2023-12-18T15:41:16.372Z" }, - { url = "https://files.pythonhosted.org/packages/43/3f/7bdf7dc3b3b0cfdcc60760e7880954ba99ccd0bc1e0df806f3dd61bc01cd/pyarrow-14.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:52809ee69d4dbf2241c0e4366d949ba035cbcf48409bf404f071f624ed313a2b", size = 24576230, upload-time = "2023-12-18T15:41:22.561Z" }, - { url = "https://files.pythonhosted.org/packages/69/5b/d8ab6c20c43b598228710e4e4a6cba03a01f6faa3d08afff9ce76fd0fd47/pyarrow-14.0.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:c87824a5ac52be210d32906c715f4ed7053d0180c1060ae3ff9b7e560f53f944", size = 26819585, upload-time = "2023-12-18T15:41:27.59Z" }, - { url = "https://files.pythonhosted.org/packages/2d/29/bed2643d0dd5e9570405244a61f6db66c7f4704a6e9ce313f84fa5a3675a/pyarrow-14.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a25eb2421a58e861f6ca91f43339d215476f4fe159eca603c55950c14f378cc5", size = 23965222, upload-time = "2023-12-18T15:41:32.449Z" }, - { url = "https://files.pythonhosted.org/packages/2a/34/da464632e59a8cdd083370d69e6c14eae30221acb284f671c6bc9273fadd/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c1da70d668af5620b8ba0a23f229030a4cd6c5f24a616a146f30d2386fec422", size = 35942036, upload-time = "2023-12-18T15:41:38.767Z" }, - { url = "https://files.pythonhosted.org/packages/a8/ff/cbed4836d543b29f00d2355af67575c934999ff1d43e3f438ab0b1b394f1/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2cc61593c8e66194c7cdfae594503e91b926a228fba40b5cf25cc593563bcd07", size = 38089266, upload-time = "2023-12-18T15:41:47.617Z" }, - { url = "https://files.pythonhosted.org/packages/38/41/345011cb831d3dbb2dab762fc244c745a5df94b199223a99af52a5f7dff6/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:78ea56f62fb7c0ae8ecb9afdd7893e3a7dbeb0b04106f5c08dbb23f9c0157591", size = 35404468, upload-time = "2023-12-18T15:41:54.49Z" }, - { url = "https://files.pythonhosted.org/packages/fd/af/2fc23ca2068ff02068d8dabf0fb85b6185df40ec825973470e613dbd8790/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:37c233ddbce0c67a76c0985612fef27c0c92aef9413cf5aa56952f359fcb7379", size = 38003134, upload-time = "2023-12-18T15:42:01.593Z" }, - { url = "https://files.pythonhosted.org/packages/95/1f/9d912f66a87e3864f694e000977a6a70a644ea560289eac1d733983f215d/pyarrow-14.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:e4b123ad0f6add92de898214d404e488167b87b5dd86e9a434126bc2b7a5578d", size = 25043754, upload-time = "2023-12-18T15:42:07.108Z" }, + { url = "https://files.pythonhosted.org/packages/f9/46/ce89f87c2936f5bb9d879473b9663ce7a4b1f4359acc2f0eb39865eaa1af/pyarrow-17.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:1c8856e2ef09eb87ecf937104aacfa0708f22dfeb039c363ec99735190ffb977", size = 29028748, upload-time = "2024-07-16T10:30:02.609Z" }, + { url = "https://files.pythonhosted.org/packages/8d/8e/ce2e9b2146de422f6638333c01903140e9ada244a2a477918a368306c64c/pyarrow-17.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e19f569567efcbbd42084e87f948778eb371d308e137a0f97afe19bb860ccb3", size = 27190965, upload-time = "2024-07-16T10:30:10.718Z" }, + { url = "https://files.pythonhosted.org/packages/3b/c8/5675719570eb1acd809481c6d64e2136ffb340bc387f4ca62dce79516cea/pyarrow-17.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b244dc8e08a23b3e352899a006a26ae7b4d0da7bb636872fa8f5884e70acf15", size = 39269081, upload-time = "2024-07-16T10:30:18.878Z" }, + { url = "https://files.pythonhosted.org/packages/5e/78/3931194f16ab681ebb87ad252e7b8d2c8b23dad49706cadc865dff4a1dd3/pyarrow-17.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b72e87fe3e1db343995562f7fff8aee354b55ee83d13afba65400c178ab2597", size = 39864921, upload-time = "2024-07-16T10:30:27.008Z" }, + { url = "https://files.pythonhosted.org/packages/d8/81/69b6606093363f55a2a574c018901c40952d4e902e670656d18213c71ad7/pyarrow-17.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dc5c31c37409dfbc5d014047817cb4ccd8c1ea25d19576acf1a001fe07f5b420", size = 38740798, upload-time = "2024-07-16T10:30:34.814Z" }, + { url = "https://files.pythonhosted.org/packages/4c/21/9ca93b84b92ef927814cb7ba37f0774a484c849d58f0b692b16af8eebcfb/pyarrow-17.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e3343cb1e88bc2ea605986d4b94948716edc7a8d14afd4e2c097232f729758b4", size = 39871877, upload-time = "2024-07-16T10:30:42.672Z" }, + { url = "https://files.pythonhosted.org/packages/30/d1/63a7c248432c71c7d3ee803e706590a0b81ce1a8d2b2ae49677774b813bb/pyarrow-17.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:a27532c38f3de9eb3e90ecab63dfda948a8ca859a66e3a47f5f42d1e403c4d03", size = 25151089, upload-time = "2024-07-16T10:30:49.279Z" }, + { url = "https://files.pythonhosted.org/packages/d4/62/ce6ac1275a432b4a27c55fe96c58147f111d8ba1ad800a112d31859fae2f/pyarrow-17.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9b8a823cea605221e61f34859dcc03207e52e409ccf6354634143e23af7c8d22", size = 29019418, upload-time = "2024-07-16T10:30:55.573Z" }, + { url = "https://files.pythonhosted.org/packages/8e/0a/dbd0c134e7a0c30bea439675cc120012337202e5fac7163ba839aa3691d2/pyarrow-17.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f1e70de6cb5790a50b01d2b686d54aaf73da01266850b05e3af2a1bc89e16053", size = 27152197, upload-time = "2024-07-16T10:31:02.036Z" }, + { url = "https://files.pythonhosted.org/packages/cb/05/3f4a16498349db79090767620d6dc23c1ec0c658a668d61d76b87706c65d/pyarrow-17.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0071ce35788c6f9077ff9ecba4858108eebe2ea5a3f7cf2cf55ebc1dbc6ee24a", size = 39263026, upload-time = "2024-07-16T10:31:10.351Z" }, + { url = "https://files.pythonhosted.org/packages/c2/0c/ea2107236740be8fa0e0d4a293a095c9f43546a2465bb7df34eee9126b09/pyarrow-17.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:757074882f844411fcca735e39aae74248a1531367a7c80799b4266390ae51cc", size = 39880798, upload-time = "2024-07-16T10:31:17.66Z" }, + { url = "https://files.pythonhosted.org/packages/f6/b0/b9164a8bc495083c10c281cc65064553ec87b7537d6f742a89d5953a2a3e/pyarrow-17.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ba11c4f16976e89146781a83833df7f82077cdab7dc6232c897789343f7891a", size = 38715172, upload-time = "2024-07-16T10:31:25.965Z" }, + { url = "https://files.pythonhosted.org/packages/f1/c4/9625418a1413005e486c006e56675334929fad864347c5ae7c1b2e7fe639/pyarrow-17.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b0c6ac301093b42d34410b187bba560b17c0330f64907bfa4f7f7f2444b0cf9b", size = 39874508, upload-time = "2024-07-16T10:31:33.721Z" }, + { url = "https://files.pythonhosted.org/packages/ae/49/baafe2a964f663413be3bd1cf5c45ed98c5e42e804e2328e18f4570027c1/pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7", size = 25099235, upload-time = "2024-07-16T10:31:40.893Z" }, ] [[package]] @@ -6287,15 +6290,15 @@ wheels = [ [[package]] name = "types-gevent" -version = "24.11.0.20250401" +version = "25.9.0.20251102" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "types-greenlet" }, { name = "types-psutil" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f8/db/bdade74c3ba3a266eafd625377eb7b9b37c9c724c7472192100baf0fe507/types_gevent-24.11.0.20250401.tar.gz", hash = "sha256:1443f796a442062698e67d818fca50aa88067dee4021d457a7c0c6bedd6f46ca", size = 36980, upload-time = "2025-04-01T03:07:30.365Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/21/552d818a475e1a31780fb7ae50308feb64211a05eb403491d1a34df95e5f/types_gevent-25.9.0.20251102.tar.gz", hash = "sha256:76f93513af63f4577bb4178c143676dd6c4780abc305f405a4e8ff8f1fa177f8", size = 38096, upload-time = "2025-11-02T03:07:42.112Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/25/3d/c8b12d048565ef12ae65d71a0e566f36c6e076b158d3f94d87edddbeea6b/types_gevent-24.11.0.20250401-py3-none-any.whl", hash = "sha256:6764faf861ea99250c38179c58076392c44019ac3393029f71b06c4a15e8c1d1", size = 54863, upload-time = "2025-04-01T03:07:29.147Z" }, + { url = "https://files.pythonhosted.org/packages/60/a1/776d2de31a02123f225aaa790641113ae47f738f6e8e3091d3012240a88e/types_gevent-25.9.0.20251102-py3-none-any.whl", hash = "sha256:0f14b9977cb04bf3d94444b5ae6ec5d78ac30f74c4df83483e0facec86f19d8b", size = 55592, upload-time = "2025-11-02T03:07:41.003Z" }, ] [[package]] @@ -6554,14 +6557,14 @@ wheels = [ [[package]] name = "types-shapely" -version = "2.0.0.20250404" +version = "2.1.0.20250917" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4e/55/c71a25fd3fc9200df4d0b5fd2f6d74712a82f9a8bbdd90cefb9e6aee39dd/types_shapely-2.0.0.20250404.tar.gz", hash = "sha256:863f540b47fa626c33ae64eae06df171f9ab0347025d4458d2df496537296b4f", size = 25066, upload-time = "2025-04-04T02:54:30.592Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fa/19/7f28b10994433d43b9caa66f3b9bd6a0a9192b7ce8b5a7fc41534e54b821/types_shapely-2.1.0.20250917.tar.gz", hash = "sha256:5c56670742105aebe40c16414390d35fcaa55d6f774d328c1a18273ab0e2134a", size = 26363, upload-time = "2025-09-17T02:47:44.604Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ce/ff/7f4d414eb81534ba2476f3d54f06f1463c2ebf5d663fd10cff16ba607dd6/types_shapely-2.0.0.20250404-py3-none-any.whl", hash = "sha256:170fb92f5c168a120db39b3287697fdec5c93ef3e1ad15e52552c36b25318821", size = 36350, upload-time = "2025-04-04T02:54:29.506Z" }, + { url = "https://files.pythonhosted.org/packages/e5/a9/554ac40810e530263b6163b30a2b623bc16aae3fb64416f5d2b3657d0729/types_shapely-2.1.0.20250917-py3-none-any.whl", hash = "sha256:9334a79339504d39b040426be4938d422cec419168414dc74972aa746a8bf3a1", size = 37813, upload-time = "2025-09-17T02:47:43.788Z" }, ] [[package]] @@ -6799,11 +6802,11 @@ wheels = [ [[package]] name = "urllib3" -version = "2.3.0" +version = "2.5.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/aa/63/e53da845320b757bf29ef6a9062f5c669fe997973f966045cb019c3f4b66/urllib3-2.3.0.tar.gz", hash = "sha256:f8c5449b3cf0861679ce7e0503c7b44b5ec981bec0d1d3795a07f1ba96f0204d", size = 307268, upload-time = "2024-12-22T07:47:30.032Z" } +sdist = { url = "https://files.pythonhosted.org/packages/15/22/9ee70a2574a4f4599c47dd506532914ce044817c7752a79b6a51286319bc/urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760", size = 393185, upload-time = "2025-06-18T14:07:41.644Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369, upload-time = "2024-12-22T07:47:28.074Z" }, + { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, ] [[package]] @@ -7075,14 +7078,14 @@ wheels = [ [[package]] name = "werkzeug" -version = "3.1.3" +version = "3.1.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markupsafe" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9f/69/83029f1f6300c5fb2471d621ab06f6ec6b3324685a2ce0f9777fd4a8b71e/werkzeug-3.1.3.tar.gz", hash = "sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746", size = 806925, upload-time = "2024-11-08T15:52:18.093Z" } +sdist = { url = "https://files.pythonhosted.org/packages/45/ea/b0f8eeb287f8df9066e56e831c7824ac6bab645dd6c7a8f4b2d767944f9b/werkzeug-3.1.4.tar.gz", hash = "sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e", size = 864687, upload-time = "2025-11-29T02:15:22.841Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/52/24/ab44c871b0f07f491e5d2ad12c9bd7358e527510618cb1b803a88e986db1/werkzeug-3.1.3-py3-none-any.whl", hash = "sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e", size = 224498, upload-time = "2024-11-08T15:52:16.132Z" }, + { url = "https://files.pythonhosted.org/packages/2f/f9/9e082990c2585c744734f85bec79b5dae5df9c974ffee58fe421652c8e91/werkzeug-3.1.4-py3-none-any.whl", hash = "sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905", size = 224960, upload-time = "2025-11-29T02:15:21.13Z" }, ] [[package]] diff --git a/docker/.env.example b/docker/.env.example index c9981baaba..04088b72a8 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -233,7 +233,7 @@ NEXT_PUBLIC_ENABLE_SINGLE_DOLLAR_LATEX=false # Database type, supported values are `postgresql` and `mysql` DB_TYPE=postgresql - +# For MySQL, only `root` user is supported for now DB_USERNAME=postgres DB_PASSWORD=difyai123456 DB_HOST=db_postgres @@ -808,6 +808,19 @@ UPLOAD_FILE_BATCH_LIMIT=5 # Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll UPLOAD_FILE_EXTENSION_BLACKLIST= +# Maximum number of files allowed in a single chunk attachment, default 10. +SINGLE_CHUNK_ATTACHMENT_LIMIT=10 + +# Maximum number of files allowed in a image batch upload operation +IMAGE_FILE_BATCH_LIMIT=10 + +# Maximum allowed image file size for attachments in megabytes, default 2. +ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2 + +# Timeout for downloading image attachments in seconds, default 60. +ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60 + + # ETL type, support: `dify`, `Unstructured` # `dify` Dify's proprietary file extraction scheme # `Unstructured` Unstructured.io file extraction scheme @@ -1076,24 +1089,10 @@ MAX_TREE_DEPTH=50 # ------------------------------ # Environment Variables for database Service # ------------------------------ - -# The name of the default postgres user. -POSTGRES_USER=${DB_USERNAME} -# The password for the default postgres user. -POSTGRES_PASSWORD=${DB_PASSWORD} -# The name of the default postgres database. -POSTGRES_DB=${DB_DATABASE} # Postgres data directory PGDATA=/var/lib/postgresql/data/pgdata # MySQL Default Configuration -# The name of the default mysql user. -MYSQL_USERNAME=${DB_USERNAME} -# The password for the default mysql user. -MYSQL_PASSWORD=${DB_PASSWORD} -# The name of the default mysql database. -MYSQL_DATABASE=${DB_DATABASE} -# MySQL data directory MYSQL_HOST_VOLUME=./volumes/mysql/data # ------------------------------ @@ -1129,6 +1128,10 @@ WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih WEAVIATE_AUTHENTICATION_APIKEY_USERS=hello@dify.ai WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED=true WEAVIATE_AUTHORIZATION_ADMINLIST_USERS=hello@dify.ai +WEAVIATE_DISABLE_TELEMETRY=false +WEAVIATE_ENABLE_TOKENIZER_GSE=false +WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA=false +WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR=false # ------------------------------ # Environment Variables for Chroma @@ -1428,4 +1431,7 @@ WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100 WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0 # Tenant isolated task queue configuration -TENANT_ISOLATED_TASK_CONCURRENCY=1 \ No newline at end of file +TENANT_ISOLATED_TASK_CONCURRENCY=1 + +# The API key of amplitude +AMPLITUDE_API_KEY= diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 703a60ef67..b12d06ca97 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -1,8 +1,27 @@ x-shared-env: &shared-api-worker-env services: + # Init container to fix permissions + init_permissions: + image: busybox:latest + command: + - sh + - -c + - | + FLAG_FILE="/app/api/storage/.init_permissions" + if [ -f "$${FLAG_FILE}" ]; then + echo "Permissions already initialized. Exiting." + exit 0 + fi + echo "Initializing permissions for /app/api/storage" + chown -R 1001:1001 /app/api/storage && touch "$${FLAG_FILE}" + echo "Permissions initialized. Exiting." + volumes: + - ./volumes/app/storage:/app/api/storage + restart: "no" + # API service api: - image: langgenius/dify-api:1.10.1 + image: langgenius/dify-api:1.11.1 restart: always environment: # Use the shared environment variables. @@ -17,6 +36,8 @@ services: PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: + init_permissions: + condition: service_completed_successfully db_postgres: condition: service_healthy required: false @@ -41,7 +62,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.10.1 + image: langgenius/dify-api:1.11.1 restart: always environment: # Use the shared environment variables. @@ -54,6 +75,8 @@ services: PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: + init_permissions: + condition: service_completed_successfully db_postgres: condition: service_healthy required: false @@ -78,7 +101,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.10.1 + image: langgenius/dify-api:1.11.1 restart: always environment: # Use the shared environment variables. @@ -86,6 +109,8 @@ services: # Startup mode, 'worker_beat' starts the Celery beat for scheduling periodic tasks. MODE: beat depends_on: + init_permissions: + condition: service_completed_successfully db_postgres: condition: service_healthy required: false @@ -106,11 +131,12 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.10.1 + image: langgenius/dify-web:1.11.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} APP_API_URL: ${APP_API_URL:-} + AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-} NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-} SENTRY_DSN: ${WEB_SENTRY_DSN:-} NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0} @@ -139,9 +165,9 @@ services: - postgresql restart: always environment: - POSTGRES_USER: ${POSTGRES_USER:-postgres} - POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456} - POSTGRES_DB: ${POSTGRES_DB:-dify} + POSTGRES_USER: ${DB_USERNAME:-postgres} + POSTGRES_PASSWORD: ${DB_PASSWORD:-difyai123456} + POSTGRES_DB: ${DB_DATABASE:-dify} PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata} command: > postgres -c 'max_connections=${POSTGRES_MAX_CONNECTIONS:-100}' @@ -161,7 +187,7 @@ services: "-h", "db_postgres", "-U", - "${PGUSER:-postgres}", + "${DB_USERNAME:-postgres}", "-d", "${DB_DATABASE:-dify}", ] @@ -176,8 +202,8 @@ services: - mysql restart: always environment: - MYSQL_ROOT_PASSWORD: ${MYSQL_PASSWORD:-difyai123456} - MYSQL_DATABASE: ${MYSQL_DATABASE:-dify} + MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-difyai123456} + MYSQL_DATABASE: ${DB_DATABASE:-dify} command: > --max_connections=1000 --innodb_buffer_pool_size=${MYSQL_INNODB_BUFFER_POOL_SIZE:-512M} @@ -193,7 +219,7 @@ services: "ping", "-u", "root", - "-p${MYSQL_PASSWORD:-difyai123456}", + "-p${DB_PASSWORD:-difyai123456}", ] interval: 1s timeout: 3s @@ -243,7 +269,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.4.1-local + image: langgenius/dify-plugin-daemon:0.5.1-local restart: always environment: # Use the shared environment variables. @@ -425,6 +451,10 @@ services: AUTHENTICATION_APIKEY_USERS: ${WEAVIATE_AUTHENTICATION_APIKEY_USERS:-hello@dify.ai} AUTHORIZATION_ADMINLIST_ENABLED: ${WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED:-true} AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai} + DISABLE_TELEMETRY: ${WEAVIATE_DISABLE_TELEMETRY:-false} + ENABLE_TOKENIZER_GSE: ${WEAVIATE_ENABLE_TOKENIZER_GSE:-false} + ENABLE_TOKENIZER_KAGOME_JA: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA:-false} + ENABLE_TOKENIZER_KAGOME_KR: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR:-false} # OceanBase vector database oceanbase: diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index f1beefc2f2..68ef217bbd 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -9,8 +9,8 @@ services: env_file: - ./middleware.env environment: - POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456} - POSTGRES_DB: ${POSTGRES_DB:-dify} + POSTGRES_PASSWORD: ${DB_PASSWORD:-difyai123456} + POSTGRES_DB: ${DB_DATABASE:-dify} PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata} command: > postgres -c 'max_connections=${POSTGRES_MAX_CONNECTIONS:-100}' @@ -32,9 +32,9 @@ services: "-h", "db_postgres", "-U", - "${PGUSER:-postgres}", + "${DB_USERNAME:-postgres}", "-d", - "${POSTGRES_DB:-dify}", + "${DB_DATABASE:-dify}", ] interval: 1s timeout: 3s @@ -48,8 +48,8 @@ services: env_file: - ./middleware.env environment: - MYSQL_ROOT_PASSWORD: ${MYSQL_PASSWORD:-difyai123456} - MYSQL_DATABASE: ${MYSQL_DATABASE:-dify} + MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-difyai123456} + MYSQL_DATABASE: ${DB_DATABASE:-dify} command: > --max_connections=1000 --innodb_buffer_pool_size=${MYSQL_INNODB_BUFFER_POOL_SIZE:-512M} @@ -67,7 +67,7 @@ services: "ping", "-u", "root", - "-p${MYSQL_PASSWORD:-difyai123456}", + "-p${DB_PASSWORD:-difyai123456}", ] interval: 1s timeout: 3s @@ -123,7 +123,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.4.1-local + image: langgenius/dify-plugin-daemon:0.5.1-local restart: always env_file: - ./middleware.env @@ -238,6 +238,7 @@ services: AUTHENTICATION_APIKEY_USERS: ${WEAVIATE_AUTHENTICATION_APIKEY_USERS:-hello@dify.ai} AUTHORIZATION_ADMINLIST_ENABLED: ${WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED:-true} AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai} + DISABLE_TELEMETRY: ${WEAVIATE_DISABLE_TELEMETRY:-false} ports: - "${EXPOSE_WEAVIATE_PORT:-8080}:8080" - "${EXPOSE_WEAVIATE_GRPC_PORT:-50051}:50051" diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index de2e3943fe..825f0650c8 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -364,6 +364,10 @@ x-shared-env: &shared-api-worker-env UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} UPLOAD_FILE_EXTENSION_BLACKLIST: ${UPLOAD_FILE_EXTENSION_BLACKLIST:-} + SINGLE_CHUNK_ATTACHMENT_LIMIT: ${SINGLE_CHUNK_ATTACHMENT_LIMIT:-10} + IMAGE_FILE_BATCH_LIMIT: ${IMAGE_FILE_BATCH_LIMIT:-10} + ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: ${ATTACHMENT_IMAGE_FILE_SIZE_LIMIT:-2} + ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: ${ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT:-60} ETL_TYPE: ${ETL_TYPE:-dify} UNSTRUCTURED_API_URL: ${UNSTRUCTURED_API_URL:-} UNSTRUCTURED_API_KEY: ${UNSTRUCTURED_API_KEY:-} @@ -455,13 +459,7 @@ x-shared-env: &shared-api-worker-env TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false} MAX_TREE_DEPTH: ${MAX_TREE_DEPTH:-50} - POSTGRES_USER: ${POSTGRES_USER:-${DB_USERNAME}} - POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}} - POSTGRES_DB: ${POSTGRES_DB:-${DB_DATABASE}} PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata} - MYSQL_USERNAME: ${MYSQL_USERNAME:-${DB_USERNAME}} - MYSQL_PASSWORD: ${MYSQL_PASSWORD:-${DB_PASSWORD}} - MYSQL_DATABASE: ${MYSQL_DATABASE:-${DB_DATABASE}} MYSQL_HOST_VOLUME: ${MYSQL_HOST_VOLUME:-./volumes/mysql/data} SANDBOX_API_KEY: ${SANDBOX_API_KEY:-dify-sandbox} SANDBOX_GIN_MODE: ${SANDBOX_GIN_MODE:-release} @@ -480,6 +478,10 @@ x-shared-env: &shared-api-worker-env WEAVIATE_AUTHENTICATION_APIKEY_USERS: ${WEAVIATE_AUTHENTICATION_APIKEY_USERS:-hello@dify.ai} WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED: ${WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED:-true} WEAVIATE_AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai} + WEAVIATE_DISABLE_TELEMETRY: ${WEAVIATE_DISABLE_TELEMETRY:-false} + WEAVIATE_ENABLE_TOKENIZER_GSE: ${WEAVIATE_ENABLE_TOKENIZER_GSE:-false} + WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA:-false} + WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR:-false} CHROMA_SERVER_AUTHN_CREDENTIALS: ${CHROMA_SERVER_AUTHN_CREDENTIALS:-difyai123456} CHROMA_SERVER_AUTHN_PROVIDER: ${CHROMA_SERVER_AUTHN_PROVIDER:-chromadb.auth.token_authn.TokenAuthenticationServerProvider} CHROMA_IS_PERSISTENT: ${CHROMA_IS_PERSISTENT:-TRUE} @@ -633,11 +635,31 @@ x-shared-env: &shared-api-worker-env WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: ${WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE:-100} WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: ${WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK:-0} TENANT_ISOLATED_TASK_CONCURRENCY: ${TENANT_ISOLATED_TASK_CONCURRENCY:-1} + AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-} services: + # Init container to fix permissions + init_permissions: + image: busybox:latest + command: + - sh + - -c + - | + FLAG_FILE="/app/api/storage/.init_permissions" + if [ -f "$${FLAG_FILE}" ]; then + echo "Permissions already initialized. Exiting." + exit 0 + fi + echo "Initializing permissions for /app/api/storage" + chown -R 1001:1001 /app/api/storage && touch "$${FLAG_FILE}" + echo "Permissions initialized. Exiting." + volumes: + - ./volumes/app/storage:/app/api/storage + restart: "no" + # API service api: - image: langgenius/dify-api:1.10.1 + image: langgenius/dify-api:1.11.1 restart: always environment: # Use the shared environment variables. @@ -652,6 +674,8 @@ services: PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: + init_permissions: + condition: service_completed_successfully db_postgres: condition: service_healthy required: false @@ -676,7 +700,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.10.1 + image: langgenius/dify-api:1.11.1 restart: always environment: # Use the shared environment variables. @@ -689,6 +713,8 @@ services: PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1} depends_on: + init_permissions: + condition: service_completed_successfully db_postgres: condition: service_healthy required: false @@ -713,7 +739,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.10.1 + image: langgenius/dify-api:1.11.1 restart: always environment: # Use the shared environment variables. @@ -721,6 +747,8 @@ services: # Startup mode, 'worker_beat' starts the Celery beat for scheduling periodic tasks. MODE: beat depends_on: + init_permissions: + condition: service_completed_successfully db_postgres: condition: service_healthy required: false @@ -741,11 +769,12 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.10.1 + image: langgenius/dify-web:1.11.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} APP_API_URL: ${APP_API_URL:-} + AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-} NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-} SENTRY_DSN: ${WEB_SENTRY_DSN:-} NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0} @@ -774,9 +803,9 @@ services: - postgresql restart: always environment: - POSTGRES_USER: ${POSTGRES_USER:-postgres} - POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456} - POSTGRES_DB: ${POSTGRES_DB:-dify} + POSTGRES_USER: ${DB_USERNAME:-postgres} + POSTGRES_PASSWORD: ${DB_PASSWORD:-difyai123456} + POSTGRES_DB: ${DB_DATABASE:-dify} PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata} command: > postgres -c 'max_connections=${POSTGRES_MAX_CONNECTIONS:-100}' @@ -796,7 +825,7 @@ services: "-h", "db_postgres", "-U", - "${PGUSER:-postgres}", + "${DB_USERNAME:-postgres}", "-d", "${DB_DATABASE:-dify}", ] @@ -811,8 +840,8 @@ services: - mysql restart: always environment: - MYSQL_ROOT_PASSWORD: ${MYSQL_PASSWORD:-difyai123456} - MYSQL_DATABASE: ${MYSQL_DATABASE:-dify} + MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-difyai123456} + MYSQL_DATABASE: ${DB_DATABASE:-dify} command: > --max_connections=1000 --innodb_buffer_pool_size=${MYSQL_INNODB_BUFFER_POOL_SIZE:-512M} @@ -828,7 +857,7 @@ services: "ping", "-u", "root", - "-p${MYSQL_PASSWORD:-difyai123456}", + "-p${DB_PASSWORD:-difyai123456}", ] interval: 1s timeout: 3s @@ -878,7 +907,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.4.1-local + image: langgenius/dify-plugin-daemon:0.5.1-local restart: always environment: # Use the shared environment variables. @@ -1060,6 +1089,10 @@ services: AUTHENTICATION_APIKEY_USERS: ${WEAVIATE_AUTHENTICATION_APIKEY_USERS:-hello@dify.ai} AUTHORIZATION_ADMINLIST_ENABLED: ${WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED:-true} AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai} + DISABLE_TELEMETRY: ${WEAVIATE_DISABLE_TELEMETRY:-false} + ENABLE_TOKENIZER_GSE: ${WEAVIATE_ENABLE_TOKENIZER_GSE:-false} + ENABLE_TOKENIZER_KAGOME_JA: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_JA:-false} + ENABLE_TOKENIZER_KAGOME_KR: ${WEAVIATE_ENABLE_TOKENIZER_KAGOME_KR:-false} # OceanBase vector database oceanbase: diff --git a/docker/middleware.env.example b/docker/middleware.env.example index dbfb75a8d6..d4cbcd1762 100644 --- a/docker/middleware.env.example +++ b/docker/middleware.env.example @@ -4,6 +4,7 @@ # Database Configuration # Database type, supported values are `postgresql` and `mysql` DB_TYPE=postgresql +# For MySQL, only `root` user is supported for now DB_USERNAME=postgres DB_PASSWORD=difyai123456 DB_HOST=db_postgres @@ -11,11 +12,6 @@ DB_PORT=5432 DB_DATABASE=dify # PostgreSQL Configuration -POSTGRES_USER=${DB_USERNAME} -# The password for the default postgres user. -POSTGRES_PASSWORD=${DB_PASSWORD} -# The name of the default postgres database. -POSTGRES_DB=${DB_DATABASE} # postgres data directory PGDATA=/var/lib/postgresql/data/pgdata PGDATA_HOST_VOLUME=./volumes/db/data @@ -65,11 +61,6 @@ POSTGRES_STATEMENT_TIMEOUT=0 POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=0 # MySQL Configuration -MYSQL_USERNAME=${DB_USERNAME} -# MySQL password -MYSQL_PASSWORD=${DB_PASSWORD} -# MySQL database name -MYSQL_DATABASE=${DB_DATABASE} # MySQL data directory host volume MYSQL_HOST_VOLUME=./volumes/mysql/data @@ -132,6 +123,7 @@ WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih WEAVIATE_AUTHENTICATION_APIKEY_USERS=hello@dify.ai WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED=true WEAVIATE_AUTHORIZATION_ADMINLIST_USERS=hello@dify.ai +WEAVIATE_DISABLE_TELEMETRY=false WEAVIATE_HOST_VOLUME=./volumes/weaviate # ------------------------------ diff --git a/docs/suggested-questions-configuration.md b/docs/suggested-questions-configuration.md new file mode 100644 index 0000000000..c726d3b157 --- /dev/null +++ b/docs/suggested-questions-configuration.md @@ -0,0 +1,253 @@ +# Configurable Suggested Questions After Answer + +This document explains how to configure the "Suggested Questions After Answer" feature in Dify using environment variables. + +## Overview + +The suggested questions feature generates follow-up questions after each AI response to help users continue the conversation. By default, Dify generates 3 short questions (under 20 characters each), but you can customize this behavior to better fit your specific use case. + +## Environment Variables + +### `SUGGESTED_QUESTIONS_PROMPT` + +**Description**: Custom prompt template for generating suggested questions. + +**Default**: + +``` +Please help me predict the three most likely questions that human would ask, and keep each question under 20 characters. +MAKE SURE your output is the SAME language as the Assistant's latest response. +The output must be an array in JSON format following the specified schema: +["question1","question2","question3"] +``` + +**Usage Examples**: + +1. **Technical/Developer Questions (Your Use Case)**: + + ```bash + export SUGGESTED_QUESTIONS_PROMPT='Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: ["question1","question2","question3","question4","question5"]' + ``` + +1. **Customer Support**: + + ```bash + export SUGGESTED_QUESTIONS_PROMPT='Generate 3 helpful follow-up questions that guide customers toward solving their own problems. Focus on troubleshooting steps and common issues. Keep questions under 30 characters. JSON format: ["q1","q2","q3"]' + ``` + +1. **Educational Content**: + + ```bash + export SUGGESTED_QUESTIONS_PROMPT='Create 4 thought-provoking questions that help students deeper understand the topic. Focus on concepts, relationships, and applications. Questions should be 25-40 characters. JSON: ["question1","question2","question3","question4"]' + ``` + +1. **Multilingual Support**: + + ```bash + export SUGGESTED_QUESTIONS_PROMPT='Generate exactly 3 follow-up questions in the same language as the conversation. Adapt question length appropriately for the language (Chinese: 10-15 chars, English: 20-30 chars, Arabic: 25-35 chars). Always output valid JSON array.' + ``` + +**Important Notes**: + +- The prompt must request JSON array output format +- Include language matching instructions for multilingual support +- Specify clear character limits or question count requirements +- Focus on your specific domain or use case + +### `SUGGESTED_QUESTIONS_MAX_TOKENS` + +**Description**: Maximum number of tokens for the LLM response. + +**Default**: `256` + +**Usage**: + +```bash +export SUGGESTED_QUESTIONS_MAX_TOKENS=512 # For longer questions or more questions +``` + +**Recommended Values**: + +- `256`: Default, good for 3-4 short questions +- `384`: Medium, good for 4-5 medium-length questions +- `512`: High, good for 5+ longer questions or complex prompts +- `1024`: Maximum, for very complex question generation + +### `SUGGESTED_QUESTIONS_TEMPERATURE` + +**Description**: Temperature parameter for LLM creativity. + +**Default**: `0.0` + +**Usage**: + +```bash +export SUGGESTED_QUESTIONS_TEMPERATURE=0.3 # Balanced creativity +``` + +**Recommended Values**: + +- `0.0-0.2`: Very focused, predictable questions (good for technical support) +- `0.3-0.5`: Balanced creativity and relevance (good for general use) +- `0.6-0.8`: More creative, diverse questions (good for brainstorming) +- `0.9-1.0`: Maximum creativity (good for educational exploration) + +## Configuration Examples + +### Example 1: Developer Documentation Chatbot + +```bash +# .env file +SUGGESTED_QUESTIONS_PROMPT='Generate exactly 5 technical follow-up questions that developers would ask after reading code documentation. Focus on implementation details, edge cases, performance considerations, and best practices. Each question should be 40-60 characters long. Output as JSON array: ["question1","question2","question3","question4","question5"]' +SUGGESTED_QUESTIONS_MAX_TOKENS=512 +SUGGESTED_QUESTIONS_TEMPERATURE=0.3 +``` + +### Example 2: Customer Service Bot + +```bash +# .env file +SUGGESTED_QUESTIONS_PROMPT='Create 3 actionable follow-up questions that help customers resolve their own issues. Focus on common problems, troubleshooting steps, and product features. Keep questions simple and under 25 characters. JSON: ["q1","q2","q3"]' +SUGGESTED_QUESTIONS_MAX_TOKENS=256 +SUGGESTED_QUESTIONS_TEMPERATURE=0.1 +``` + +### Example 3: Educational Tutor + +```bash +# .env file +SUGGESTED_QUESTIONS_PROMPT='Generate 4 thought-provoking questions that help students deepen their understanding of the topic. Focus on relationships between concepts, practical applications, and critical thinking. Questions should be 30-45 characters. Output: ["question1","question2","question3","question4"]' +SUGGESTED_QUESTIONS_MAX_TOKENS=384 +SUGGESTED_QUESTIONS_TEMPERATURE=0.6 +``` + +## Implementation Details + +### How It Works + +1. **Environment Variable Loading**: The system checks for environment variables at startup +1. **Fallback to Defaults**: If no environment variables are set, original behavior is preserved +1. **Prompt Template**: The custom prompt is used as-is, allowing full control over question generation +1. **LLM Parameters**: Custom max_tokens and temperature are passed to the LLM API +1. **JSON Parsing**: The system expects JSON array output and parses it accordingly + +### File Changes + +The implementation modifies these files: + +- `api/core/llm_generator/prompts.py`: Environment variable support +- `api/core/llm_generator/llm_generator.py`: Custom LLM parameters +- `api/.env.example`: Documentation of new variables + +### Backward Compatibility + +- ✅ **Zero Breaking Changes**: Works exactly as before if no environment variables are set +- ✅ **Default Behavior Preserved**: Original prompt and parameters used as fallbacks +- ✅ **No Database Changes**: Pure environment variable configuration +- ✅ **No UI Changes Required**: Configuration happens at deployment level + +## Testing Your Configuration + +### Local Testing + +1. Set environment variables: + + ```bash + export SUGGESTED_QUESTIONS_PROMPT='Your test prompt...' + export SUGGESTED_QUESTIONS_MAX_TOKENS=300 + export SUGGESTED_QUESTIONS_TEMPERATURE=0.4 + ``` + +1. Start Dify API: + + ```bash + cd api + python -m flask run --host 0.0.0.0 --port=5001 --debug + ``` + +1. Test the feature in your chat application and verify the questions match your expectations. + +### Monitoring + +Monitor the following when testing: + +- **Question Quality**: Are questions relevant and helpful? +- **Language Matching**: Do questions match the conversation language? +- **JSON Format**: Is output properly formatted as JSON array? +- **Length Constraints**: Do questions follow your length requirements? +- **Response Time**: Are the custom parameters affecting performance? + +## Troubleshooting + +### Common Issues + +1. **Invalid JSON Output**: + + - **Problem**: LLM doesn't return valid JSON + - **Solution**: Make sure your prompt explicitly requests JSON array format + +1. **Questions Too Long/Short**: + + - **Problem**: Questions don't follow length constraints + - **Solution**: Be more specific about character limits in your prompt + +1. **Too Few/Many Questions**: + + - **Problem**: Wrong number of questions generated + - **Solution**: Clearly specify the exact number in your prompt + +1. **Language Mismatch**: + + - **Problem**: Questions in wrong language + - **Solution**: Include explicit language matching instructions in prompt + +1. **Performance Issues**: + + - **Problem**: Slow response times + - **Solution**: Reduce `SUGGESTED_QUESTIONS_MAX_TOKENS` or simplify prompt + +### Debug Logging + +To debug your configuration, you can temporarily add logging to see the actual prompt and parameters being used: + +```python +import logging +logger = logging.getLogger(__name__) + +# In llm_generator.py +logger.info(f"Suggested questions prompt: {prompt}") +logger.info(f"Max tokens: {SUGGESTED_QUESTIONS_MAX_TOKENS}") +logger.info(f"Temperature: {SUGGESTED_QUESTIONS_TEMPERATURE}") +``` + +## Migration Guide + +### From Default Configuration + +If you're currently using the default configuration and want to customize: + +1. **Assess Your Needs**: Determine what aspects need customization (question count, length, domain focus) +1. **Design Your Prompt**: Write a custom prompt that addresses your specific use case +1. **Choose Parameters**: Select appropriate max_tokens and temperature values +1. **Test Incrementally**: Start with small changes and test thoroughly +1. **Deploy Gradually**: Roll out to production after successful testing + +### Best Practices + +1. **Start Simple**: Begin with minimal changes to the default prompt +1. **Test Thoroughly**: Test with various conversation types and languages +1. **Monitor Performance**: Watch for impact on response times and costs +1. **Get User Feedback**: Collect feedback on question quality and relevance +1. **Iterate**: Refine your configuration based on real-world usage + +## Future Enhancements + +This environment variable approach provides immediate customization while maintaining backward compatibility. Future enhancements could include: + +1. **App-Level Configuration**: Different apps with different suggested question settings +1. **Dynamic Prompts**: Context-aware prompts based on conversation content +1. **Multi-Model Support**: Different models for different types of questions +1. **Analytics Dashboard**: Insights into question effectiveness and usage patterns +1. **A/B Testing**: Built-in testing of different prompt configurations + +For now, the environment variable approach offers a simple, reliable way to customize the suggested questions feature for your specific needs. diff --git a/sdks/python-client/LICENSE b/sdks/python-client/LICENSE deleted file mode 100644 index 873e44b4bc..0000000000 --- a/sdks/python-client/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2023 LangGenius - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/sdks/python-client/MANIFEST.in b/sdks/python-client/MANIFEST.in deleted file mode 100644 index 34b7e8711c..0000000000 --- a/sdks/python-client/MANIFEST.in +++ /dev/null @@ -1,3 +0,0 @@ -recursive-include dify_client *.py -include README.md -include LICENSE diff --git a/sdks/python-client/README.md b/sdks/python-client/README.md deleted file mode 100644 index ebfb5f5397..0000000000 --- a/sdks/python-client/README.md +++ /dev/null @@ -1,409 +0,0 @@ -# dify-client - -A Dify App Service-API Client, using for build a webapp by request Service-API - -## Usage - -First, install `dify-client` python sdk package: - -``` -pip install dify-client -``` - -### Synchronous Usage - -Write your code with sdk: - -- completion generate with `blocking` response_mode - -```python -from dify_client import CompletionClient - -api_key = "your_api_key" - -# Initialize CompletionClient -completion_client = CompletionClient(api_key) - -# Create Completion Message using CompletionClient -completion_response = completion_client.create_completion_message(inputs={"query": "What's the weather like today?"}, - response_mode="blocking", user="user_id") -completion_response.raise_for_status() - -result = completion_response.json() - -print(result.get('answer')) -``` - -- completion using vision model, like gpt-4-vision - -```python -from dify_client import CompletionClient - -api_key = "your_api_key" - -# Initialize CompletionClient -completion_client = CompletionClient(api_key) - -files = [{ - "type": "image", - "transfer_method": "remote_url", - "url": "your_image_url" -}] - -# files = [{ -# "type": "image", -# "transfer_method": "local_file", -# "upload_file_id": "your_file_id" -# }] - -# Create Completion Message using CompletionClient -completion_response = completion_client.create_completion_message(inputs={"query": "Describe the picture."}, - response_mode="blocking", user="user_id", files=files) -completion_response.raise_for_status() - -result = completion_response.json() - -print(result.get('answer')) -``` - -- chat generate with `streaming` response_mode - -```python -import json -from dify_client import ChatClient - -api_key = "your_api_key" - -# Initialize ChatClient -chat_client = ChatClient(api_key) - -# Create Chat Message using ChatClient -chat_response = chat_client.create_chat_message(inputs={}, query="Hello", user="user_id", response_mode="streaming") -chat_response.raise_for_status() - -for line in chat_response.iter_lines(decode_unicode=True): - line = line.split('data:', 1)[-1] - if line.strip(): - line = json.loads(line.strip()) - print(line.get('answer')) -``` - -- chat using vision model, like gpt-4-vision - -```python -from dify_client import ChatClient - -api_key = "your_api_key" - -# Initialize ChatClient -chat_client = ChatClient(api_key) - -files = [{ - "type": "image", - "transfer_method": "remote_url", - "url": "your_image_url" -}] - -# files = [{ -# "type": "image", -# "transfer_method": "local_file", -# "upload_file_id": "your_file_id" -# }] - -# Create Chat Message using ChatClient -chat_response = chat_client.create_chat_message(inputs={}, query="Describe the picture.", user="user_id", - response_mode="blocking", files=files) -chat_response.raise_for_status() - -result = chat_response.json() - -print(result.get("answer")) -``` - -- upload file when using vision model - -```python -from dify_client import DifyClient - -api_key = "your_api_key" - -# Initialize Client -dify_client = DifyClient(api_key) - -file_path = "your_image_file_path" -file_name = "panda.jpeg" -mime_type = "image/jpeg" - -with open(file_path, "rb") as file: - files = { - "file": (file_name, file, mime_type) - } - response = dify_client.file_upload("user_id", files) - - result = response.json() - print(f'upload_file_id: {result.get("id")}') -``` - -- Others - -```python -from dify_client import ChatClient - -api_key = "your_api_key" - -# Initialize Client -client = ChatClient(api_key) - -# Get App parameters -parameters = client.get_application_parameters(user="user_id") -parameters.raise_for_status() - -print('[parameters]') -print(parameters.json()) - -# Get Conversation List (only for chat) -conversations = client.get_conversations(user="user_id") -conversations.raise_for_status() - -print('[conversations]') -print(conversations.json()) - -# Get Message List (only for chat) -messages = client.get_conversation_messages(user="user_id", conversation_id="conversation_id") -messages.raise_for_status() - -print('[messages]') -print(messages.json()) - -# Rename Conversation (only for chat) -rename_conversation_response = client.rename_conversation(conversation_id="conversation_id", - name="new_name", user="user_id") -rename_conversation_response.raise_for_status() - -print('[rename result]') -print(rename_conversation_response.json()) -``` - -- Using the Workflow Client - -```python -import json -import requests -from dify_client import WorkflowClient - -api_key = "your_api_key" - -# Initialize Workflow Client -client = WorkflowClient(api_key) - -# Prepare parameters for Workflow Client -user_id = "your_user_id" -context = "previous user interaction / metadata" -user_prompt = "What is the capital of France?" - -inputs = { - "context": context, - "user_prompt": user_prompt, - # Add other input fields expected by your workflow (e.g., additional context, task parameters) - -} - -# Set response mode (default: streaming) -response_mode = "blocking" - -# Run the workflow -response = client.run(inputs=inputs, response_mode=response_mode, user=user_id) -response.raise_for_status() - -# Parse result -result = json.loads(response.text) - -answer = result.get("data").get("outputs") - -print(answer["answer"]) - -``` - -- Dataset Management - -```python -from dify_client import KnowledgeBaseClient - -api_key = "your_api_key" -dataset_id = "your_dataset_id" - -# Use context manager to ensure proper resource cleanup -with KnowledgeBaseClient(api_key, dataset_id) as kb_client: - # Get dataset information - dataset_info = kb_client.get_dataset() - dataset_info.raise_for_status() - print(dataset_info.json()) - - # Update dataset configuration - update_response = kb_client.update_dataset( - name="Updated Dataset Name", - description="Updated description", - indexing_technique="high_quality" - ) - update_response.raise_for_status() - print(update_response.json()) - - # Batch update document status - batch_response = kb_client.batch_update_document_status( - action="enable", - document_ids=["doc_id_1", "doc_id_2", "doc_id_3"] - ) - batch_response.raise_for_status() - print(batch_response.json()) -``` - -- Conversation Variables Management - -```python -from dify_client import ChatClient - -api_key = "your_api_key" - -# Use context manager to ensure proper resource cleanup -with ChatClient(api_key) as chat_client: - # Get all conversation variables - variables = chat_client.get_conversation_variables( - conversation_id="conversation_id", - user="user_id" - ) - variables.raise_for_status() - print(variables.json()) - - # Update a specific conversation variable - update_var = chat_client.update_conversation_variable( - conversation_id="conversation_id", - variable_id="variable_id", - value="new_value", - user="user_id" - ) - update_var.raise_for_status() - print(update_var.json()) -``` - -### Asynchronous Usage - -The SDK provides full async/await support for all API operations using `httpx.AsyncClient`. All async clients mirror their synchronous counterparts but require `await` for method calls. - -- async chat with `blocking` response_mode - -```python -import asyncio -from dify_client import AsyncChatClient - -api_key = "your_api_key" - -async def main(): - # Use async context manager for proper resource cleanup - async with AsyncChatClient(api_key) as client: - response = await client.create_chat_message( - inputs={}, - query="Hello, how are you?", - user="user_id", - response_mode="blocking" - ) - response.raise_for_status() - result = response.json() - print(result.get('answer')) - -# Run the async function -asyncio.run(main()) -``` - -- async completion with `streaming` response_mode - -```python -import asyncio -import json -from dify_client import AsyncCompletionClient - -api_key = "your_api_key" - -async def main(): - async with AsyncCompletionClient(api_key) as client: - response = await client.create_completion_message( - inputs={"query": "What's the weather?"}, - response_mode="streaming", - user="user_id" - ) - response.raise_for_status() - - # Stream the response - async for line in response.aiter_lines(): - if line.startswith('data:'): - data = line[5:].strip() - if data: - chunk = json.loads(data) - print(chunk.get('answer', ''), end='', flush=True) - -asyncio.run(main()) -``` - -- async workflow execution - -```python -import asyncio -from dify_client import AsyncWorkflowClient - -api_key = "your_api_key" - -async def main(): - async with AsyncWorkflowClient(api_key) as client: - response = await client.run( - inputs={"query": "What is machine learning?"}, - response_mode="blocking", - user="user_id" - ) - response.raise_for_status() - result = response.json() - print(result.get("data").get("outputs")) - -asyncio.run(main()) -``` - -- async dataset management - -```python -import asyncio -from dify_client import AsyncKnowledgeBaseClient - -api_key = "your_api_key" -dataset_id = "your_dataset_id" - -async def main(): - async with AsyncKnowledgeBaseClient(api_key, dataset_id) as kb_client: - # Get dataset information - dataset_info = await kb_client.get_dataset() - dataset_info.raise_for_status() - print(dataset_info.json()) - - # List documents - docs = await kb_client.list_documents(page=1, page_size=10) - docs.raise_for_status() - print(docs.json()) - -asyncio.run(main()) -``` - -**Benefits of Async Usage:** - -- **Better Performance**: Handle multiple concurrent API requests efficiently -- **Non-blocking I/O**: Don't block the event loop during network operations -- **Scalability**: Ideal for applications handling many simultaneous requests -- **Modern Python**: Leverages Python's native async/await syntax - -**Available Async Clients:** - -- `AsyncDifyClient` - Base async client -- `AsyncChatClient` - Async chat operations -- `AsyncCompletionClient` - Async completion operations -- `AsyncWorkflowClient` - Async workflow operations -- `AsyncKnowledgeBaseClient` - Async dataset/knowledge base operations -- `AsyncWorkspaceClient` - Async workspace operations - -``` -``` diff --git a/sdks/python-client/build.sh b/sdks/python-client/build.sh deleted file mode 100755 index 525f57c1ef..0000000000 --- a/sdks/python-client/build.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -set -e - -rm -rf build dist *.egg-info - -pip install setuptools wheel twine -python setup.py sdist bdist_wheel -twine upload dist/* diff --git a/sdks/python-client/dify_client/__init__.py b/sdks/python-client/dify_client/__init__.py deleted file mode 100644 index ced093b20a..0000000000 --- a/sdks/python-client/dify_client/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -from dify_client.client import ( - ChatClient, - CompletionClient, - DifyClient, - KnowledgeBaseClient, - WorkflowClient, - WorkspaceClient, -) - -from dify_client.async_client import ( - AsyncChatClient, - AsyncCompletionClient, - AsyncDifyClient, - AsyncKnowledgeBaseClient, - AsyncWorkflowClient, - AsyncWorkspaceClient, -) - -__all__ = [ - # Synchronous clients - "ChatClient", - "CompletionClient", - "DifyClient", - "KnowledgeBaseClient", - "WorkflowClient", - "WorkspaceClient", - # Asynchronous clients - "AsyncChatClient", - "AsyncCompletionClient", - "AsyncDifyClient", - "AsyncKnowledgeBaseClient", - "AsyncWorkflowClient", - "AsyncWorkspaceClient", -] diff --git a/sdks/python-client/dify_client/async_client.py b/sdks/python-client/dify_client/async_client.py deleted file mode 100644 index 23126cf326..0000000000 --- a/sdks/python-client/dify_client/async_client.py +++ /dev/null @@ -1,2074 +0,0 @@ -"""Asynchronous Dify API client. - -This module provides async/await support for all Dify API operations using httpx.AsyncClient. -All client classes mirror their synchronous counterparts but require `await` for method calls. - -Example: - import asyncio - from dify_client import AsyncChatClient - - async def main(): - async with AsyncChatClient(api_key="your-key") as client: - response = await client.create_chat_message( - inputs={}, - query="Hello", - user="user-123" - ) - print(response.json()) - - asyncio.run(main()) -""" - -import json -import os -from typing import Literal, Dict, List, Any, IO, Optional, Union - -import aiofiles -import httpx - - -class AsyncDifyClient: - """Asynchronous Dify API client. - - This client uses httpx.AsyncClient for efficient async connection pooling. - It's recommended to use this client as a context manager: - - Example: - async with AsyncDifyClient(api_key="your-key") as client: - response = await client.get_app_info() - """ - - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - timeout: float = 60.0, - ): - """Initialize the async Dify client. - - Args: - api_key: Your Dify API key - base_url: Base URL for the Dify API - timeout: Request timeout in seconds (default: 60.0) - """ - self.api_key = api_key - self.base_url = base_url - self._client = httpx.AsyncClient( - base_url=base_url, - timeout=httpx.Timeout(timeout, connect=5.0), - ) - - async def __aenter__(self): - """Support async context manager protocol.""" - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Clean up resources when exiting async context.""" - await self.aclose() - - async def aclose(self): - """Close the async HTTP client and release resources.""" - if hasattr(self, "_client"): - await self._client.aclose() - - async def _send_request( - self, - method: str, - endpoint: str, - json: Dict | None = None, - params: Dict | None = None, - stream: bool = False, - **kwargs, - ): - """Send an async HTTP request to the Dify API. - - Args: - method: HTTP method (GET, POST, PUT, PATCH, DELETE) - endpoint: API endpoint path - json: JSON request body - params: Query parameters - stream: Whether to stream the response - **kwargs: Additional arguments to pass to httpx.request - - Returns: - httpx.Response object - """ - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } - - response = await self._client.request( - method, - endpoint, - json=json, - params=params, - headers=headers, - **kwargs, - ) - - return response - - async def _send_request_with_files(self, method: str, endpoint: str, data: dict, files: dict): - """Send an async HTTP request with file uploads. - - Args: - method: HTTP method (POST, PUT, etc.) - endpoint: API endpoint path - data: Form data - files: Files to upload - - Returns: - httpx.Response object - """ - headers = {"Authorization": f"Bearer {self.api_key}"} - - response = await self._client.request( - method, - endpoint, - data=data, - headers=headers, - files=files, - ) - - return response - - async def message_feedback(self, message_id: str, rating: Literal["like", "dislike"], user: str): - """Send feedback for a message.""" - data = {"rating": rating, "user": user} - return await self._send_request("POST", f"/messages/{message_id}/feedbacks", data) - - async def get_application_parameters(self, user: str): - """Get application parameters.""" - params = {"user": user} - return await self._send_request("GET", "/parameters", params=params) - - async def file_upload(self, user: str, files: dict): - """Upload a file.""" - data = {"user": user} - return await self._send_request_with_files("POST", "/files/upload", data=data, files=files) - - async def text_to_audio(self, text: str, user: str, streaming: bool = False): - """Convert text to audio.""" - data = {"text": text, "user": user, "streaming": streaming} - return await self._send_request("POST", "/text-to-audio", json=data) - - async def get_meta(self, user: str): - """Get metadata.""" - params = {"user": user} - return await self._send_request("GET", "/meta", params=params) - - async def get_app_info(self): - """Get basic application information including name, description, tags, and mode.""" - return await self._send_request("GET", "/info") - - async def get_app_site_info(self): - """Get application site information.""" - return await self._send_request("GET", "/site") - - async def get_file_preview(self, file_id: str): - """Get file preview by file ID.""" - return await self._send_request("GET", f"/files/{file_id}/preview") - - # App Configuration APIs - async def get_app_site_config(self, app_id: str): - """Get app site configuration. - - Args: - app_id: ID of the app - - Returns: - App site configuration - """ - url = f"/apps/{app_id}/site/config" - return await self._send_request("GET", url) - - async def update_app_site_config(self, app_id: str, config_data: Dict[str, Any]): - """Update app site configuration. - - Args: - app_id: ID of the app - config_data: Configuration data to update - - Returns: - Updated app site configuration - """ - url = f"/apps/{app_id}/site/config" - return await self._send_request("PUT", url, json=config_data) - - async def get_app_api_tokens(self, app_id: str): - """Get API tokens for an app. - - Args: - app_id: ID of the app - - Returns: - List of API tokens - """ - url = f"/apps/{app_id}/api-tokens" - return await self._send_request("GET", url) - - async def create_app_api_token(self, app_id: str, name: str, description: str | None = None): - """Create a new API token for an app. - - Args: - app_id: ID of the app - name: Name for the API token - description: Description for the API token (optional) - - Returns: - Created API token information - """ - data = {"name": name, "description": description} - url = f"/apps/{app_id}/api-tokens" - return await self._send_request("POST", url, json=data) - - async def delete_app_api_token(self, app_id: str, token_id: str): - """Delete an API token. - - Args: - app_id: ID of the app - token_id: ID of the token to delete - - Returns: - Deletion result - """ - url = f"/apps/{app_id}/api-tokens/{token_id}" - return await self._send_request("DELETE", url) - - -class AsyncCompletionClient(AsyncDifyClient): - """Async client for Completion API operations.""" - - async def create_completion_message( - self, - inputs: dict, - response_mode: Literal["blocking", "streaming"], - user: str, - files: Dict | None = None, - ): - """Create a completion message. - - Args: - inputs: Input variables for the completion - response_mode: Response mode ('blocking' or 'streaming') - user: User identifier - files: Optional files to include - - Returns: - httpx.Response object - """ - data = { - "inputs": inputs, - "response_mode": response_mode, - "user": user, - "files": files, - } - return await self._send_request( - "POST", - "/completion-messages", - data, - stream=(response_mode == "streaming"), - ) - - -class AsyncChatClient(AsyncDifyClient): - """Async client for Chat API operations.""" - - async def create_chat_message( - self, - inputs: dict, - query: str, - user: str, - response_mode: Literal["blocking", "streaming"] = "blocking", - conversation_id: str | None = None, - files: Dict | None = None, - ): - """Create a chat message. - - Args: - inputs: Input variables for the chat - query: User query/message - user: User identifier - response_mode: Response mode ('blocking' or 'streaming') - conversation_id: Optional conversation ID for context - files: Optional files to include - - Returns: - httpx.Response object - """ - data = { - "inputs": inputs, - "query": query, - "user": user, - "response_mode": response_mode, - "files": files, - } - if conversation_id: - data["conversation_id"] = conversation_id - - return await self._send_request( - "POST", - "/chat-messages", - data, - stream=(response_mode == "streaming"), - ) - - async def get_suggested(self, message_id: str, user: str): - """Get suggested questions for a message.""" - params = {"user": user} - return await self._send_request("GET", f"/messages/{message_id}/suggested", params=params) - - async def stop_message(self, task_id: str, user: str): - """Stop a running message generation.""" - data = {"user": user} - return await self._send_request("POST", f"/chat-messages/{task_id}/stop", data) - - async def get_conversations( - self, - user: str, - last_id: str | None = None, - limit: int | None = None, - pinned: bool | None = None, - ): - """Get list of conversations.""" - params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned} - return await self._send_request("GET", "/conversations", params=params) - - async def get_conversation_messages( - self, - user: str, - conversation_id: str | None = None, - first_id: str | None = None, - limit: int | None = None, - ): - """Get messages from a conversation.""" - params = { - "user": user, - "conversation_id": conversation_id, - "first_id": first_id, - "limit": limit, - } - return await self._send_request("GET", "/messages", params=params) - - async def rename_conversation(self, conversation_id: str, name: str, auto_generate: bool, user: str): - """Rename a conversation.""" - data = {"name": name, "auto_generate": auto_generate, "user": user} - return await self._send_request("POST", f"/conversations/{conversation_id}/name", data) - - async def delete_conversation(self, conversation_id: str, user: str): - """Delete a conversation.""" - data = {"user": user} - return await self._send_request("DELETE", f"/conversations/{conversation_id}", data) - - async def audio_to_text(self, audio_file: Union[IO[bytes], tuple], user: str): - """Convert audio to text.""" - data = {"user": user} - files = {"file": audio_file} - return await self._send_request_with_files("POST", "/audio-to-text", data, files) - - # Annotation APIs - async def annotation_reply_action( - self, - action: Literal["enable", "disable"], - score_threshold: float, - embedding_provider_name: str, - embedding_model_name: str, - ): - """Enable or disable annotation reply feature.""" - data = { - "score_threshold": score_threshold, - "embedding_provider_name": embedding_provider_name, - "embedding_model_name": embedding_model_name, - } - return await self._send_request("POST", f"/apps/annotation-reply/{action}", json=data) - - async def get_annotation_reply_status(self, action: Literal["enable", "disable"], job_id: str): - """Get the status of an annotation reply action job.""" - return await self._send_request("GET", f"/apps/annotation-reply/{action}/status/{job_id}") - - async def list_annotations(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations for the application.""" - params = {"page": page, "limit": limit, "keyword": keyword} - return await self._send_request("GET", "/apps/annotations", params=params) - - async def create_annotation(self, question: str, answer: str): - """Create a new annotation.""" - data = {"question": question, "answer": answer} - return await self._send_request("POST", "/apps/annotations", json=data) - - async def update_annotation(self, annotation_id: str, question: str, answer: str): - """Update an existing annotation.""" - data = {"question": question, "answer": answer} - return await self._send_request("PUT", f"/apps/annotations/{annotation_id}", json=data) - - async def delete_annotation(self, annotation_id: str): - """Delete an annotation.""" - return await self._send_request("DELETE", f"/apps/annotations/{annotation_id}") - - # Enhanced Annotation APIs - async def get_annotation_reply_job_status(self, action: str, job_id: str): - """Get status of an annotation reply action job.""" - url = f"/apps/annotation-reply/{action}/status/{job_id}" - return await self._send_request("GET", url) - - async def list_annotations_with_pagination(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations for application with pagination.""" - params = {"page": page, "limit": limit} - if keyword: - params["keyword"] = keyword - return await self._send_request("GET", "/apps/annotations", params=params) - - async def create_annotation_with_response(self, question: str, answer: str): - """Create a new annotation with full response handling.""" - data = {"question": question, "answer": answer} - return await self._send_request("POST", "/apps/annotations", json=data) - - async def update_annotation_with_response(self, annotation_id: str, question: str, answer: str): - """Update an existing annotation with full response handling.""" - data = {"question": question, "answer": answer} - url = f"/apps/annotations/{annotation_id}" - return await self._send_request("PUT", url, json=data) - - async def delete_annotation_with_response(self, annotation_id: str): - """Delete an annotation with full response handling.""" - url = f"/apps/annotations/{annotation_id}" - return await self._send_request("DELETE", url) - - # Conversation Variables APIs - async def get_conversation_variables(self, conversation_id: str, user: str): - """Get all variables for a specific conversation. - - Args: - conversation_id: The conversation ID to query variables for - user: User identifier - - Returns: - Response from the API containing: - - variables: List of conversation variables with their values - - conversation_id: The conversation ID - """ - params = {"user": user} - url = f"/conversations/{conversation_id}/variables" - return await self._send_request("GET", url, params=params) - - async def update_conversation_variable(self, conversation_id: str, variable_id: str, value: Any, user: str): - """Update a specific conversation variable. - - Args: - conversation_id: The conversation ID - variable_id: The variable ID to update - value: New value for the variable - user: User identifier - - Returns: - Response from the API with updated variable information - """ - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return await self._send_request("PATCH", url, json=data) - - # Enhanced Conversation Variable APIs - async def list_conversation_variables_with_pagination( - self, conversation_id: str, user: str, page: int = 1, limit: int = 20 - ): - """List conversation variables with pagination.""" - params = {"page": page, "limit": limit, "user": user} - url = f"/conversations/{conversation_id}/variables" - return await self._send_request("GET", url, params=params) - - async def update_conversation_variable_with_response( - self, conversation_id: str, variable_id: str, user: str, value: Any - ): - """Update a conversation variable with full response handling.""" - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return await self._send_request("PUT", url, data=data) - - # Additional annotation methods for API parity - async def get_annotation_reply_job_status(self, action: str, job_id: str): - """Get status of an annotation reply action job.""" - url = f"/apps/annotation-reply/{action}/status/{job_id}" - return await self._send_request("GET", url) - - async def list_annotations_with_pagination(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations for application with pagination.""" - params = {"page": page, "limit": limit} - if keyword: - params["keyword"] = keyword - return await self._send_request("GET", "/apps/annotations", params=params) - - async def create_annotation_with_response(self, question: str, answer: str): - """Create a new annotation with full response handling.""" - data = {"question": question, "answer": answer} - return await self._send_request("POST", "/apps/annotations", json=data) - - async def update_annotation_with_response(self, annotation_id: str, question: str, answer: str): - """Update an existing annotation with full response handling.""" - data = {"question": question, "answer": answer} - url = f"/apps/annotations/{annotation_id}" - return await self._send_request("PUT", url, json=data) - - async def delete_annotation_with_response(self, annotation_id: str): - """Delete an annotation with full response handling.""" - url = f"/apps/annotations/{annotation_id}" - return await self._send_request("DELETE", url) - - -class AsyncWorkflowClient(AsyncDifyClient): - """Async client for Workflow API operations.""" - - async def run( - self, - inputs: dict, - response_mode: Literal["blocking", "streaming"] = "streaming", - user: str = "abc-123", - ): - """Run a workflow.""" - data = {"inputs": inputs, "response_mode": response_mode, "user": user} - return await self._send_request("POST", "/workflows/run", data) - - async def stop(self, task_id: str, user: str): - """Stop a running workflow task.""" - data = {"user": user} - return await self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data) - - async def get_result(self, workflow_run_id: str): - """Get workflow run result.""" - return await self._send_request("GET", f"/workflows/run/{workflow_run_id}") - - async def get_workflow_logs( - self, - keyword: str = None, - status: Literal["succeeded", "failed", "stopped"] | None = None, - page: int = 1, - limit: int = 20, - created_at__before: str = None, - created_at__after: str = None, - created_by_end_user_session_id: str = None, - created_by_account: str = None, - ): - """Get workflow execution logs with optional filtering.""" - params = { - "page": page, - "limit": limit, - "keyword": keyword, - "status": status, - "created_at__before": created_at__before, - "created_at__after": created_at__after, - "created_by_end_user_session_id": created_by_end_user_session_id, - "created_by_account": created_by_account, - } - return await self._send_request("GET", "/workflows/logs", params=params) - - async def run_specific_workflow( - self, - workflow_id: str, - inputs: dict, - response_mode: Literal["blocking", "streaming"] = "streaming", - user: str = "abc-123", - ): - """Run a specific workflow by workflow ID.""" - data = {"inputs": inputs, "response_mode": response_mode, "user": user} - return await self._send_request( - "POST", - f"/workflows/{workflow_id}/run", - data, - stream=(response_mode == "streaming"), - ) - - # Enhanced Workflow APIs - async def get_workflow_draft(self, app_id: str): - """Get workflow draft configuration. - - Args: - app_id: ID of the workflow app - - Returns: - Workflow draft configuration - """ - url = f"/apps/{app_id}/workflow/draft" - return await self._send_request("GET", url) - - async def update_workflow_draft(self, app_id: str, workflow_data: Dict[str, Any]): - """Update workflow draft configuration. - - Args: - app_id: ID of the workflow app - workflow_data: Workflow configuration data - - Returns: - Updated workflow draft - """ - url = f"/apps/{app_id}/workflow/draft" - return await self._send_request("PUT", url, json=workflow_data) - - async def publish_workflow(self, app_id: str): - """Publish workflow from draft. - - Args: - app_id: ID of the workflow app - - Returns: - Published workflow information - """ - url = f"/apps/{app_id}/workflow/publish" - return await self._send_request("POST", url) - - async def get_workflow_run_history( - self, - app_id: str, - page: int = 1, - limit: int = 20, - status: Literal["succeeded", "failed", "stopped"] | None = None, - ): - """Get workflow run history. - - Args: - app_id: ID of the workflow app - page: Page number (default: 1) - limit: Number of items per page (default: 20) - status: Filter by status (optional) - - Returns: - Paginated workflow run history - """ - params = {"page": page, "limit": limit} - if status: - params["status"] = status - url = f"/apps/{app_id}/workflow/runs" - return await self._send_request("GET", url, params=params) - - -class AsyncWorkspaceClient(AsyncDifyClient): - """Async client for workspace-related operations.""" - - async def get_available_models(self, model_type: str): - """Get available models by model type.""" - url = f"/workspaces/current/models/model-types/{model_type}" - return await self._send_request("GET", url) - - async def get_available_models_by_type(self, model_type: str): - """Get available models by model type (enhanced version).""" - url = f"/workspaces/current/models/model-types/{model_type}" - return await self._send_request("GET", url) - - async def get_model_providers(self): - """Get all model providers.""" - return await self._send_request("GET", "/workspaces/current/model-providers") - - async def get_model_provider_models(self, provider_name: str): - """Get models for a specific provider.""" - url = f"/workspaces/current/model-providers/{provider_name}/models" - return await self._send_request("GET", url) - - async def validate_model_provider_credentials(self, provider_name: str, credentials: Dict[str, Any]): - """Validate model provider credentials.""" - url = f"/workspaces/current/model-providers/{provider_name}/credentials/validate" - return await self._send_request("POST", url, json=credentials) - - # File Management APIs - async def get_file_info(self, file_id: str): - """Get information about a specific file.""" - url = f"/files/{file_id}/info" - return await self._send_request("GET", url) - - async def get_file_download_url(self, file_id: str): - """Get download URL for a file.""" - url = f"/files/{file_id}/download-url" - return await self._send_request("GET", url) - - async def delete_file(self, file_id: str): - """Delete a file.""" - url = f"/files/{file_id}" - return await self._send_request("DELETE", url) - - -class AsyncKnowledgeBaseClient(AsyncDifyClient): - """Async client for Knowledge Base API operations.""" - - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - dataset_id: str | None = None, - timeout: float = 60.0, - ): - """Construct an AsyncKnowledgeBaseClient object. - - Args: - api_key: API key of Dify - base_url: Base URL of Dify API - dataset_id: ID of the dataset - timeout: Request timeout in seconds - """ - super().__init__(api_key=api_key, base_url=base_url, timeout=timeout) - self.dataset_id = dataset_id - - def _get_dataset_id(self): - """Get the dataset ID, raise error if not set.""" - if self.dataset_id is None: - raise ValueError("dataset_id is not set") - return self.dataset_id - - async def create_dataset(self, name: str, **kwargs): - """Create a new dataset.""" - return await self._send_request("POST", "/datasets", {"name": name}, **kwargs) - - async def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): - """List all datasets.""" - return await self._send_request("GET", "/datasets", params={"page": page, "limit": page_size}, **kwargs) - - async def create_document_by_text(self, name: str, text: str, extra_params: Dict | None = None, **kwargs): - """Create a document by text. - - Args: - name: Name of the document - text: Text content of the document - extra_params: Extra parameters for the API - - Returns: - Response from the API - """ - data = { - "indexing_technique": "high_quality", - "process_rule": {"mode": "automatic"}, - "name": name, - "text": text, - } - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/document/create_by_text" - return await self._send_request("POST", url, json=data, **kwargs) - - async def update_document_by_text( - self, - document_id: str, - name: str, - text: str, - extra_params: Dict | None = None, - **kwargs, - ): - """Update a document by text.""" - data = {"name": name, "text": text} - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text" - return await self._send_request("POST", url, json=data, **kwargs) - - async def create_document_by_file( - self, - file_path: str, - original_document_id: str | None = None, - extra_params: Dict | None = None, - ): - """Create a document by file.""" - async with aiofiles.open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - data = { - "process_rule": {"mode": "automatic"}, - "indexing_technique": "high_quality", - } - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - if original_document_id is not None: - data["original_document_id"] = original_document_id - url = f"/datasets/{self._get_dataset_id()}/document/create_by_file" - return await self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) - - async def update_document_by_file(self, document_id: str, file_path: str, extra_params: Dict | None = None): - """Update a document by file.""" - async with aiofiles.open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - data = {} - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file" - return await self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) - - async def batch_indexing_status(self, batch_id: str, **kwargs): - """Get the status of the batch indexing.""" - url = f"/datasets/{self._get_dataset_id()}/documents/{batch_id}/indexing-status" - return await self._send_request("GET", url, **kwargs) - - async def delete_dataset(self): - """Delete this dataset.""" - url = f"/datasets/{self._get_dataset_id()}" - return await self._send_request("DELETE", url) - - async def delete_document(self, document_id: str): - """Delete a document.""" - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}" - return await self._send_request("DELETE", url) - - async def list_documents( - self, - page: int | None = None, - page_size: int | None = None, - keyword: str | None = None, - **kwargs, - ): - """Get a list of documents in this dataset.""" - params = { - "page": page, - "limit": page_size, - "keyword": keyword, - } - url = f"/datasets/{self._get_dataset_id()}/documents" - return await self._send_request("GET", url, params=params, **kwargs) - - async def add_segments(self, document_id: str, segments: list[dict], **kwargs): - """Add segments to a document.""" - data = {"segments": segments} - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" - return await self._send_request("POST", url, json=data, **kwargs) - - async def query_segments( - self, - document_id: str, - keyword: str | None = None, - status: str | None = None, - **kwargs, - ): - """Query segments in this document. - - Args: - document_id: ID of the document - keyword: Query keyword (optional) - status: Status of the segment (optional, e.g., 'completed') - **kwargs: Additional parameters to pass to the API. - Can include a 'params' dict for extra query parameters. - - Returns: - Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" - params = { - "keyword": keyword, - "status": status, - } - if "params" in kwargs: - params.update(kwargs.pop("params")) - return await self._send_request("GET", url, params=params, **kwargs) - - async def delete_document_segment(self, document_id: str, segment_id: str): - """Delete a segment from a document.""" - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" - return await self._send_request("DELETE", url) - - async def update_document_segment(self, document_id: str, segment_id: str, segment_data: dict, **kwargs): - """Update a segment in a document.""" - data = {"segment": segment_data} - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" - return await self._send_request("POST", url, json=data, **kwargs) - - # Advanced Knowledge Base APIs - async def hit_testing( - self, - query: str, - retrieval_model: Dict[str, Any] = None, - external_retrieval_model: Dict[str, Any] = None, - ): - """Perform hit testing on the dataset.""" - data = {"query": query} - if retrieval_model: - data["retrieval_model"] = retrieval_model - if external_retrieval_model: - data["external_retrieval_model"] = external_retrieval_model - url = f"/datasets/{self._get_dataset_id()}/hit-testing" - return await self._send_request("POST", url, json=data) - - async def get_dataset_metadata(self): - """Get dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata" - return await self._send_request("GET", url) - - async def create_dataset_metadata(self, metadata_data: Dict[str, Any]): - """Create dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata" - return await self._send_request("POST", url, json=metadata_data) - - async def update_dataset_metadata(self, metadata_id: str, metadata_data: Dict[str, Any]): - """Update dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata/{metadata_id}" - return await self._send_request("PATCH", url, json=metadata_data) - - async def get_built_in_metadata(self): - """Get built-in metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata/built-in" - return await self._send_request("GET", url) - - async def manage_built_in_metadata(self, action: str, metadata_data: Dict[str, Any] = None): - """Manage built-in metadata with specified action.""" - data = metadata_data or {} - url = f"/datasets/{self._get_dataset_id()}/metadata/built-in/{action}" - return await self._send_request("POST", url, json=data) - - async def update_documents_metadata(self, operation_data: List[Dict[str, Any]]): - """Update metadata for multiple documents.""" - url = f"/datasets/{self._get_dataset_id()}/documents/metadata" - data = {"operation_data": operation_data} - return await self._send_request("POST", url, json=data) - - # Dataset Tags APIs - async def list_dataset_tags(self): - """List all dataset tags.""" - return await self._send_request("GET", "/datasets/tags") - - async def bind_dataset_tags(self, tag_ids: List[str]): - """Bind tags to dataset.""" - data = {"tag_ids": tag_ids, "target_id": self._get_dataset_id()} - return await self._send_request("POST", "/datasets/tags/binding", json=data) - - async def unbind_dataset_tag(self, tag_id: str): - """Unbind a single tag from dataset.""" - data = {"tag_id": tag_id, "target_id": self._get_dataset_id()} - return await self._send_request("POST", "/datasets/tags/unbinding", json=data) - - async def get_dataset_tags(self): - """Get tags for current dataset.""" - url = f"/datasets/{self._get_dataset_id()}/tags" - return await self._send_request("GET", url) - - # RAG Pipeline APIs - async def get_datasource_plugins(self, is_published: bool = True): - """Get datasource plugins for RAG pipeline.""" - params = {"is_published": is_published} - url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource-plugins" - return await self._send_request("GET", url, params=params) - - async def run_datasource_node( - self, - node_id: str, - inputs: Dict[str, Any], - datasource_type: str, - is_published: bool = True, - credential_id: str = None, - ): - """Run a datasource node in RAG pipeline.""" - data = { - "inputs": inputs, - "datasource_type": datasource_type, - "is_published": is_published, - } - if credential_id: - data["credential_id"] = credential_id - url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource/nodes/{node_id}/run" - return await self._send_request("POST", url, json=data, stream=True) - - async def run_rag_pipeline( - self, - inputs: Dict[str, Any], - datasource_type: str, - datasource_info_list: List[Dict[str, Any]], - start_node_id: str, - is_published: bool = True, - response_mode: Literal["streaming", "blocking"] = "blocking", - ): - """Run RAG pipeline.""" - data = { - "inputs": inputs, - "datasource_type": datasource_type, - "datasource_info_list": datasource_info_list, - "start_node_id": start_node_id, - "is_published": is_published, - "response_mode": response_mode, - } - url = f"/datasets/{self._get_dataset_id()}/pipeline/run" - return await self._send_request("POST", url, json=data, stream=response_mode == "streaming") - - async def upload_pipeline_file(self, file_path: str): - """Upload file for RAG pipeline.""" - async with aiofiles.open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - return await self._send_request_with_files("POST", "/datasets/pipeline/file-upload", {}, files) - - # Dataset Management APIs - async def get_dataset(self, dataset_id: str | None = None): - """Get detailed information about a specific dataset.""" - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}" - return await self._send_request("GET", url) - - async def update_dataset( - self, - dataset_id: str | None = None, - name: str | None = None, - description: str | None = None, - indexing_technique: str | None = None, - embedding_model: str | None = None, - embedding_model_provider: str | None = None, - retrieval_model: Dict[str, Any] | None = None, - **kwargs, - ): - """Update dataset configuration. - - Args: - dataset_id: Dataset ID (optional, uses current dataset_id if not provided) - name: New dataset name - description: New dataset description - indexing_technique: Indexing technique ('high_quality' or 'economy') - embedding_model: Embedding model name - embedding_model_provider: Embedding model provider - retrieval_model: Retrieval model configuration dict - **kwargs: Additional parameters to pass to the API - - Returns: - Response from the API with updated dataset information - """ - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}" - - payload = { - "name": name, - "description": description, - "indexing_technique": indexing_technique, - "embedding_model": embedding_model, - "embedding_model_provider": embedding_model_provider, - "retrieval_model": retrieval_model, - } - - data = {k: v for k, v in payload.items() if v is not None} - data.update(kwargs) - - return await self._send_request("PATCH", url, json=data) - - async def batch_update_document_status( - self, - action: Literal["enable", "disable", "archive", "un_archive"], - document_ids: List[str], - dataset_id: str | None = None, - ): - """Batch update document status.""" - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}/documents/status/{action}" - data = {"document_ids": document_ids} - return await self._send_request("PATCH", url, json=data) - - # Enhanced Dataset APIs - - async def create_dataset_from_template(self, template_name: str, name: str, description: str | None = None): - """Create a dataset from a predefined template. - - Args: - template_name: Name of the template to use - name: Name for the new dataset - description: Description for the dataset (optional) - - Returns: - Created dataset information - """ - data = { - "template_name": template_name, - "name": name, - "description": description, - } - return await self._send_request("POST", "/datasets/from-template", json=data) - - async def duplicate_dataset(self, dataset_id: str, name: str): - """Duplicate an existing dataset. - - Args: - dataset_id: ID of dataset to duplicate - name: Name for duplicated dataset - - Returns: - New dataset information - """ - data = {"name": name} - url = f"/datasets/{dataset_id}/duplicate" - return await self._send_request("POST", url, json=data) - - async def update_conversation_variable_with_response( - self, conversation_id: str, variable_id: str, user: str, value: Any - ): - """Update a conversation variable with full response handling.""" - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return await self._send_request("PUT", url, json=data) - - async def list_conversation_variables_with_pagination( - self, conversation_id: str, user: str, page: int = 1, limit: int = 20 - ): - """List conversation variables with pagination.""" - params = {"page": page, "limit": limit, "user": user} - url = f"/conversations/{conversation_id}/variables" - return await self._send_request("GET", url, params=params) - - -class AsyncEnterpriseClient(AsyncDifyClient): - """Async Enterprise and Account Management APIs for Dify platform administration.""" - - async def get_account_info(self): - """Get current account information.""" - return await self._send_request("GET", "/account") - - async def update_account_info(self, account_data: Dict[str, Any]): - """Update account information.""" - return await self._send_request("PUT", "/account", json=account_data) - - # Member Management APIs - async def list_members(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List workspace members with pagination.""" - params = {"page": page, "limit": limit} - if keyword: - params["keyword"] = keyword - return await self._send_request("GET", "/members", params=params) - - async def invite_member(self, email: str, role: str, name: str | None = None): - """Invite a new member to the workspace.""" - data = {"email": email, "role": role} - if name: - data["name"] = name - return await self._send_request("POST", "/members/invite", json=data) - - async def get_member(self, member_id: str): - """Get detailed information about a specific member.""" - url = f"/members/{member_id}" - return await self._send_request("GET", url) - - async def update_member(self, member_id: str, member_data: Dict[str, Any]): - """Update member information.""" - url = f"/members/{member_id}" - return await self._send_request("PUT", url, json=member_data) - - async def remove_member(self, member_id: str): - """Remove a member from the workspace.""" - url = f"/members/{member_id}" - return await self._send_request("DELETE", url) - - async def deactivate_member(self, member_id: str): - """Deactivate a member account.""" - url = f"/members/{member_id}/deactivate" - return await self._send_request("POST", url) - - async def reactivate_member(self, member_id: str): - """Reactivate a deactivated member account.""" - url = f"/members/{member_id}/reactivate" - return await self._send_request("POST", url) - - # Role Management APIs - async def list_roles(self): - """List all available roles in the workspace.""" - return await self._send_request("GET", "/roles") - - async def create_role(self, name: str, description: str, permissions: List[str]): - """Create a new role with specified permissions.""" - data = {"name": name, "description": description, "permissions": permissions} - return await self._send_request("POST", "/roles", json=data) - - async def get_role(self, role_id: str): - """Get detailed information about a specific role.""" - url = f"/roles/{role_id}" - return await self._send_request("GET", url) - - async def update_role(self, role_id: str, role_data: Dict[str, Any]): - """Update role information.""" - url = f"/roles/{role_id}" - return await self._send_request("PUT", url, json=role_data) - - async def delete_role(self, role_id: str): - """Delete a role.""" - url = f"/roles/{role_id}" - return await self._send_request("DELETE", url) - - # Permission Management APIs - async def list_permissions(self): - """List all available permissions.""" - return await self._send_request("GET", "/permissions") - - async def get_role_permissions(self, role_id: str): - """Get permissions for a specific role.""" - url = f"/roles/{role_id}/permissions" - return await self._send_request("GET", url) - - async def update_role_permissions(self, role_id: str, permissions: List[str]): - """Update permissions for a role.""" - url = f"/roles/{role_id}/permissions" - data = {"permissions": permissions} - return await self._send_request("PUT", url, json=data) - - # Workspace Settings APIs - async def get_workspace_settings(self): - """Get workspace settings and configuration.""" - return await self._send_request("GET", "/workspace/settings") - - async def update_workspace_settings(self, settings_data: Dict[str, Any]): - """Update workspace settings.""" - return await self._send_request("PUT", "/workspace/settings", json=settings_data) - - async def get_workspace_statistics(self): - """Get workspace usage statistics.""" - return await self._send_request("GET", "/workspace/statistics") - - # Billing and Subscription APIs - async def get_billing_info(self): - """Get current billing information.""" - return await self._send_request("GET", "/billing") - - async def get_subscription_info(self): - """Get current subscription information.""" - return await self._send_request("GET", "/subscription") - - async def update_subscription(self, subscription_data: Dict[str, Any]): - """Update subscription settings.""" - return await self._send_request("PUT", "/subscription", json=subscription_data) - - async def get_billing_history(self, page: int = 1, limit: int = 20): - """Get billing history with pagination.""" - params = {"page": page, "limit": limit} - return await self._send_request("GET", "/billing/history", params=params) - - async def get_usage_metrics(self, start_date: str, end_date: str, metric_type: str | None = None): - """Get usage metrics for a date range.""" - params = {"start_date": start_date, "end_date": end_date} - if metric_type: - params["metric_type"] = metric_type - return await self._send_request("GET", "/usage/metrics", params=params) - - # Audit Logs APIs - async def get_audit_logs( - self, - page: int = 1, - limit: int = 20, - action: str | None = None, - user_id: str | None = None, - start_date: str | None = None, - end_date: str | None = None, - ): - """Get audit logs with filtering options.""" - params = {"page": page, "limit": limit} - if action: - params["action"] = action - if user_id: - params["user_id"] = user_id - if start_date: - params["start_date"] = start_date - if end_date: - params["end_date"] = end_date - return await self._send_request("GET", "/audit/logs", params=params) - - async def export_audit_logs(self, format: str = "csv", filters: Dict[str, Any] | None = None): - """Export audit logs in specified format.""" - params = {"format": format} - if filters: - params.update(filters) - return await self._send_request("GET", "/audit/logs/export", params=params) - - -class AsyncSecurityClient(AsyncDifyClient): - """Async Security and Access Control APIs for Dify platform security management.""" - - # API Key Management APIs - async def list_api_keys(self, page: int = 1, limit: int = 20, status: str | None = None): - """List all API keys with pagination and filtering.""" - params = {"page": page, "limit": limit} - if status: - params["status"] = status - return await self._send_request("GET", "/security/api-keys", params=params) - - async def create_api_key( - self, - name: str, - permissions: List[str], - expires_at: str | None = None, - description: str | None = None, - ): - """Create a new API key with specified permissions.""" - data = {"name": name, "permissions": permissions} - if expires_at: - data["expires_at"] = expires_at - if description: - data["description"] = description - return await self._send_request("POST", "/security/api-keys", json=data) - - async def get_api_key(self, key_id: str): - """Get detailed information about an API key.""" - url = f"/security/api-keys/{key_id}" - return await self._send_request("GET", url) - - async def update_api_key(self, key_id: str, key_data: Dict[str, Any]): - """Update API key information.""" - url = f"/security/api-keys/{key_id}" - return await self._send_request("PUT", url, json=key_data) - - async def revoke_api_key(self, key_id: str): - """Revoke an API key.""" - url = f"/security/api-keys/{key_id}/revoke" - return await self._send_request("POST", url) - - async def rotate_api_key(self, key_id: str): - """Rotate an API key (generate new key).""" - url = f"/security/api-keys/{key_id}/rotate" - return await self._send_request("POST", url) - - # Rate Limiting APIs - async def get_rate_limits(self): - """Get current rate limiting configuration.""" - return await self._send_request("GET", "/security/rate-limits") - - async def update_rate_limits(self, limits_config: Dict[str, Any]): - """Update rate limiting configuration.""" - return await self._send_request("PUT", "/security/rate-limits", json=limits_config) - - async def get_rate_limit_usage(self, timeframe: str = "1h"): - """Get rate limit usage statistics.""" - params = {"timeframe": timeframe} - return await self._send_request("GET", "/security/rate-limits/usage", params=params) - - # Access Control Lists APIs - async def list_access_policies(self, page: int = 1, limit: int = 20): - """List access control policies.""" - params = {"page": page, "limit": limit} - return await self._send_request("GET", "/security/access-policies", params=params) - - async def create_access_policy(self, policy_data: Dict[str, Any]): - """Create a new access control policy.""" - return await self._send_request("POST", "/security/access-policies", json=policy_data) - - async def get_access_policy(self, policy_id: str): - """Get detailed information about an access policy.""" - url = f"/security/access-policies/{policy_id}" - return await self._send_request("GET", url) - - async def update_access_policy(self, policy_id: str, policy_data: Dict[str, Any]): - """Update an access control policy.""" - url = f"/security/access-policies/{policy_id}" - return await self._send_request("PUT", url, json=policy_data) - - async def delete_access_policy(self, policy_id: str): - """Delete an access control policy.""" - url = f"/security/access-policies/{policy_id}" - return await self._send_request("DELETE", url) - - # Security Settings APIs - async def get_security_settings(self): - """Get security configuration settings.""" - return await self._send_request("GET", "/security/settings") - - async def update_security_settings(self, settings_data: Dict[str, Any]): - """Update security configuration settings.""" - return await self._send_request("PUT", "/security/settings", json=settings_data) - - async def get_security_audit_logs( - self, - page: int = 1, - limit: int = 20, - event_type: str | None = None, - start_date: str | None = None, - end_date: str | None = None, - ): - """Get security-specific audit logs.""" - params = {"page": page, "limit": limit} - if event_type: - params["event_type"] = event_type - if start_date: - params["start_date"] = start_date - if end_date: - params["end_date"] = end_date - return await self._send_request("GET", "/security/audit-logs", params=params) - - # IP Whitelist/Blacklist APIs - async def get_ip_whitelist(self): - """Get IP whitelist configuration.""" - return await self._send_request("GET", "/security/ip-whitelist") - - async def update_ip_whitelist(self, ip_list: List[str], description: str | None = None): - """Update IP whitelist configuration.""" - data = {"ip_list": ip_list} - if description: - data["description"] = description - return await self._send_request("PUT", "/security/ip-whitelist", json=data) - - async def get_ip_blacklist(self): - """Get IP blacklist configuration.""" - return await self._send_request("GET", "/security/ip-blacklist") - - async def update_ip_blacklist(self, ip_list: List[str], description: str | None = None): - """Update IP blacklist configuration.""" - data = {"ip_list": ip_list} - if description: - data["description"] = description - return await self._send_request("PUT", "/security/ip-blacklist", json=data) - - # Authentication Settings APIs - async def get_auth_settings(self): - """Get authentication configuration settings.""" - return await self._send_request("GET", "/security/auth-settings") - - async def update_auth_settings(self, auth_data: Dict[str, Any]): - """Update authentication configuration settings.""" - return await self._send_request("PUT", "/security/auth-settings", json=auth_data) - - async def test_auth_configuration(self, auth_config: Dict[str, Any]): - """Test authentication configuration.""" - return await self._send_request("POST", "/security/auth-settings/test", json=auth_config) - - -class AsyncAnalyticsClient(AsyncDifyClient): - """Async Analytics and Monitoring APIs for Dify platform insights and metrics.""" - - # Usage Analytics APIs - async def get_usage_analytics( - self, - start_date: str, - end_date: str, - granularity: str = "day", - metrics: List[str] | None = None, - ): - """Get usage analytics for specified date range.""" - params = { - "start_date": start_date, - "end_date": end_date, - "granularity": granularity, - } - if metrics: - params["metrics"] = ",".join(metrics) - return await self._send_request("GET", "/analytics/usage", params=params) - - async def get_app_usage_analytics(self, app_id: str, start_date: str, end_date: str, granularity: str = "day"): - """Get usage analytics for a specific app.""" - params = { - "start_date": start_date, - "end_date": end_date, - "granularity": granularity, - } - url = f"/analytics/apps/{app_id}/usage" - return await self._send_request("GET", url, params=params) - - async def get_user_analytics(self, start_date: str, end_date: str, user_segment: str | None = None): - """Get user analytics and behavior insights.""" - params = {"start_date": start_date, "end_date": end_date} - if user_segment: - params["user_segment"] = user_segment - return await self._send_request("GET", "/analytics/users", params=params) - - # Performance Metrics APIs - async def get_performance_metrics(self, start_date: str, end_date: str, metric_type: str | None = None): - """Get performance metrics for the platform.""" - params = {"start_date": start_date, "end_date": end_date} - if metric_type: - params["metric_type"] = metric_type - return await self._send_request("GET", "/analytics/performance", params=params) - - async def get_app_performance_metrics(self, app_id: str, start_date: str, end_date: str): - """Get performance metrics for a specific app.""" - params = {"start_date": start_date, "end_date": end_date} - url = f"/analytics/apps/{app_id}/performance" - return await self._send_request("GET", url, params=params) - - async def get_model_performance_metrics(self, model_provider: str, model_name: str, start_date: str, end_date: str): - """Get performance metrics for a specific model.""" - params = {"start_date": start_date, "end_date": end_date} - url = f"/analytics/models/{model_provider}/{model_name}/performance" - return await self._send_request("GET", url, params=params) - - # Cost Tracking APIs - async def get_cost_analytics(self, start_date: str, end_date: str, cost_type: str | None = None): - """Get cost analytics and breakdown.""" - params = {"start_date": start_date, "end_date": end_date} - if cost_type: - params["cost_type"] = cost_type - return await self._send_request("GET", "/analytics/costs", params=params) - - async def get_app_cost_analytics(self, app_id: str, start_date: str, end_date: str): - """Get cost analytics for a specific app.""" - params = {"start_date": start_date, "end_date": end_date} - url = f"/analytics/apps/{app_id}/costs" - return await self._send_request("GET", url, params=params) - - async def get_cost_forecast(self, forecast_period: str = "30d"): - """Get cost forecast for specified period.""" - params = {"forecast_period": forecast_period} - return await self._send_request("GET", "/analytics/costs/forecast", params=params) - - # Real-time Monitoring APIs - async def get_real_time_metrics(self): - """Get real-time platform metrics.""" - return await self._send_request("GET", "/analytics/realtime") - - async def get_app_real_time_metrics(self, app_id: str): - """Get real-time metrics for a specific app.""" - url = f"/analytics/apps/{app_id}/realtime" - return await self._send_request("GET", url) - - async def get_system_health(self): - """Get overall system health status.""" - return await self._send_request("GET", "/analytics/health") - - # Custom Reports APIs - async def create_custom_report(self, report_config: Dict[str, Any]): - """Create a custom analytics report.""" - return await self._send_request("POST", "/analytics/reports", json=report_config) - - async def list_custom_reports(self, page: int = 1, limit: int = 20): - """List custom analytics reports.""" - params = {"page": page, "limit": limit} - return await self._send_request("GET", "/analytics/reports", params=params) - - async def get_custom_report(self, report_id: str): - """Get a specific custom report.""" - url = f"/analytics/reports/{report_id}" - return await self._send_request("GET", url) - - async def update_custom_report(self, report_id: str, report_config: Dict[str, Any]): - """Update a custom analytics report.""" - url = f"/analytics/reports/{report_id}" - return await self._send_request("PUT", url, json=report_config) - - async def delete_custom_report(self, report_id: str): - """Delete a custom analytics report.""" - url = f"/analytics/reports/{report_id}" - return await self._send_request("DELETE", url) - - async def generate_report(self, report_id: str, format: str = "pdf"): - """Generate and download a custom report.""" - params = {"format": format} - url = f"/analytics/reports/{report_id}/generate" - return await self._send_request("GET", url, params=params) - - # Export APIs - async def export_analytics_data(self, data_type: str, start_date: str, end_date: str, format: str = "csv"): - """Export analytics data in specified format.""" - params = { - "data_type": data_type, - "start_date": start_date, - "end_date": end_date, - "format": format, - } - return await self._send_request("GET", "/analytics/export", params=params) - - -class AsyncIntegrationClient(AsyncDifyClient): - """Async Integration and Plugin APIs for Dify platform extensibility.""" - - # Webhook Management APIs - async def list_webhooks(self, page: int = 1, limit: int = 20, status: str | None = None): - """List webhooks with pagination and filtering.""" - params = {"page": page, "limit": limit} - if status: - params["status"] = status - return await self._send_request("GET", "/integrations/webhooks", params=params) - - async def create_webhook(self, webhook_data: Dict[str, Any]): - """Create a new webhook.""" - return await self._send_request("POST", "/integrations/webhooks", json=webhook_data) - - async def get_webhook(self, webhook_id: str): - """Get detailed information about a webhook.""" - url = f"/integrations/webhooks/{webhook_id}" - return await self._send_request("GET", url) - - async def update_webhook(self, webhook_id: str, webhook_data: Dict[str, Any]): - """Update webhook configuration.""" - url = f"/integrations/webhooks/{webhook_id}" - return await self._send_request("PUT", url, json=webhook_data) - - async def delete_webhook(self, webhook_id: str): - """Delete a webhook.""" - url = f"/integrations/webhooks/{webhook_id}" - return await self._send_request("DELETE", url) - - async def test_webhook(self, webhook_id: str): - """Test webhook delivery.""" - url = f"/integrations/webhooks/{webhook_id}/test" - return await self._send_request("POST", url) - - async def get_webhook_logs(self, webhook_id: str, page: int = 1, limit: int = 20): - """Get webhook delivery logs.""" - params = {"page": page, "limit": limit} - url = f"/integrations/webhooks/{webhook_id}/logs" - return await self._send_request("GET", url, params=params) - - # Plugin Management APIs - async def list_plugins(self, page: int = 1, limit: int = 20, category: str | None = None): - """List available plugins.""" - params = {"page": page, "limit": limit} - if category: - params["category"] = category - return await self._send_request("GET", "/integrations/plugins", params=params) - - async def install_plugin(self, plugin_id: str, config: Dict[str, Any] | None = None): - """Install a plugin.""" - data = {"plugin_id": plugin_id} - if config: - data["config"] = config - return await self._send_request("POST", "/integrations/plugins/install", json=data) - - async def get_installed_plugin(self, installation_id: str): - """Get information about an installed plugin.""" - url = f"/integrations/plugins/{installation_id}" - return await self._send_request("GET", url) - - async def update_plugin_config(self, installation_id: str, config: Dict[str, Any]): - """Update plugin configuration.""" - url = f"/integrations/plugins/{installation_id}/config" - return await self._send_request("PUT", url, json=config) - - async def uninstall_plugin(self, installation_id: str): - """Uninstall a plugin.""" - url = f"/integrations/plugins/{installation_id}" - return await self._send_request("DELETE", url) - - async def enable_plugin(self, installation_id: str): - """Enable a plugin.""" - url = f"/integrations/plugins/{installation_id}/enable" - return await self._send_request("POST", url) - - async def disable_plugin(self, installation_id: str): - """Disable a plugin.""" - url = f"/integrations/plugins/{installation_id}/disable" - return await self._send_request("POST", url) - - # Import/Export APIs - async def export_app_data(self, app_id: str, format: str = "json", include_data: bool = True): - """Export application data.""" - params = {"format": format, "include_data": include_data} - url = f"/integrations/export/apps/{app_id}" - return await self._send_request("GET", url, params=params) - - async def import_app_data(self, import_data: Dict[str, Any]): - """Import application data.""" - return await self._send_request("POST", "/integrations/import/apps", json=import_data) - - async def get_import_status(self, import_id: str): - """Get import operation status.""" - url = f"/integrations/import/{import_id}/status" - return await self._send_request("GET", url) - - async def export_workspace_data(self, format: str = "json", include_data: bool = True): - """Export workspace data.""" - params = {"format": format, "include_data": include_data} - return await self._send_request("GET", "/integrations/export/workspace", params=params) - - async def import_workspace_data(self, import_data: Dict[str, Any]): - """Import workspace data.""" - return await self._send_request("POST", "/integrations/import/workspace", json=import_data) - - # Backup and Restore APIs - async def create_backup(self, backup_config: Dict[str, Any] | None = None): - """Create a system backup.""" - data = backup_config or {} - return await self._send_request("POST", "/integrations/backup/create", json=data) - - async def list_backups(self, page: int = 1, limit: int = 20): - """List available backups.""" - params = {"page": page, "limit": limit} - return await self._send_request("GET", "/integrations/backup", params=params) - - async def get_backup(self, backup_id: str): - """Get backup information.""" - url = f"/integrations/backup/{backup_id}" - return await self._send_request("GET", url) - - async def restore_backup(self, backup_id: str, restore_config: Dict[str, Any] | None = None): - """Restore from backup.""" - data = restore_config or {} - url = f"/integrations/backup/{backup_id}/restore" - return await self._send_request("POST", url, json=data) - - async def delete_backup(self, backup_id: str): - """Delete a backup.""" - url = f"/integrations/backup/{backup_id}" - return await self._send_request("DELETE", url) - - -class AsyncAdvancedModelClient(AsyncDifyClient): - """Async Advanced Model Management APIs for fine-tuning and custom deployments.""" - - # Fine-tuning Job Management APIs - async def list_fine_tuning_jobs( - self, - page: int = 1, - limit: int = 20, - status: str | None = None, - model_provider: str | None = None, - ): - """List fine-tuning jobs with filtering.""" - params = {"page": page, "limit": limit} - if status: - params["status"] = status - if model_provider: - params["model_provider"] = model_provider - return await self._send_request("GET", "/models/fine-tuning/jobs", params=params) - - async def create_fine_tuning_job(self, job_config: Dict[str, Any]): - """Create a new fine-tuning job.""" - return await self._send_request("POST", "/models/fine-tuning/jobs", json=job_config) - - async def get_fine_tuning_job(self, job_id: str): - """Get fine-tuning job details.""" - url = f"/models/fine-tuning/jobs/{job_id}" - return await self._send_request("GET", url) - - async def update_fine_tuning_job(self, job_id: str, job_config: Dict[str, Any]): - """Update fine-tuning job configuration.""" - url = f"/models/fine-tuning/jobs/{job_id}" - return await self._send_request("PUT", url, json=job_config) - - async def cancel_fine_tuning_job(self, job_id: str): - """Cancel a fine-tuning job.""" - url = f"/models/fine-tuning/jobs/{job_id}/cancel" - return await self._send_request("POST", url) - - async def resume_fine_tuning_job(self, job_id: str): - """Resume a paused fine-tuning job.""" - url = f"/models/fine-tuning/jobs/{job_id}/resume" - return await self._send_request("POST", url) - - async def get_fine_tuning_job_metrics(self, job_id: str): - """Get fine-tuning job training metrics.""" - url = f"/models/fine-tuning/jobs/{job_id}/metrics" - return await self._send_request("GET", url) - - async def get_fine_tuning_job_logs(self, job_id: str, page: int = 1, limit: int = 50): - """Get fine-tuning job logs.""" - params = {"page": page, "limit": limit} - url = f"/models/fine-tuning/jobs/{job_id}/logs" - return await self._send_request("GET", url, params=params) - - # Custom Model Deployment APIs - async def list_custom_deployments(self, page: int = 1, limit: int = 20, status: str | None = None): - """List custom model deployments.""" - params = {"page": page, "limit": limit} - if status: - params["status"] = status - return await self._send_request("GET", "/models/custom/deployments", params=params) - - async def create_custom_deployment(self, deployment_config: Dict[str, Any]): - """Create a custom model deployment.""" - return await self._send_request("POST", "/models/custom/deployments", json=deployment_config) - - async def get_custom_deployment(self, deployment_id: str): - """Get custom deployment details.""" - url = f"/models/custom/deployments/{deployment_id}" - return await self._send_request("GET", url) - - async def update_custom_deployment(self, deployment_id: str, deployment_config: Dict[str, Any]): - """Update custom deployment configuration.""" - url = f"/models/custom/deployments/{deployment_id}" - return await self._send_request("PUT", url, json=deployment_config) - - async def delete_custom_deployment(self, deployment_id: str): - """Delete a custom deployment.""" - url = f"/models/custom/deployments/{deployment_id}" - return await self._send_request("DELETE", url) - - async def scale_custom_deployment(self, deployment_id: str, scale_config: Dict[str, Any]): - """Scale custom deployment resources.""" - url = f"/models/custom/deployments/{deployment_id}/scale" - return await self._send_request("POST", url, json=scale_config) - - async def restart_custom_deployment(self, deployment_id: str): - """Restart a custom deployment.""" - url = f"/models/custom/deployments/{deployment_id}/restart" - return await self._send_request("POST", url) - - # Model Performance Monitoring APIs - async def get_model_performance_history( - self, - model_provider: str, - model_name: str, - start_date: str, - end_date: str, - metrics: List[str] | None = None, - ): - """Get model performance history.""" - params = {"start_date": start_date, "end_date": end_date} - if metrics: - params["metrics"] = ",".join(metrics) - url = f"/models/{model_provider}/{model_name}/performance/history" - return await self._send_request("GET", url, params=params) - - async def get_model_health_metrics(self, model_provider: str, model_name: str): - """Get real-time model health metrics.""" - url = f"/models/{model_provider}/{model_name}/health" - return await self._send_request("GET", url) - - async def get_model_usage_stats( - self, - model_provider: str, - model_name: str, - start_date: str, - end_date: str, - granularity: str = "day", - ): - """Get model usage statistics.""" - params = { - "start_date": start_date, - "end_date": end_date, - "granularity": granularity, - } - url = f"/models/{model_provider}/{model_name}/usage" - return await self._send_request("GET", url, params=params) - - async def get_model_cost_analysis(self, model_provider: str, model_name: str, start_date: str, end_date: str): - """Get model cost analysis.""" - params = {"start_date": start_date, "end_date": end_date} - url = f"/models/{model_provider}/{model_name}/costs" - return await self._send_request("GET", url, params=params) - - # Model Versioning APIs - async def list_model_versions(self, model_provider: str, model_name: str, page: int = 1, limit: int = 20): - """List model versions.""" - params = {"page": page, "limit": limit} - url = f"/models/{model_provider}/{model_name}/versions" - return await self._send_request("GET", url, params=params) - - async def create_model_version(self, model_provider: str, model_name: str, version_config: Dict[str, Any]): - """Create a new model version.""" - url = f"/models/{model_provider}/{model_name}/versions" - return await self._send_request("POST", url, json=version_config) - - async def get_model_version(self, model_provider: str, model_name: str, version_id: str): - """Get model version details.""" - url = f"/models/{model_provider}/{model_name}/versions/{version_id}" - return await self._send_request("GET", url) - - async def promote_model_version(self, model_provider: str, model_name: str, version_id: str): - """Promote model version to production.""" - url = f"/models/{model_provider}/{model_name}/versions/{version_id}/promote" - return await self._send_request("POST", url) - - async def rollback_model_version(self, model_provider: str, model_name: str, version_id: str): - """Rollback to a specific model version.""" - url = f"/models/{model_provider}/{model_name}/versions/{version_id}/rollback" - return await self._send_request("POST", url) - - # Model Registry APIs - async def list_registry_models(self, page: int = 1, limit: int = 20, filter: str | None = None): - """List models in registry.""" - params = {"page": page, "limit": limit} - if filter: - params["filter"] = filter - return await self._send_request("GET", "/models/registry", params=params) - - async def register_model(self, model_config: Dict[str, Any]): - """Register a new model in the registry.""" - return await self._send_request("POST", "/models/registry", json=model_config) - - async def get_registry_model(self, model_id: str): - """Get registered model details.""" - url = f"/models/registry/{model_id}" - return await self._send_request("GET", url) - - async def update_registry_model(self, model_id: str, model_config: Dict[str, Any]): - """Update registered model information.""" - url = f"/models/registry/{model_id}" - return await self._send_request("PUT", url, json=model_config) - - async def unregister_model(self, model_id: str): - """Unregister a model from the registry.""" - url = f"/models/registry/{model_id}" - return await self._send_request("DELETE", url) - - -class AsyncAdvancedAppClient(AsyncDifyClient): - """Async Advanced App Configuration APIs for comprehensive app management.""" - - # App Creation and Management APIs - async def create_app(self, app_config: Dict[str, Any]): - """Create a new application.""" - return await self._send_request("POST", "/apps", json=app_config) - - async def list_apps( - self, - page: int = 1, - limit: int = 20, - app_type: str | None = None, - status: str | None = None, - ): - """List applications with filtering.""" - params = {"page": page, "limit": limit} - if app_type: - params["app_type"] = app_type - if status: - params["status"] = status - return await self._send_request("GET", "/apps", params=params) - - async def get_app(self, app_id: str): - """Get detailed application information.""" - url = f"/apps/{app_id}" - return await self._send_request("GET", url) - - async def update_app(self, app_id: str, app_config: Dict[str, Any]): - """Update application configuration.""" - url = f"/apps/{app_id}" - return await self._send_request("PUT", url, json=app_config) - - async def delete_app(self, app_id: str): - """Delete an application.""" - url = f"/apps/{app_id}" - return await self._send_request("DELETE", url) - - async def duplicate_app(self, app_id: str, duplicate_config: Dict[str, Any]): - """Duplicate an application.""" - url = f"/apps/{app_id}/duplicate" - return await self._send_request("POST", url, json=duplicate_config) - - async def archive_app(self, app_id: str): - """Archive an application.""" - url = f"/apps/{app_id}/archive" - return await self._send_request("POST", url) - - async def restore_app(self, app_id: str): - """Restore an archived application.""" - url = f"/apps/{app_id}/restore" - return await self._send_request("POST", url) - - # App Publishing and Versioning APIs - async def publish_app(self, app_id: str, publish_config: Dict[str, Any] | None = None): - """Publish an application.""" - data = publish_config or {} - url = f"/apps/{app_id}/publish" - return await self._send_request("POST", url, json=data) - - async def unpublish_app(self, app_id: str): - """Unpublish an application.""" - url = f"/apps/{app_id}/unpublish" - return await self._send_request("POST", url) - - async def list_app_versions(self, app_id: str, page: int = 1, limit: int = 20): - """List application versions.""" - params = {"page": page, "limit": limit} - url = f"/apps/{app_id}/versions" - return await self._send_request("GET", url, params=params) - - async def create_app_version(self, app_id: str, version_config: Dict[str, Any]): - """Create a new application version.""" - url = f"/apps/{app_id}/versions" - return await self._send_request("POST", url, json=version_config) - - async def get_app_version(self, app_id: str, version_id: str): - """Get application version details.""" - url = f"/apps/{app_id}/versions/{version_id}" - return await self._send_request("GET", url) - - async def rollback_app_version(self, app_id: str, version_id: str): - """Rollback application to a specific version.""" - url = f"/apps/{app_id}/versions/{version_id}/rollback" - return await self._send_request("POST", url) - - # App Template APIs - async def list_app_templates(self, page: int = 1, limit: int = 20, category: str | None = None): - """List available app templates.""" - params = {"page": page, "limit": limit} - if category: - params["category"] = category - return await self._send_request("GET", "/apps/templates", params=params) - - async def get_app_template(self, template_id: str): - """Get app template details.""" - url = f"/apps/templates/{template_id}" - return await self._send_request("GET", url) - - async def create_app_from_template(self, template_id: str, app_config: Dict[str, Any]): - """Create an app from a template.""" - url = f"/apps/templates/{template_id}/create" - return await self._send_request("POST", url, json=app_config) - - async def create_custom_template(self, app_id: str, template_config: Dict[str, Any]): - """Create a custom template from an existing app.""" - url = f"/apps/{app_id}/create-template" - return await self._send_request("POST", url, json=template_config) - - # App Analytics and Metrics APIs - async def get_app_analytics( - self, - app_id: str, - start_date: str, - end_date: str, - metrics: List[str] | None = None, - ): - """Get application analytics.""" - params = {"start_date": start_date, "end_date": end_date} - if metrics: - params["metrics"] = ",".join(metrics) - url = f"/apps/{app_id}/analytics" - return await self._send_request("GET", url, params=params) - - async def get_app_user_feedback(self, app_id: str, page: int = 1, limit: int = 20, rating: int | None = None): - """Get user feedback for an application.""" - params = {"page": page, "limit": limit} - if rating: - params["rating"] = rating - url = f"/apps/{app_id}/feedback" - return await self._send_request("GET", url, params=params) - - async def get_app_error_logs( - self, - app_id: str, - start_date: str, - end_date: str, - error_type: str | None = None, - page: int = 1, - limit: int = 20, - ): - """Get application error logs.""" - params = { - "start_date": start_date, - "end_date": end_date, - "page": page, - "limit": limit, - } - if error_type: - params["error_type"] = error_type - url = f"/apps/{app_id}/errors" - return await self._send_request("GET", url, params=params) - - # Advanced Configuration APIs - async def get_app_advanced_config(self, app_id: str): - """Get advanced application configuration.""" - url = f"/apps/{app_id}/advanced-config" - return await self._send_request("GET", url) - - async def update_app_advanced_config(self, app_id: str, config: Dict[str, Any]): - """Update advanced application configuration.""" - url = f"/apps/{app_id}/advanced-config" - return await self._send_request("PUT", url, json=config) - - async def get_app_environment_variables(self, app_id: str): - """Get application environment variables.""" - url = f"/apps/{app_id}/environment" - return await self._send_request("GET", url) - - async def update_app_environment_variables(self, app_id: str, variables: Dict[str, str]): - """Update application environment variables.""" - url = f"/apps/{app_id}/environment" - return await self._send_request("PUT", url, json=variables) - - async def get_app_resource_limits(self, app_id: str): - """Get application resource limits.""" - url = f"/apps/{app_id}/resource-limits" - return await self._send_request("GET", url) - - async def update_app_resource_limits(self, app_id: str, limits: Dict[str, Any]): - """Update application resource limits.""" - url = f"/apps/{app_id}/resource-limits" - return await self._send_request("PUT", url, json=limits) - - # App Integration APIs - async def get_app_integrations(self, app_id: str): - """Get application integrations.""" - url = f"/apps/{app_id}/integrations" - return await self._send_request("GET", url) - - async def add_app_integration(self, app_id: str, integration_config: Dict[str, Any]): - """Add integration to application.""" - url = f"/apps/{app_id}/integrations" - return await self._send_request("POST", url, json=integration_config) - - async def update_app_integration(self, app_id: str, integration_id: str, config: Dict[str, Any]): - """Update application integration.""" - url = f"/apps/{app_id}/integrations/{integration_id}" - return await self._send_request("PUT", url, json=config) - - async def remove_app_integration(self, app_id: str, integration_id: str): - """Remove integration from application.""" - url = f"/apps/{app_id}/integrations/{integration_id}" - return await self._send_request("DELETE", url) - - async def test_app_integration(self, app_id: str, integration_id: str): - """Test application integration.""" - url = f"/apps/{app_id}/integrations/{integration_id}/test" - return await self._send_request("POST", url) diff --git a/sdks/python-client/dify_client/base_client.py b/sdks/python-client/dify_client/base_client.py deleted file mode 100644 index 0ad6e07b23..0000000000 --- a/sdks/python-client/dify_client/base_client.py +++ /dev/null @@ -1,228 +0,0 @@ -"""Base client with common functionality for both sync and async clients.""" - -import json -import time -import logging -from typing import Dict, Callable, Optional - -try: - # Python 3.10+ - from typing import ParamSpec -except ImportError: - # Python < 3.10 - from typing_extensions import ParamSpec - -from urllib.parse import urljoin - -import httpx - -P = ParamSpec("P") - -from .exceptions import ( - DifyClientError, - APIError, - AuthenticationError, - RateLimitError, - ValidationError, - NetworkError, - TimeoutError, -) - - -class BaseClientMixin: - """Mixin class providing common functionality for Dify clients.""" - - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - timeout: float = 60.0, - max_retries: int = 3, - retry_delay: float = 1.0, - enable_logging: bool = False, - ): - """Initialize the base client. - - Args: - api_key: Your Dify API key - base_url: Base URL for the Dify API - timeout: Request timeout in seconds - max_retries: Maximum number of retry attempts - retry_delay: Delay between retries in seconds - enable_logging: Enable detailed logging - """ - if not api_key: - raise ValidationError("API key is required") - - self.api_key = api_key - self.base_url = base_url.rstrip("/") - self.timeout = timeout - self.max_retries = max_retries - self.retry_delay = retry_delay - self.enable_logging = enable_logging - - # Setup logging - self.logger = logging.getLogger(f"dify_client.{self.__class__.__name__.lower()}") - if enable_logging and not self.logger.handlers: - # Create console handler with formatter - handler = logging.StreamHandler() - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - handler.setFormatter(formatter) - self.logger.addHandler(handler) - self.logger.setLevel(logging.INFO) - self.enable_logging = True - else: - self.enable_logging = enable_logging - - def _get_headers(self, content_type: str = "application/json") -> Dict[str, str]: - """Get common request headers.""" - return { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": content_type, - "User-Agent": "dify-client-python/0.1.12", - } - - def _build_url(self, endpoint: str) -> str: - """Build full URL from endpoint.""" - return urljoin(self.base_url + "/", endpoint.lstrip("/")) - - def _handle_response(self, response: httpx.Response) -> httpx.Response: - """Handle HTTP response and raise appropriate exceptions.""" - try: - if response.status_code == 401: - raise AuthenticationError( - "Authentication failed. Check your API key.", - status_code=response.status_code, - response=response.json() if response.content else None, - ) - elif response.status_code == 429: - retry_after = response.headers.get("Retry-After") - raise RateLimitError( - "Rate limit exceeded. Please try again later.", - retry_after=int(retry_after) if retry_after else None, - ) - elif response.status_code >= 400: - try: - error_data = response.json() - message = error_data.get("message", f"HTTP {response.status_code}") - except: - message = f"HTTP {response.status_code}: {response.text}" - - raise APIError( - message, - status_code=response.status_code, - response=response.json() if response.content else None, - ) - - return response - - except json.JSONDecodeError: - raise APIError( - f"Invalid JSON response: {response.text}", - status_code=response.status_code, - ) - - def _retry_request( - self, - request_func: Callable[P, httpx.Response], - request_context: str | None = None, - *args: P.args, - **kwargs: P.kwargs, - ) -> httpx.Response: - """Retry a request with exponential backoff. - - Args: - request_func: Function that performs the HTTP request - request_context: Context description for logging (e.g., "GET /v1/messages") - *args: Positional arguments to pass to request_func - **kwargs: Keyword arguments to pass to request_func - - Returns: - httpx.Response: Successful response - - Raises: - NetworkError: On network failures after retries - TimeoutError: On timeout failures after retries - APIError: On API errors (4xx/5xx responses) - DifyClientError: On unexpected failures - """ - last_exception = None - - for attempt in range(self.max_retries + 1): - try: - response = request_func(*args, **kwargs) - return response # Let caller handle response processing - - except (httpx.NetworkError, httpx.TimeoutException) as e: - last_exception = e - context_msg = f" {request_context}" if request_context else "" - - if attempt < self.max_retries: - delay = self.retry_delay * (2**attempt) # Exponential backoff - self.logger.warning( - f"Request failed{context_msg} (attempt {attempt + 1}/{self.max_retries + 1}): {e}. " - f"Retrying in {delay:.2f} seconds..." - ) - time.sleep(delay) - else: - self.logger.error(f"Request failed{context_msg} after {self.max_retries + 1} attempts: {e}") - # Convert to custom exceptions - if isinstance(e, httpx.TimeoutException): - from .exceptions import TimeoutError - - raise TimeoutError(f"Request timed out after {self.max_retries} retries{context_msg}") from e - else: - from .exceptions import NetworkError - - raise NetworkError( - f"Network error after {self.max_retries} retries{context_msg}: {str(e)}" - ) from e - - if last_exception: - raise last_exception - raise DifyClientError("Request failed after retries") - - def _validate_params(self, **params) -> None: - """Validate request parameters.""" - for key, value in params.items(): - if value is None: - continue - - # String validations - if isinstance(value, str): - if not value.strip(): - raise ValidationError(f"Parameter '{key}' cannot be empty or whitespace only") - if len(value) > 10000: - raise ValidationError(f"Parameter '{key}' exceeds maximum length of 10000 characters") - - # List validations - elif isinstance(value, list): - if len(value) > 1000: - raise ValidationError(f"Parameter '{key}' exceeds maximum size of 1000 items") - - # Dictionary validations - elif isinstance(value, dict): - if len(value) > 100: - raise ValidationError(f"Parameter '{key}' exceeds maximum size of 100 items") - - # Type-specific validations - if key == "user" and not isinstance(value, str): - raise ValidationError(f"Parameter '{key}' must be a string") - elif key in ["page", "limit", "page_size"] and not isinstance(value, int): - raise ValidationError(f"Parameter '{key}' must be an integer") - elif key == "files" and not isinstance(value, (list, dict)): - raise ValidationError(f"Parameter '{key}' must be a list or dict") - elif key == "rating" and value not in ["like", "dislike"]: - raise ValidationError(f"Parameter '{key}' must be 'like' or 'dislike'") - - def _log_request(self, method: str, url: str, **kwargs) -> None: - """Log request details.""" - self.logger.info(f"Making {method} request to {url}") - if kwargs.get("json"): - self.logger.debug(f"Request body: {kwargs['json']}") - if kwargs.get("params"): - self.logger.debug(f"Query params: {kwargs['params']}") - - def _log_response(self, response: httpx.Response) -> None: - """Log response details.""" - self.logger.info(f"Received response: {response.status_code} ({len(response.content)} bytes)") diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py deleted file mode 100644 index cebdf6845c..0000000000 --- a/sdks/python-client/dify_client/client.py +++ /dev/null @@ -1,1267 +0,0 @@ -import json -import logging -import os -from typing import Literal, Dict, List, Any, IO, Optional, Union - -import httpx -from .base_client import BaseClientMixin -from .exceptions import ( - APIError, - AuthenticationError, - RateLimitError, - ValidationError, - FileUploadError, -) - - -class DifyClient(BaseClientMixin): - """Synchronous Dify API client. - - This client uses httpx.Client for efficient connection pooling and resource management. - It's recommended to use this client as a context manager: - - Example: - with DifyClient(api_key="your-key") as client: - response = client.get_app_info() - """ - - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - timeout: float = 60.0, - max_retries: int = 3, - retry_delay: float = 1.0, - enable_logging: bool = False, - ): - """Initialize the Dify client. - - Args: - api_key: Your Dify API key - base_url: Base URL for the Dify API - timeout: Request timeout in seconds (default: 60.0) - max_retries: Maximum number of retry attempts (default: 3) - retry_delay: Delay between retries in seconds (default: 1.0) - enable_logging: Whether to enable request logging (default: True) - """ - # Initialize base client functionality - BaseClientMixin.__init__(self, api_key, base_url, timeout, max_retries, retry_delay, enable_logging) - - self._client = httpx.Client( - base_url=base_url, - timeout=httpx.Timeout(timeout, connect=5.0), - ) - - def __enter__(self): - """Support context manager protocol.""" - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Clean up resources when exiting context.""" - self.close() - - def close(self): - """Close the HTTP client and release resources.""" - if hasattr(self, "_client"): - self._client.close() - - def _send_request( - self, - method: str, - endpoint: str, - json: Dict[str, Any] | None = None, - params: Dict[str, Any] | None = None, - stream: bool = False, - **kwargs, - ): - """Send an HTTP request to the Dify API with retry logic. - - Args: - method: HTTP method (GET, POST, PUT, PATCH, DELETE) - endpoint: API endpoint path - json: JSON request body - params: Query parameters - stream: Whether to stream the response - **kwargs: Additional arguments to pass to httpx.request - - Returns: - httpx.Response object - """ - # Validate parameters - if json: - self._validate_params(**json) - if params: - self._validate_params(**params) - - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } - - def make_request(): - """Inner function to perform the actual HTTP request.""" - # Log request if logging is enabled - if self.enable_logging: - self.logger.info(f"Sending {method} request to {endpoint}") - # Debug logging for detailed information - if self.logger.isEnabledFor(logging.DEBUG): - if json: - self.logger.debug(f"Request body: {json}") - if params: - self.logger.debug(f"Request params: {params}") - - # httpx.Client automatically prepends base_url - response = self._client.request( - method, - endpoint, - json=json, - params=params, - headers=headers, - **kwargs, - ) - - # Log response if logging is enabled - if self.enable_logging: - self.logger.info(f"Received response: {response.status_code}") - - return response - - # Use the retry mechanism from base client - request_context = f"{method} {endpoint}" - response = self._retry_request(make_request, request_context) - - # Handle error responses (API errors don't retry) - self._handle_error_response(response) - - return response - - def _handle_error_response(self, response, is_upload_request: bool = False) -> None: - """Handle HTTP error responses and raise appropriate exceptions.""" - - if response.status_code < 400: - return # Success response - - try: - error_data = response.json() - message = error_data.get("message", f"HTTP {response.status_code}") - except (ValueError, KeyError): - message = f"HTTP {response.status_code}" - error_data = None - - # Log error response if logging is enabled - if self.enable_logging: - self.logger.error(f"API error: {response.status_code} - {message}") - - if response.status_code == 401: - raise AuthenticationError(message, response.status_code, error_data) - elif response.status_code == 429: - retry_after = response.headers.get("Retry-After") - raise RateLimitError(message, retry_after) - elif response.status_code == 422: - raise ValidationError(message, response.status_code, error_data) - elif response.status_code == 400: - # Check if this is a file upload error based on the URL or context - current_url = getattr(response, "url", "") or "" - if is_upload_request or "upload" in str(current_url).lower() or "files" in str(current_url).lower(): - raise FileUploadError(message, response.status_code, error_data) - else: - raise APIError(message, response.status_code, error_data) - elif response.status_code >= 500: - # Server errors should raise APIError - raise APIError(message, response.status_code, error_data) - elif response.status_code >= 400: - raise APIError(message, response.status_code, error_data) - - def _send_request_with_files(self, method: str, endpoint: str, data: dict, files: dict): - """Send an HTTP request with file uploads. - - Args: - method: HTTP method (POST, PUT, etc.) - endpoint: API endpoint path - data: Form data - files: Files to upload - - Returns: - httpx.Response object - """ - headers = {"Authorization": f"Bearer {self.api_key}"} - - # Log file upload request if logging is enabled - if self.enable_logging: - self.logger.info(f"Sending {method} file upload request to {endpoint}") - self.logger.debug(f"Form data: {data}") - self.logger.debug(f"Files: {files}") - - response = self._client.request( - method, - endpoint, - data=data, - headers=headers, - files=files, - ) - - # Log response if logging is enabled - if self.enable_logging: - self.logger.info(f"Received file upload response: {response.status_code}") - - # Handle error responses - self._handle_error_response(response, is_upload_request=True) - - return response - - def message_feedback(self, message_id: str, rating: Literal["like", "dislike"], user: str): - self._validate_params(message_id=message_id, rating=rating, user=user) - data = {"rating": rating, "user": user} - return self._send_request("POST", f"/messages/{message_id}/feedbacks", data) - - def get_application_parameters(self, user: str): - params = {"user": user} - return self._send_request("GET", "/parameters", params=params) - - def file_upload(self, user: str, files: dict): - data = {"user": user} - return self._send_request_with_files("POST", "/files/upload", data=data, files=files) - - def text_to_audio(self, text: str, user: str, streaming: bool = False): - data = {"text": text, "user": user, "streaming": streaming} - return self._send_request("POST", "/text-to-audio", json=data) - - def get_meta(self, user: str): - params = {"user": user} - return self._send_request("GET", "/meta", params=params) - - def get_app_info(self): - """Get basic application information including name, description, tags, and mode.""" - return self._send_request("GET", "/info") - - def get_app_site_info(self): - """Get application site information.""" - return self._send_request("GET", "/site") - - def get_file_preview(self, file_id: str): - """Get file preview by file ID.""" - return self._send_request("GET", f"/files/{file_id}/preview") - - # App Configuration APIs - def get_app_site_config(self, app_id: str): - """Get app site configuration. - - Args: - app_id: ID of the app - - Returns: - App site configuration - """ - url = f"/apps/{app_id}/site/config" - return self._send_request("GET", url) - - def update_app_site_config(self, app_id: str, config_data: Dict[str, Any]): - """Update app site configuration. - - Args: - app_id: ID of the app - config_data: Configuration data to update - - Returns: - Updated app site configuration - """ - url = f"/apps/{app_id}/site/config" - return self._send_request("PUT", url, json=config_data) - - def get_app_api_tokens(self, app_id: str): - """Get API tokens for an app. - - Args: - app_id: ID of the app - - Returns: - List of API tokens - """ - url = f"/apps/{app_id}/api-tokens" - return self._send_request("GET", url) - - def create_app_api_token(self, app_id: str, name: str, description: str | None = None): - """Create a new API token for an app. - - Args: - app_id: ID of the app - name: Name for the API token - description: Description for the API token (optional) - - Returns: - Created API token information - """ - data = {"name": name, "description": description} - url = f"/apps/{app_id}/api-tokens" - return self._send_request("POST", url, json=data) - - def delete_app_api_token(self, app_id: str, token_id: str): - """Delete an API token. - - Args: - app_id: ID of the app - token_id: ID of the token to delete - - Returns: - Deletion result - """ - url = f"/apps/{app_id}/api-tokens/{token_id}" - return self._send_request("DELETE", url) - - -class CompletionClient(DifyClient): - def create_completion_message( - self, - inputs: dict, - response_mode: Literal["blocking", "streaming"], - user: str, - files: Dict[str, Any] | None = None, - ): - # Validate parameters - if not isinstance(inputs, dict): - raise ValidationError("inputs must be a dictionary") - if response_mode not in ["blocking", "streaming"]: - raise ValidationError("response_mode must be 'blocking' or 'streaming'") - - self._validate_params(inputs=inputs, response_mode=response_mode, user=user) - - data = { - "inputs": inputs, - "response_mode": response_mode, - "user": user, - "files": files, - } - return self._send_request( - "POST", - "/completion-messages", - data, - stream=(response_mode == "streaming"), - ) - - -class ChatClient(DifyClient): - def create_chat_message( - self, - inputs: dict, - query: str, - user: str, - response_mode: Literal["blocking", "streaming"] = "blocking", - conversation_id: str | None = None, - files: Dict[str, Any] | None = None, - ): - # Validate parameters - if not isinstance(inputs, dict): - raise ValidationError("inputs must be a dictionary") - if not isinstance(query, str) or not query.strip(): - raise ValidationError("query must be a non-empty string") - if response_mode not in ["blocking", "streaming"]: - raise ValidationError("response_mode must be 'blocking' or 'streaming'") - - self._validate_params(inputs=inputs, query=query, user=user, response_mode=response_mode) - - data = { - "inputs": inputs, - "query": query, - "user": user, - "response_mode": response_mode, - "files": files, - } - if conversation_id: - data["conversation_id"] = conversation_id - - return self._send_request( - "POST", - "/chat-messages", - data, - stream=(response_mode == "streaming"), - ) - - def get_suggested(self, message_id: str, user: str): - params = {"user": user} - return self._send_request("GET", f"/messages/{message_id}/suggested", params=params) - - def stop_message(self, task_id: str, user: str): - data = {"user": user} - return self._send_request("POST", f"/chat-messages/{task_id}/stop", data) - - def get_conversations( - self, - user: str, - last_id: str | None = None, - limit: int | None = None, - pinned: bool | None = None, - ): - params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned} - return self._send_request("GET", "/conversations", params=params) - - def get_conversation_messages( - self, - user: str, - conversation_id: str | None = None, - first_id: str | None = None, - limit: int | None = None, - ): - params = {"user": user} - - if conversation_id: - params["conversation_id"] = conversation_id - if first_id: - params["first_id"] = first_id - if limit: - params["limit"] = limit - - return self._send_request("GET", "/messages", params=params) - - def rename_conversation(self, conversation_id: str, name: str, auto_generate: bool, user: str): - data = {"name": name, "auto_generate": auto_generate, "user": user} - return self._send_request("POST", f"/conversations/{conversation_id}/name", data) - - def delete_conversation(self, conversation_id: str, user: str): - data = {"user": user} - return self._send_request("DELETE", f"/conversations/{conversation_id}", data) - - def audio_to_text(self, audio_file: Union[IO[bytes], tuple], user: str): - data = {"user": user} - files = {"file": audio_file} - return self._send_request_with_files("POST", "/audio-to-text", data, files) - - # Annotation APIs - def annotation_reply_action( - self, - action: Literal["enable", "disable"], - score_threshold: float, - embedding_provider_name: str, - embedding_model_name: str, - ): - """Enable or disable annotation reply feature.""" - data = { - "score_threshold": score_threshold, - "embedding_provider_name": embedding_provider_name, - "embedding_model_name": embedding_model_name, - } - return self._send_request("POST", f"/apps/annotation-reply/{action}", json=data) - - def get_annotation_reply_status(self, action: Literal["enable", "disable"], job_id: str): - """Get the status of an annotation reply action job.""" - return self._send_request("GET", f"/apps/annotation-reply/{action}/status/{job_id}") - - def list_annotations(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations for the application.""" - params = {"page": page, "limit": limit, "keyword": keyword} - return self._send_request("GET", "/apps/annotations", params=params) - - def create_annotation(self, question: str, answer: str): - """Create a new annotation.""" - data = {"question": question, "answer": answer} - return self._send_request("POST", "/apps/annotations", json=data) - - def update_annotation(self, annotation_id: str, question: str, answer: str): - """Update an existing annotation.""" - data = {"question": question, "answer": answer} - return self._send_request("PUT", f"/apps/annotations/{annotation_id}", json=data) - - def delete_annotation(self, annotation_id: str): - """Delete an annotation.""" - return self._send_request("DELETE", f"/apps/annotations/{annotation_id}") - - # Conversation Variables APIs - def get_conversation_variables(self, conversation_id: str, user: str): - """Get all variables for a specific conversation. - - Args: - conversation_id: The conversation ID to query variables for - user: User identifier - - Returns: - Response from the API containing: - - variables: List of conversation variables with their values - - conversation_id: The conversation ID - """ - params = {"user": user} - url = f"/conversations/{conversation_id}/variables" - return self._send_request("GET", url, params=params) - - def update_conversation_variable(self, conversation_id: str, variable_id: str, value: Any, user: str): - """Update a specific conversation variable. - - Args: - conversation_id: The conversation ID - variable_id: The variable ID to update - value: New value for the variable - user: User identifier - - Returns: - Response from the API with updated variable information - """ - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return self._send_request("PUT", url, json=data) - - def delete_annotation_with_response(self, annotation_id: str): - """Delete an annotation with full response handling.""" - url = f"/apps/annotations/{annotation_id}" - return self._send_request("DELETE", url) - - def list_conversation_variables_with_pagination( - self, conversation_id: str, user: str, page: int = 1, limit: int = 20 - ): - """List conversation variables with pagination.""" - params = {"page": page, "limit": limit, "user": user} - url = f"/conversations/{conversation_id}/variables" - return self._send_request("GET", url, params=params) - - def update_conversation_variable_with_response(self, conversation_id: str, variable_id: str, user: str, value: Any): - """Update a conversation variable with full response handling.""" - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return self._send_request("PUT", url, json=data) - - # Enhanced Annotation APIs - def get_annotation_reply_job_status(self, action: str, job_id: str): - """Get status of an annotation reply action job.""" - url = f"/apps/annotation-reply/{action}/status/{job_id}" - return self._send_request("GET", url) - - def list_annotations_with_pagination(self, page: int = 1, limit: int = 20, keyword: str | None = None): - """List annotations with pagination.""" - params = {"page": page, "limit": limit, "keyword": keyword} - return self._send_request("GET", "/apps/annotations", params=params) - - def create_annotation_with_response(self, question: str, answer: str): - """Create an annotation with full response handling.""" - data = {"question": question, "answer": answer} - return self._send_request("POST", "/apps/annotations", json=data) - - def update_annotation_with_response(self, annotation_id: str, question: str, answer: str): - """Update an annotation with full response handling.""" - data = {"question": question, "answer": answer} - url = f"/apps/annotations/{annotation_id}" - return self._send_request("PUT", url, json=data) - - -class WorkflowClient(DifyClient): - def run( - self, - inputs: dict, - response_mode: Literal["blocking", "streaming"] = "streaming", - user: str = "abc-123", - ): - data = {"inputs": inputs, "response_mode": response_mode, "user": user} - return self._send_request("POST", "/workflows/run", data) - - def stop(self, task_id, user): - data = {"user": user} - return self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data) - - def get_result(self, workflow_run_id): - return self._send_request("GET", f"/workflows/run/{workflow_run_id}") - - def get_workflow_logs( - self, - keyword: str = None, - status: Literal["succeeded", "failed", "stopped"] | None = None, - page: int = 1, - limit: int = 20, - created_at__before: str = None, - created_at__after: str = None, - created_by_end_user_session_id: str = None, - created_by_account: str = None, - ): - """Get workflow execution logs with optional filtering.""" - params = {"page": page, "limit": limit} - if keyword: - params["keyword"] = keyword - if status: - params["status"] = status - if created_at__before: - params["created_at__before"] = created_at__before - if created_at__after: - params["created_at__after"] = created_at__after - if created_by_end_user_session_id: - params["created_by_end_user_session_id"] = created_by_end_user_session_id - if created_by_account: - params["created_by_account"] = created_by_account - return self._send_request("GET", "/workflows/logs", params=params) - - def run_specific_workflow( - self, - workflow_id: str, - inputs: dict, - response_mode: Literal["blocking", "streaming"] = "streaming", - user: str = "abc-123", - ): - """Run a specific workflow by workflow ID.""" - data = {"inputs": inputs, "response_mode": response_mode, "user": user} - return self._send_request( - "POST", - f"/workflows/{workflow_id}/run", - data, - stream=(response_mode == "streaming"), - ) - - # Enhanced Workflow APIs - def get_workflow_draft(self, app_id: str): - """Get workflow draft configuration. - - Args: - app_id: ID of the workflow app - - Returns: - Workflow draft configuration - """ - url = f"/apps/{app_id}/workflow/draft" - return self._send_request("GET", url) - - def update_workflow_draft(self, app_id: str, workflow_data: Dict[str, Any]): - """Update workflow draft configuration. - - Args: - app_id: ID of the workflow app - workflow_data: Workflow configuration data - - Returns: - Updated workflow draft - """ - url = f"/apps/{app_id}/workflow/draft" - return self._send_request("PUT", url, json=workflow_data) - - def publish_workflow(self, app_id: str): - """Publish workflow from draft. - - Args: - app_id: ID of the workflow app - - Returns: - Published workflow information - """ - url = f"/apps/{app_id}/workflow/publish" - return self._send_request("POST", url) - - def get_workflow_run_history( - self, - app_id: str, - page: int = 1, - limit: int = 20, - status: Literal["succeeded", "failed", "stopped"] | None = None, - ): - """Get workflow run history. - - Args: - app_id: ID of the workflow app - page: Page number (default: 1) - limit: Number of items per page (default: 20) - status: Filter by status (optional) - - Returns: - Paginated workflow run history - """ - params = {"page": page, "limit": limit} - if status: - params["status"] = status - url = f"/apps/{app_id}/workflow/runs" - return self._send_request("GET", url, params=params) - - -class WorkspaceClient(DifyClient): - """Client for workspace-related operations.""" - - def get_available_models(self, model_type: str): - """Get available models by model type.""" - url = f"/workspaces/current/models/model-types/{model_type}" - return self._send_request("GET", url) - - def get_available_models_by_type(self, model_type: str): - """Get available models by model type (enhanced version).""" - url = f"/workspaces/current/models/model-types/{model_type}" - return self._send_request("GET", url) - - def get_model_providers(self): - """Get all model providers.""" - return self._send_request("GET", "/workspaces/current/model-providers") - - def get_model_provider_models(self, provider_name: str): - """Get models for a specific provider.""" - url = f"/workspaces/current/model-providers/{provider_name}/models" - return self._send_request("GET", url) - - def validate_model_provider_credentials(self, provider_name: str, credentials: Dict[str, Any]): - """Validate model provider credentials.""" - url = f"/workspaces/current/model-providers/{provider_name}/credentials/validate" - return self._send_request("POST", url, json=credentials) - - # File Management APIs - def get_file_info(self, file_id: str): - """Get information about a specific file.""" - url = f"/files/{file_id}/info" - return self._send_request("GET", url) - - def get_file_download_url(self, file_id: str): - """Get download URL for a file.""" - url = f"/files/{file_id}/download-url" - return self._send_request("GET", url) - - def delete_file(self, file_id: str): - """Delete a file.""" - url = f"/files/{file_id}" - return self._send_request("DELETE", url) - - -class KnowledgeBaseClient(DifyClient): - def __init__( - self, - api_key: str, - base_url: str = "https://api.dify.ai/v1", - dataset_id: str | None = None, - ): - """ - Construct a KnowledgeBaseClient object. - - Args: - api_key (str): API key of Dify. - base_url (str, optional): Base URL of Dify API. Defaults to 'https://api.dify.ai/v1'. - dataset_id (str, optional): ID of the dataset. Defaults to None. You don't need this if you just want to - create a new dataset. or list datasets. otherwise you need to set this. - """ - super().__init__(api_key=api_key, base_url=base_url) - self.dataset_id = dataset_id - - def _get_dataset_id(self): - if self.dataset_id is None: - raise ValueError("dataset_id is not set") - return self.dataset_id - - def create_dataset(self, name: str, **kwargs): - return self._send_request("POST", "/datasets", {"name": name}, **kwargs) - - def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): - return self._send_request("GET", "/datasets", params={"page": page, "limit": page_size}, **kwargs) - - def create_document_by_text(self, name, text, extra_params: Dict[str, Any] | None = None, **kwargs): - """ - Create a document by text. - - :param name: Name of the document - :param text: Text content of the document - :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) - e.g. - { - 'indexing_technique': 'high_quality', - 'process_rule': { - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': True} - ], - 'segmentation': { - 'separator': '\n', - 'max_tokens': 500 - } - }, - 'mode': 'custom' - } - } - :return: Response from the API - """ - data = { - "indexing_technique": "high_quality", - "process_rule": {"mode": "automatic"}, - "name": name, - "text": text, - } - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/document/create_by_text" - return self._send_request("POST", url, json=data, **kwargs) - - def update_document_by_text( - self, - document_id: str, - name: str, - text: str, - extra_params: Dict[str, Any] | None = None, - **kwargs, - ): - """ - Update a document by text. - - :param document_id: ID of the document - :param name: Name of the document - :param text: Text content of the document - :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) - e.g. - { - 'indexing_technique': 'high_quality', - 'process_rule': { - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': True} - ], - 'segmentation': { - 'separator': '\n', - 'max_tokens': 500 - } - }, - 'mode': 'custom' - } - } - :return: Response from the API - """ - data = {"name": name, "text": text} - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text" - return self._send_request("POST", url, json=data, **kwargs) - - def create_document_by_file( - self, - file_path: str, - original_document_id: str | None = None, - extra_params: Dict[str, Any] | None = None, - ): - """ - Create a document by file. - - :param file_path: Path to the file - :param original_document_id: pass this ID if you want to replace the original document (optional) - :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) - e.g. - { - 'indexing_technique': 'high_quality', - 'process_rule': { - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': True} - ], - 'segmentation': { - 'separator': '\n', - 'max_tokens': 500 - } - }, - 'mode': 'custom' - } - } - :return: Response from the API - """ - with open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - data = { - "process_rule": {"mode": "automatic"}, - "indexing_technique": "high_quality", - } - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - if original_document_id is not None: - data["original_document_id"] = original_document_id - url = f"/datasets/{self._get_dataset_id()}/document/create_by_file" - return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) - - def update_document_by_file( - self, - document_id: str, - file_path: str, - extra_params: Dict[str, Any] | None = None, - ): - """ - Update a document by file. - - :param document_id: ID of the document - :param file_path: Path to the file - :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional) - e.g. - { - 'indexing_technique': 'high_quality', - 'process_rule': { - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': True} - ], - 'segmentation': { - 'separator': '\n', - 'max_tokens': 500 - } - }, - 'mode': 'custom' - } - } - :return: - """ - with open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - data = {} - if extra_params is not None and isinstance(extra_params, dict): - data.update(extra_params) - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file" - return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files) - - def batch_indexing_status(self, batch_id: str, **kwargs): - """ - Get the status of the batch indexing. - - :param batch_id: ID of the batch uploading - :return: Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{batch_id}/indexing-status" - return self._send_request("GET", url, **kwargs) - - def delete_dataset(self): - """ - Delete this dataset. - - :return: Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}" - return self._send_request("DELETE", url) - - def delete_document(self, document_id: str): - """ - Delete a document. - - :param document_id: ID of the document - :return: Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}" - return self._send_request("DELETE", url) - - def list_documents( - self, - page: int | None = None, - page_size: int | None = None, - keyword: str | None = None, - **kwargs, - ): - """ - Get a list of documents in this dataset. - - :return: Response from the API - """ - params = {} - if page is not None: - params["page"] = page - if page_size is not None: - params["limit"] = page_size - if keyword is not None: - params["keyword"] = keyword - url = f"/datasets/{self._get_dataset_id()}/documents" - return self._send_request("GET", url, params=params, **kwargs) - - def add_segments(self, document_id: str, segments: list[dict], **kwargs): - """ - Add segments to a document. - - :param document_id: ID of the document - :param segments: List of segments to add, example: [{"content": "1", "answer": "1", "keyword": ["a"]}] - :return: Response from the API - """ - data = {"segments": segments} - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" - return self._send_request("POST", url, json=data, **kwargs) - - def query_segments( - self, - document_id: str, - keyword: str | None = None, - status: str | None = None, - **kwargs, - ): - """ - Query segments in this document. - - :param document_id: ID of the document - :param keyword: query keyword, optional - :param status: status of the segment, optional, e.g. completed - :param kwargs: Additional parameters to pass to the API. - Can include a 'params' dict for extra query parameters. - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments" - params = {} - if keyword is not None: - params["keyword"] = keyword - if status is not None: - params["status"] = status - if "params" in kwargs: - params.update(kwargs.pop("params")) - return self._send_request("GET", url, params=params, **kwargs) - - def delete_document_segment(self, document_id: str, segment_id: str): - """ - Delete a segment from a document. - - :param document_id: ID of the document - :param segment_id: ID of the segment - :return: Response from the API - """ - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" - return self._send_request("DELETE", url) - - def update_document_segment(self, document_id: str, segment_id: str, segment_data: dict, **kwargs): - """ - Update a segment in a document. - - :param document_id: ID of the document - :param segment_id: ID of the segment - :param segment_data: Data of the segment, example: {"content": "1", "answer": "1", "keyword": ["a"], "enabled": True} - :return: Response from the API - """ - data = {"segment": segment_data} - url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" - return self._send_request("POST", url, json=data, **kwargs) - - # Advanced Knowledge Base APIs - def hit_testing( - self, - query: str, - retrieval_model: Dict[str, Any] = None, - external_retrieval_model: Dict[str, Any] = None, - ): - """Perform hit testing on the dataset.""" - data = {"query": query} - if retrieval_model: - data["retrieval_model"] = retrieval_model - if external_retrieval_model: - data["external_retrieval_model"] = external_retrieval_model - url = f"/datasets/{self._get_dataset_id()}/hit-testing" - return self._send_request("POST", url, json=data) - - def get_dataset_metadata(self): - """Get dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata" - return self._send_request("GET", url) - - def create_dataset_metadata(self, metadata_data: Dict[str, Any]): - """Create dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata" - return self._send_request("POST", url, json=metadata_data) - - def update_dataset_metadata(self, metadata_id: str, metadata_data: Dict[str, Any]): - """Update dataset metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata/{metadata_id}" - return self._send_request("PATCH", url, json=metadata_data) - - def get_built_in_metadata(self): - """Get built-in metadata.""" - url = f"/datasets/{self._get_dataset_id()}/metadata/built-in" - return self._send_request("GET", url) - - def manage_built_in_metadata(self, action: str, metadata_data: Dict[str, Any] = None): - """Manage built-in metadata with specified action.""" - data = metadata_data or {} - url = f"/datasets/{self._get_dataset_id()}/metadata/built-in/{action}" - return self._send_request("POST", url, json=data) - - def update_documents_metadata(self, operation_data: List[Dict[str, Any]]): - """Update metadata for multiple documents.""" - url = f"/datasets/{self._get_dataset_id()}/documents/metadata" - data = {"operation_data": operation_data} - return self._send_request("POST", url, json=data) - - # Dataset Tags APIs - def list_dataset_tags(self): - """List all dataset tags.""" - return self._send_request("GET", "/datasets/tags") - - def bind_dataset_tags(self, tag_ids: List[str]): - """Bind tags to dataset.""" - data = {"tag_ids": tag_ids, "target_id": self._get_dataset_id()} - return self._send_request("POST", "/datasets/tags/binding", json=data) - - def unbind_dataset_tag(self, tag_id: str): - """Unbind a single tag from dataset.""" - data = {"tag_id": tag_id, "target_id": self._get_dataset_id()} - return self._send_request("POST", "/datasets/tags/unbinding", json=data) - - def get_dataset_tags(self): - """Get tags for current dataset.""" - url = f"/datasets/{self._get_dataset_id()}/tags" - return self._send_request("GET", url) - - # RAG Pipeline APIs - def get_datasource_plugins(self, is_published: bool = True): - """Get datasource plugins for RAG pipeline.""" - params = {"is_published": is_published} - url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource-plugins" - return self._send_request("GET", url, params=params) - - def run_datasource_node( - self, - node_id: str, - inputs: Dict[str, Any], - datasource_type: str, - is_published: bool = True, - credential_id: str = None, - ): - """Run a datasource node in RAG pipeline.""" - data = { - "inputs": inputs, - "datasource_type": datasource_type, - "is_published": is_published, - } - if credential_id: - data["credential_id"] = credential_id - url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource/nodes/{node_id}/run" - return self._send_request("POST", url, json=data, stream=True) - - def run_rag_pipeline( - self, - inputs: Dict[str, Any], - datasource_type: str, - datasource_info_list: List[Dict[str, Any]], - start_node_id: str, - is_published: bool = True, - response_mode: Literal["streaming", "blocking"] = "blocking", - ): - """Run RAG pipeline.""" - data = { - "inputs": inputs, - "datasource_type": datasource_type, - "datasource_info_list": datasource_info_list, - "start_node_id": start_node_id, - "is_published": is_published, - "response_mode": response_mode, - } - url = f"/datasets/{self._get_dataset_id()}/pipeline/run" - return self._send_request("POST", url, json=data, stream=response_mode == "streaming") - - def upload_pipeline_file(self, file_path: str): - """Upload file for RAG pipeline.""" - with open(file_path, "rb") as f: - files = {"file": (os.path.basename(file_path), f)} - return self._send_request_with_files("POST", "/datasets/pipeline/file-upload", {}, files) - - # Dataset Management APIs - def get_dataset(self, dataset_id: str | None = None): - """Get detailed information about a specific dataset. - - Args: - dataset_id: Dataset ID (optional, uses current dataset_id if not provided) - - Returns: - Response from the API containing dataset details including: - - name, description, permission - - indexing_technique, embedding_model, embedding_model_provider - - retrieval_model configuration - - document_count, word_count, app_count - - created_at, updated_at - """ - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}" - return self._send_request("GET", url) - - def update_dataset( - self, - dataset_id: str | None = None, - name: str | None = None, - description: str | None = None, - indexing_technique: str | None = None, - embedding_model: str | None = None, - embedding_model_provider: str | None = None, - retrieval_model: Dict[str, Any] | None = None, - **kwargs, - ): - """Update dataset configuration. - - Args: - dataset_id: Dataset ID (optional, uses current dataset_id if not provided) - name: New dataset name - description: New dataset description - indexing_technique: Indexing technique ('high_quality' or 'economy') - embedding_model: Embedding model name - embedding_model_provider: Embedding model provider - retrieval_model: Retrieval model configuration dict - **kwargs: Additional parameters to pass to the API - - Returns: - Response from the API with updated dataset information - """ - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}" - - # Build data dictionary with all possible parameters - payload = { - "name": name, - "description": description, - "indexing_technique": indexing_technique, - "embedding_model": embedding_model, - "embedding_model_provider": embedding_model_provider, - "retrieval_model": retrieval_model, - } - - # Filter out None values and merge with additional kwargs - data = {k: v for k, v in payload.items() if v is not None} - data.update(kwargs) - - return self._send_request("PATCH", url, json=data) - - def batch_update_document_status( - self, - action: Literal["enable", "disable", "archive", "un_archive"], - document_ids: List[str], - dataset_id: str | None = None, - ): - """Batch update document status (enable/disable/archive/unarchive). - - Args: - action: Action to perform on documents - - 'enable': Enable documents for retrieval - - 'disable': Disable documents from retrieval - - 'archive': Archive documents - - 'un_archive': Unarchive documents - document_ids: List of document IDs to update - dataset_id: Dataset ID (optional, uses current dataset_id if not provided) - - Returns: - Response from the API with operation result - """ - ds_id = dataset_id or self._get_dataset_id() - url = f"/datasets/{ds_id}/documents/status/{action}" - data = {"document_ids": document_ids} - return self._send_request("PATCH", url, json=data) - - # Enhanced Dataset APIs - def create_dataset_from_template(self, template_name: str, name: str, description: str | None = None): - """Create a dataset from a predefined template. - - Args: - template_name: Name of the template to use - name: Name for the new dataset - description: Description for the dataset (optional) - - Returns: - Created dataset information - """ - data = { - "template_name": template_name, - "name": name, - "description": description, - } - return self._send_request("POST", "/datasets/from-template", json=data) - - def duplicate_dataset(self, dataset_id: str, name: str): - """Duplicate an existing dataset. - - Args: - dataset_id: ID of dataset to duplicate - name: Name for duplicated dataset - - Returns: - New dataset information - """ - data = {"name": name} - url = f"/datasets/{dataset_id}/duplicate" - return self._send_request("POST", url, json=data) - - def list_conversation_variables_with_pagination( - self, conversation_id: str, user: str, page: int = 1, limit: int = 20 - ): - """List conversation variables with pagination.""" - params = {"page": page, "limit": limit, "user": user} - url = f"/conversations/{conversation_id}/variables" - return self._send_request("GET", url, params=params) - - def update_conversation_variable_with_response(self, conversation_id: str, variable_id: str, user: str, value: Any): - """Update a conversation variable with full response handling.""" - data = {"value": value, "user": user} - url = f"/conversations/{conversation_id}/variables/{variable_id}" - return self._send_request("PUT", url, json=data) diff --git a/sdks/python-client/dify_client/exceptions.py b/sdks/python-client/dify_client/exceptions.py deleted file mode 100644 index e7ba2ff4b2..0000000000 --- a/sdks/python-client/dify_client/exceptions.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Custom exceptions for the Dify client.""" - -from typing import Optional, Dict, Any - - -class DifyClientError(Exception): - """Base exception for all Dify client errors.""" - - def __init__(self, message: str, status_code: int | None = None, response: Dict[str, Any] | None = None): - super().__init__(message) - self.message = message - self.status_code = status_code - self.response = response - - -class APIError(DifyClientError): - """Raised when the API returns an error response.""" - - def __init__(self, message: str, status_code: int, response: Dict[str, Any] | None = None): - super().__init__(message, status_code, response) - self.status_code = status_code - - -class AuthenticationError(DifyClientError): - """Raised when authentication fails.""" - - pass - - -class RateLimitError(DifyClientError): - """Raised when rate limit is exceeded.""" - - def __init__(self, message: str = "Rate limit exceeded", retry_after: int | None = None): - super().__init__(message) - self.retry_after = retry_after - - -class ValidationError(DifyClientError): - """Raised when request validation fails.""" - - pass - - -class NetworkError(DifyClientError): - """Raised when network-related errors occur.""" - - pass - - -class TimeoutError(DifyClientError): - """Raised when request times out.""" - - pass - - -class FileUploadError(DifyClientError): - """Raised when file upload fails.""" - - pass - - -class DatasetError(DifyClientError): - """Raised when dataset operations fail.""" - - pass - - -class WorkflowError(DifyClientError): - """Raised when workflow operations fail.""" - - pass diff --git a/sdks/python-client/dify_client/models.py b/sdks/python-client/dify_client/models.py deleted file mode 100644 index 0321e9c3f4..0000000000 --- a/sdks/python-client/dify_client/models.py +++ /dev/null @@ -1,396 +0,0 @@ -"""Response models for the Dify client with proper type hints.""" - -from typing import Optional, List, Dict, Any, Literal, Union -from dataclasses import dataclass, field -from datetime import datetime - - -@dataclass -class BaseResponse: - """Base response model.""" - - success: bool = True - message: str | None = None - - -@dataclass -class ErrorResponse(BaseResponse): - """Error response model.""" - - error_code: str | None = None - details: Dict[str, Any] | None = None - success: bool = False - - -@dataclass -class FileInfo: - """File information model.""" - - id: str - name: str - size: int - mime_type: str - url: str | None = None - created_at: datetime | None = None - - -@dataclass -class MessageResponse(BaseResponse): - """Message response model.""" - - id: str = "" - answer: str = "" - conversation_id: str | None = None - created_at: int | None = None - metadata: Dict[str, Any] | None = None - files: List[Dict[str, Any]] | None = None - - -@dataclass -class ConversationResponse(BaseResponse): - """Conversation response model.""" - - id: str = "" - name: str = "" - inputs: Dict[str, Any] | None = None - status: str | None = None - created_at: int | None = None - updated_at: int | None = None - - -@dataclass -class DatasetResponse(BaseResponse): - """Dataset response model.""" - - id: str = "" - name: str = "" - description: str | None = None - permission: str | None = None - indexing_technique: str | None = None - embedding_model: str | None = None - embedding_model_provider: str | None = None - retrieval_model: Dict[str, Any] | None = None - document_count: int | None = None - word_count: int | None = None - app_count: int | None = None - created_at: int | None = None - updated_at: int | None = None - - -@dataclass -class DocumentResponse(BaseResponse): - """Document response model.""" - - id: str = "" - name: str = "" - data_source_type: str | None = None - data_source_info: Dict[str, Any] | None = None - dataset_process_rule_id: str | None = None - batch: str | None = None - position: int | None = None - enabled: bool | None = None - disabled_at: float | None = None - disabled_by: str | None = None - archived: bool | None = None - archived_reason: str | None = None - archived_at: float | None = None - archived_by: str | None = None - word_count: int | None = None - hit_count: int | None = None - doc_form: str | None = None - doc_metadata: Dict[str, Any] | None = None - created_at: float | None = None - updated_at: float | None = None - indexing_status: str | None = None - completed_at: float | None = None - paused_at: float | None = None - error: str | None = None - stopped_at: float | None = None - - -@dataclass -class DocumentSegmentResponse(BaseResponse): - """Document segment response model.""" - - id: str = "" - position: int | None = None - document_id: str | None = None - content: str | None = None - answer: str | None = None - word_count: int | None = None - tokens: int | None = None - keywords: List[str] | None = None - index_node_id: str | None = None - index_node_hash: str | None = None - hit_count: int | None = None - enabled: bool | None = None - disabled_at: float | None = None - disabled_by: str | None = None - status: str | None = None - created_by: str | None = None - created_at: float | None = None - indexing_at: float | None = None - completed_at: float | None = None - error: str | None = None - stopped_at: float | None = None - - -@dataclass -class WorkflowRunResponse(BaseResponse): - """Workflow run response model.""" - - id: str = "" - workflow_id: str | None = None - status: Literal["running", "succeeded", "failed", "stopped"] | None = None - inputs: Dict[str, Any] | None = None - outputs: Dict[str, Any] | None = None - error: str | None = None - elapsed_time: float | None = None - total_tokens: int | None = None - total_steps: int | None = None - created_at: float | None = None - finished_at: float | None = None - - -@dataclass -class ApplicationParametersResponse(BaseResponse): - """Application parameters response model.""" - - opening_statement: str | None = None - suggested_questions: List[str] | None = None - speech_to_text: Dict[str, Any] | None = None - text_to_speech: Dict[str, Any] | None = None - retriever_resource: Dict[str, Any] | None = None - sensitive_word_avoidance: Dict[str, Any] | None = None - file_upload: Dict[str, Any] | None = None - system_parameters: Dict[str, Any] | None = None - user_input_form: List[Dict[str, Any]] | None = None - - -@dataclass -class AnnotationResponse(BaseResponse): - """Annotation response model.""" - - id: str = "" - question: str = "" - answer: str = "" - content: str | None = None - created_at: float | None = None - updated_at: float | None = None - created_by: str | None = None - updated_by: str | None = None - hit_count: int | None = None - - -@dataclass -class PaginatedResponse(BaseResponse): - """Paginated response model.""" - - data: List[Any] = field(default_factory=list) - has_more: bool = False - limit: int = 0 - total: int = 0 - page: int | None = None - - -@dataclass -class ConversationVariableResponse(BaseResponse): - """Conversation variable response model.""" - - conversation_id: str = "" - variables: List[Dict[str, Any]] = field(default_factory=list) - - -@dataclass -class FileUploadResponse(BaseResponse): - """File upload response model.""" - - id: str = "" - name: str = "" - size: int = 0 - mime_type: str = "" - url: str | None = None - created_at: float | None = None - - -@dataclass -class AudioResponse(BaseResponse): - """Audio generation/response model.""" - - audio: str | None = None # Base64 encoded audio data or URL - audio_url: str | None = None - duration: float | None = None - sample_rate: int | None = None - - -@dataclass -class SuggestedQuestionsResponse(BaseResponse): - """Suggested questions response model.""" - - message_id: str = "" - questions: List[str] = field(default_factory=list) - - -@dataclass -class AppInfoResponse(BaseResponse): - """App info response model.""" - - id: str = "" - name: str = "" - description: str | None = None - icon: str | None = None - icon_background: str | None = None - mode: str | None = None - tags: List[str] | None = None - enable_site: bool | None = None - enable_api: bool | None = None - api_token: str | None = None - - -@dataclass -class WorkspaceModelsResponse(BaseResponse): - """Workspace models response model.""" - - models: List[Dict[str, Any]] = field(default_factory=list) - - -@dataclass -class HitTestingResponse(BaseResponse): - """Hit testing response model.""" - - query: str = "" - records: List[Dict[str, Any]] = field(default_factory=list) - - -@dataclass -class DatasetTagsResponse(BaseResponse): - """Dataset tags response model.""" - - tags: List[Dict[str, Any]] = field(default_factory=list) - - -@dataclass -class WorkflowLogsResponse(BaseResponse): - """Workflow logs response model.""" - - logs: List[Dict[str, Any]] = field(default_factory=list) - total: int = 0 - page: int = 0 - limit: int = 0 - has_more: bool = False - - -@dataclass -class ModelProviderResponse(BaseResponse): - """Model provider response model.""" - - provider_name: str = "" - provider_type: str = "" - models: List[Dict[str, Any]] = field(default_factory=list) - is_enabled: bool = False - credentials: Dict[str, Any] | None = None - - -@dataclass -class FileInfoResponse(BaseResponse): - """File info response model.""" - - id: str = "" - name: str = "" - size: int = 0 - mime_type: str = "" - url: str | None = None - created_at: int | None = None - metadata: Dict[str, Any] | None = None - - -@dataclass -class WorkflowDraftResponse(BaseResponse): - """Workflow draft response model.""" - - id: str = "" - app_id: str = "" - draft_data: Dict[str, Any] = field(default_factory=dict) - version: int = 0 - created_at: int | None = None - updated_at: int | None = None - - -@dataclass -class ApiTokenResponse(BaseResponse): - """API token response model.""" - - id: str = "" - name: str = "" - token: str = "" - description: str | None = None - created_at: int | None = None - last_used_at: int | None = None - is_active: bool = True - - -@dataclass -class JobStatusResponse(BaseResponse): - """Job status response model.""" - - job_id: str = "" - job_status: str = "" - error_msg: str | None = None - progress: float | None = None - created_at: int | None = None - updated_at: int | None = None - - -@dataclass -class DatasetQueryResponse(BaseResponse): - """Dataset query response model.""" - - query: str = "" - records: List[Dict[str, Any]] = field(default_factory=list) - total: int = 0 - search_time: float | None = None - retrieval_model: Dict[str, Any] | None = None - - -@dataclass -class DatasetTemplateResponse(BaseResponse): - """Dataset template response model.""" - - template_name: str = "" - display_name: str = "" - description: str = "" - category: str = "" - icon: str | None = None - config_schema: Dict[str, Any] = field(default_factory=dict) - - -# Type aliases for common response types -ResponseType = Union[ - BaseResponse, - ErrorResponse, - MessageResponse, - ConversationResponse, - DatasetResponse, - DocumentResponse, - DocumentSegmentResponse, - WorkflowRunResponse, - ApplicationParametersResponse, - AnnotationResponse, - PaginatedResponse, - ConversationVariableResponse, - FileUploadResponse, - AudioResponse, - SuggestedQuestionsResponse, - AppInfoResponse, - WorkspaceModelsResponse, - HitTestingResponse, - DatasetTagsResponse, - WorkflowLogsResponse, - ModelProviderResponse, - FileInfoResponse, - WorkflowDraftResponse, - ApiTokenResponse, - JobStatusResponse, - DatasetQueryResponse, - DatasetTemplateResponse, -] diff --git a/sdks/python-client/examples/advanced_usage.py b/sdks/python-client/examples/advanced_usage.py deleted file mode 100644 index bc8720bef2..0000000000 --- a/sdks/python-client/examples/advanced_usage.py +++ /dev/null @@ -1,264 +0,0 @@ -""" -Advanced usage examples for the Dify Python SDK. - -This example demonstrates: -- Error handling and retries -- Logging configuration -- Context managers -- Async usage -- File uploads -- Dataset management -""" - -import asyncio -import logging -from pathlib import Path - -from dify_client import ( - ChatClient, - CompletionClient, - AsyncChatClient, - KnowledgeBaseClient, - DifyClient, -) -from dify_client.exceptions import ( - APIError, - RateLimitError, - AuthenticationError, - DifyClientError, -) - - -def setup_logging(): - """Setup logging for the SDK.""" - logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") - - -def example_chat_with_error_handling(): - """Example of chat with comprehensive error handling.""" - api_key = "your-api-key-here" - - try: - with ChatClient(api_key, enable_logging=True) as client: - # Simple chat message - response = client.create_chat_message( - inputs={}, query="Hello, how are you?", user="user-123", response_mode="blocking" - ) - - result = response.json() - print(f"Response: {result.get('answer')}") - - except AuthenticationError as e: - print(f"Authentication failed: {e}") - print("Please check your API key") - - except RateLimitError as e: - print(f"Rate limit exceeded: {e}") - if e.retry_after: - print(f"Retry after {e.retry_after} seconds") - - except APIError as e: - print(f"API error: {e.message}") - print(f"Status code: {e.status_code}") - - except DifyClientError as e: - print(f"Dify client error: {e}") - - except Exception as e: - print(f"Unexpected error: {e}") - - -def example_completion_with_files(): - """Example of completion with file upload.""" - api_key = "your-api-key-here" - - with CompletionClient(api_key) as client: - # Upload an image file first - file_path = "path/to/your/image.jpg" - - try: - with open(file_path, "rb") as f: - files = {"file": (Path(file_path).name, f, "image/jpeg")} - upload_response = client.file_upload("user-123", files) - upload_response.raise_for_status() - - file_id = upload_response.json().get("id") - print(f"File uploaded with ID: {file_id}") - - # Use the uploaded file in completion - files_list = [{"type": "image", "transfer_method": "local_file", "upload_file_id": file_id}] - - completion_response = client.create_completion_message( - inputs={"query": "Describe this image"}, response_mode="blocking", user="user-123", files=files_list - ) - - result = completion_response.json() - print(f"Completion result: {result.get('answer')}") - - except FileNotFoundError: - print(f"File not found: {file_path}") - except Exception as e: - print(f"Error during file upload/completion: {e}") - - -def example_dataset_management(): - """Example of dataset management operations.""" - api_key = "your-api-key-here" - - with KnowledgeBaseClient(api_key) as kb_client: - try: - # Create a new dataset - create_response = kb_client.create_dataset(name="My Test Dataset") - create_response.raise_for_status() - - dataset_id = create_response.json().get("id") - print(f"Created dataset with ID: {dataset_id}") - - # Create a client with the dataset ID - dataset_client = KnowledgeBaseClient(api_key, dataset_id=dataset_id) - - # Add a document by text - doc_response = dataset_client.create_document_by_text( - name="Test Document", text="This is a test document for the knowledge base." - ) - doc_response.raise_for_status() - - document_id = doc_response.json().get("document", {}).get("id") - print(f"Created document with ID: {document_id}") - - # List documents - list_response = dataset_client.list_documents() - list_response.raise_for_status() - - documents = list_response.json().get("data", []) - print(f"Dataset contains {len(documents)} documents") - - # Update dataset configuration - update_response = dataset_client.update_dataset( - name="Updated Dataset Name", description="Updated description", indexing_technique="high_quality" - ) - update_response.raise_for_status() - - print("Dataset updated successfully") - - except Exception as e: - print(f"Dataset management error: {e}") - - -async def example_async_chat(): - """Example of async chat usage.""" - api_key = "your-api-key-here" - - try: - async with AsyncChatClient(api_key) as client: - # Create chat message - response = await client.create_chat_message( - inputs={}, query="What's the weather like?", user="user-456", response_mode="blocking" - ) - - result = response.json() - print(f"Async response: {result.get('answer')}") - - # Get conversations - conversations = await client.get_conversations("user-456") - conversations.raise_for_status() - - conv_data = conversations.json() - print(f"Found {len(conv_data.get('data', []))} conversations") - - except Exception as e: - print(f"Async chat error: {e}") - - -def example_streaming_response(): - """Example of handling streaming responses.""" - api_key = "your-api-key-here" - - with ChatClient(api_key) as client: - try: - response = client.create_chat_message( - inputs={}, query="Tell me a story", user="user-789", response_mode="streaming" - ) - - print("Streaming response:") - for line in response.iter_lines(decode_unicode=True): - if line.startswith("data:"): - data = line[5:].strip() - if data: - import json - - try: - chunk = json.loads(data) - answer = chunk.get("answer", "") - if answer: - print(answer, end="", flush=True) - except json.JSONDecodeError: - continue - print() # New line after streaming - - except Exception as e: - print(f"Streaming error: {e}") - - -def example_application_info(): - """Example of getting application information.""" - api_key = "your-api-key-here" - - with DifyClient(api_key) as client: - try: - # Get app info - info_response = client.get_app_info() - info_response.raise_for_status() - - app_info = info_response.json() - print(f"App name: {app_info.get('name')}") - print(f"App mode: {app_info.get('mode')}") - print(f"App tags: {app_info.get('tags', [])}") - - # Get app parameters - params_response = client.get_application_parameters("user-123") - params_response.raise_for_status() - - params = params_response.json() - print(f"Opening statement: {params.get('opening_statement')}") - print(f"Suggested questions: {params.get('suggested_questions', [])}") - - except Exception as e: - print(f"App info error: {e}") - - -def main(): - """Run all examples.""" - setup_logging() - - print("=== Dify Python SDK Advanced Usage Examples ===\n") - - print("1. Chat with Error Handling:") - example_chat_with_error_handling() - print() - - print("2. Completion with Files:") - example_completion_with_files() - print() - - print("3. Dataset Management:") - example_dataset_management() - print() - - print("4. Async Chat:") - asyncio.run(example_async_chat()) - print() - - print("5. Streaming Response:") - example_streaming_response() - print() - - print("6. Application Info:") - example_application_info() - print() - - print("All examples completed!") - - -if __name__ == "__main__": - main() diff --git a/sdks/python-client/pyproject.toml b/sdks/python-client/pyproject.toml deleted file mode 100644 index a25cb9150c..0000000000 --- a/sdks/python-client/pyproject.toml +++ /dev/null @@ -1,43 +0,0 @@ -[project] -name = "dify-client" -version = "0.1.12" -description = "A package for interacting with the Dify Service-API" -readme = "README.md" -requires-python = ">=3.10" -dependencies = [ - "httpx[http2]>=0.27.0", - "aiofiles>=23.0.0", -] -authors = [ - {name = "Dify", email = "hello@dify.ai"} -] -license = {text = "MIT"} -keywords = ["dify", "nlp", "ai", "language-processing"] -classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", -] - -[project.urls] -Homepage = "https://github.com/langgenius/dify" - -[project.optional-dependencies] -dev = [ - "pytest>=7.0.0", - "pytest-asyncio>=0.21.0", -] - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["dify_client"] - -[tool.pytest.ini_options] -testpaths = ["tests"] -python_files = ["test_*.py"] -python_classes = ["Test*"] -python_functions = ["test_*"] -asyncio_mode = "auto" diff --git a/sdks/python-client/tests/test_async_client.py b/sdks/python-client/tests/test_async_client.py deleted file mode 100644 index 4f5001866f..0000000000 --- a/sdks/python-client/tests/test_async_client.py +++ /dev/null @@ -1,250 +0,0 @@ -#!/usr/bin/env python3 -""" -Test suite for async client implementation in the Python SDK. - -This test validates the async/await functionality using httpx.AsyncClient -and ensures API parity with sync clients. -""" - -import unittest -from unittest.mock import Mock, patch, AsyncMock - -from dify_client.async_client import ( - AsyncDifyClient, - AsyncChatClient, - AsyncCompletionClient, - AsyncWorkflowClient, - AsyncWorkspaceClient, - AsyncKnowledgeBaseClient, -) - - -class TestAsyncAPIParity(unittest.TestCase): - """Test that async clients have API parity with sync clients.""" - - def test_dify_client_api_parity(self): - """Test AsyncDifyClient has same methods as DifyClient.""" - from dify_client import DifyClient - - sync_methods = {name for name in dir(DifyClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncDifyClient) if not name.startswith("_")} - - # aclose is async-specific, close is sync-specific - sync_methods.discard("close") - async_methods.discard("aclose") - - # Verify parity - self.assertEqual(sync_methods, async_methods, "API parity mismatch for DifyClient") - - def test_chat_client_api_parity(self): - """Test AsyncChatClient has same methods as ChatClient.""" - from dify_client import ChatClient - - sync_methods = {name for name in dir(ChatClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncChatClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for ChatClient") - - def test_completion_client_api_parity(self): - """Test AsyncCompletionClient has same methods as CompletionClient.""" - from dify_client import CompletionClient - - sync_methods = {name for name in dir(CompletionClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncCompletionClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for CompletionClient") - - def test_workflow_client_api_parity(self): - """Test AsyncWorkflowClient has same methods as WorkflowClient.""" - from dify_client import WorkflowClient - - sync_methods = {name for name in dir(WorkflowClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncWorkflowClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for WorkflowClient") - - def test_workspace_client_api_parity(self): - """Test AsyncWorkspaceClient has same methods as WorkspaceClient.""" - from dify_client import WorkspaceClient - - sync_methods = {name for name in dir(WorkspaceClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncWorkspaceClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for WorkspaceClient") - - def test_knowledge_base_client_api_parity(self): - """Test AsyncKnowledgeBaseClient has same methods as KnowledgeBaseClient.""" - from dify_client import KnowledgeBaseClient - - sync_methods = {name for name in dir(KnowledgeBaseClient) if not name.startswith("_")} - async_methods = {name for name in dir(AsyncKnowledgeBaseClient) if not name.startswith("_")} - - sync_methods.discard("close") - async_methods.discard("aclose") - - self.assertEqual(sync_methods, async_methods, "API parity mismatch for KnowledgeBaseClient") - - -class TestAsyncClientMocked(unittest.IsolatedAsyncioTestCase): - """Test async client with mocked httpx.AsyncClient.""" - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_client_initialization(self, mock_httpx_async_client): - """Test async client initializes with httpx.AsyncClient.""" - mock_client_instance = AsyncMock() - mock_httpx_async_client.return_value = mock_client_instance - - client = AsyncDifyClient("test-key", "https://api.dify.ai/v1") - - # Verify httpx.AsyncClient was called - mock_httpx_async_client.assert_called_once() - self.assertEqual(client.api_key, "test-key") - - await client.aclose() - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_context_manager(self, mock_httpx_async_client): - """Test async context manager works.""" - mock_client_instance = AsyncMock() - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncDifyClient("test-key") as client: - self.assertEqual(client.api_key, "test-key") - - # Verify aclose was called - mock_client_instance.aclose.assert_called_once() - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_send_request(self, mock_httpx_async_client): - """Test async _send_request method.""" - mock_response = AsyncMock() - mock_response.json = AsyncMock(return_value={"result": "success"}) - mock_response.status_code = 200 - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncDifyClient("test-key") as client: - response = await client._send_request("GET", "/test") - - # Verify request was called - mock_client_instance.request.assert_called_once() - call_args = mock_client_instance.request.call_args - - # Verify parameters - self.assertEqual(call_args[0][0], "GET") - self.assertEqual(call_args[0][1], "/test") - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_chat_client(self, mock_httpx_async_client): - """Test AsyncChatClient functionality.""" - mock_response = AsyncMock() - mock_response.text = '{"answer": "Hello!"}' - mock_response.json = AsyncMock(return_value={"answer": "Hello!"}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncChatClient("test-key") as client: - response = await client.create_chat_message({}, "Hi", "user123") - self.assertIn("answer", response.text) - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_completion_client(self, mock_httpx_async_client): - """Test AsyncCompletionClient functionality.""" - mock_response = AsyncMock() - mock_response.text = '{"answer": "Response"}' - mock_response.json = AsyncMock(return_value={"answer": "Response"}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncCompletionClient("test-key") as client: - response = await client.create_completion_message({"query": "test"}, "blocking", "user123") - self.assertIn("answer", response.text) - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_workflow_client(self, mock_httpx_async_client): - """Test AsyncWorkflowClient functionality.""" - mock_response = AsyncMock() - mock_response.json = AsyncMock(return_value={"result": "success"}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncWorkflowClient("test-key") as client: - response = await client.run({"input": "test"}, "blocking", "user123") - data = await response.json() - self.assertEqual(data["result"], "success") - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_workspace_client(self, mock_httpx_async_client): - """Test AsyncWorkspaceClient functionality.""" - mock_response = AsyncMock() - mock_response.json = AsyncMock(return_value={"data": []}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncWorkspaceClient("test-key") as client: - response = await client.get_available_models("llm") - data = await response.json() - self.assertIn("data", data) - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_async_knowledge_base_client(self, mock_httpx_async_client): - """Test AsyncKnowledgeBaseClient functionality.""" - mock_response = AsyncMock() - mock_response.json = AsyncMock(return_value={"data": [], "total": 0}) - - mock_client_instance = AsyncMock() - mock_client_instance.request = AsyncMock(return_value=mock_response) - mock_httpx_async_client.return_value = mock_client_instance - - async with AsyncKnowledgeBaseClient("test-key") as client: - response = await client.list_datasets() - data = await response.json() - self.assertIn("data", data) - - @patch("dify_client.async_client.httpx.AsyncClient") - async def test_all_async_client_classes(self, mock_httpx_async_client): - """Test all async client classes work with httpx.AsyncClient.""" - mock_client_instance = AsyncMock() - mock_httpx_async_client.return_value = mock_client_instance - - clients = [ - AsyncDifyClient("key"), - AsyncChatClient("key"), - AsyncCompletionClient("key"), - AsyncWorkflowClient("key"), - AsyncWorkspaceClient("key"), - AsyncKnowledgeBaseClient("key"), - ] - - # Verify httpx.AsyncClient was called for each - self.assertEqual(mock_httpx_async_client.call_count, 6) - - # Clean up - for client in clients: - await client.aclose() - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_client.py b/sdks/python-client/tests/test_client.py deleted file mode 100644 index b0d2f8ba23..0000000000 --- a/sdks/python-client/tests/test_client.py +++ /dev/null @@ -1,489 +0,0 @@ -import os -import time -import unittest -from unittest.mock import Mock, patch, mock_open - -from dify_client.client import ( - ChatClient, - CompletionClient, - DifyClient, - KnowledgeBaseClient, -) - -API_KEY = os.environ.get("API_KEY") -APP_ID = os.environ.get("APP_ID") -API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.dify.ai/v1") -FILE_PATH_BASE = os.path.dirname(__file__) - - -class TestKnowledgeBaseClient(unittest.TestCase): - def setUp(self): - self.api_key = "test-api-key" - self.base_url = "https://api.dify.ai/v1" - self.knowledge_base_client = KnowledgeBaseClient(self.api_key, base_url=self.base_url) - self.README_FILE_PATH = os.path.abspath(os.path.join(FILE_PATH_BASE, "../README.md")) - self.dataset_id = "test-dataset-id" - self.document_id = "test-document-id" - self.segment_id = "test-segment-id" - self.batch_id = "test-batch-id" - - def _get_dataset_kb_client(self): - return KnowledgeBaseClient(self.api_key, base_url=self.base_url, dataset_id=self.dataset_id) - - @patch("dify_client.client.httpx.Client") - def test_001_create_dataset(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.json.return_value = {"id": self.dataset_id, "name": "test_dataset"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Re-create client with mocked httpx - self.knowledge_base_client = KnowledgeBaseClient(self.api_key, base_url=self.base_url) - - response = self.knowledge_base_client.create_dataset(name="test_dataset") - data = response.json() - self.assertIn("id", data) - self.assertEqual("test_dataset", data["name"]) - - # the following tests require to be executed in order because they use - # the dataset/document/segment ids from the previous test - self._test_002_list_datasets() - self._test_003_create_document_by_text() - self._test_004_update_document_by_text() - self._test_006_update_document_by_file() - self._test_007_list_documents() - self._test_008_delete_document() - self._test_009_create_document_by_file() - self._test_010_add_segments() - self._test_011_query_segments() - self._test_012_update_document_segment() - self._test_013_delete_document_segment() - self._test_014_delete_dataset() - - def _test_002_list_datasets(self): - # Mock the response - using the already mocked client from test_001_create_dataset - mock_response = Mock() - mock_response.json.return_value = {"data": [], "total": 0} - mock_response.status_code = 200 - self.knowledge_base_client._client.request.return_value = mock_response - - response = self.knowledge_base_client.list_datasets() - data = response.json() - self.assertIn("data", data) - self.assertIn("total", data) - - def _test_003_create_document_by_text(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.create_document_by_text("test_document", "test_text") - data = response.json() - self.assertIn("document", data) - - def _test_004_update_document_by_text(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.update_document_by_text(self.document_id, "test_document_updated", "test_text_updated") - data = response.json() - self.assertIn("document", data) - self.assertIn("batch", data) - - def _test_006_update_document_by_file(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.update_document_by_file(self.document_id, self.README_FILE_PATH) - data = response.json() - self.assertIn("document", data) - self.assertIn("batch", data) - - def _test_007_list_documents(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"data": []} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.list_documents() - data = response.json() - self.assertIn("data", data) - - def _test_008_delete_document(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.delete_document(self.document_id) - data = response.json() - self.assertIn("result", data) - self.assertEqual("success", data["result"]) - - def _test_009_create_document_by_file(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"document": {"id": self.document_id}, "batch": self.batch_id} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.create_document_by_file(self.README_FILE_PATH) - data = response.json() - self.assertIn("document", data) - - def _test_010_add_segments(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"data": [{"id": self.segment_id, "content": "test text segment 1"}]} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.add_segments(self.document_id, [{"content": "test text segment 1"}]) - data = response.json() - self.assertIn("data", data) - self.assertGreater(len(data["data"]), 0) - - def _test_011_query_segments(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"data": [{"id": self.segment_id, "content": "test text segment 1"}]} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.query_segments(self.document_id) - data = response.json() - self.assertIn("data", data) - self.assertGreater(len(data["data"]), 0) - - def _test_012_update_document_segment(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"data": {"id": self.segment_id, "content": "test text segment 1 updated"}} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.update_document_segment( - self.document_id, - self.segment_id, - {"content": "test text segment 1 updated"}, - ) - data = response.json() - self.assertIn("data", data) - self.assertEqual("test text segment 1 updated", data["data"]["content"]) - - def _test_013_delete_document_segment(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - client._client.request.return_value = mock_response - - response = client.delete_document_segment(self.document_id, self.segment_id) - data = response.json() - self.assertIn("result", data) - self.assertEqual("success", data["result"]) - - def _test_014_delete_dataset(self): - client = self._get_dataset_kb_client() - # Mock the response - mock_response = Mock() - mock_response.status_code = 204 - client._client.request.return_value = mock_response - - response = client.delete_dataset() - self.assertEqual(204, response.status_code) - - -class TestChatClient(unittest.TestCase): - @patch("dify_client.client.httpx.Client") - def setUp(self, mock_httpx_client): - self.api_key = "test-api-key" - self.chat_client = ChatClient(self.api_key) - - # Set up default mock response for the client - mock_response = Mock() - mock_response.text = '{"answer": "Hello! This is a test response."}' - mock_response.json.return_value = {"answer": "Hello! This is a test response."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - @patch("dify_client.client.httpx.Client") - def test_create_chat_message(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "Hello! This is a test response."}' - mock_response.json.return_value = {"answer": "Hello! This is a test response."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - response = chat_client.create_chat_message({}, "Hello, World!", "test_user") - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_create_chat_message_with_vision_model_by_remote_url(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "I can see this is a test image description."}' - mock_response.json.return_value = {"answer": "I can see this is a test image description."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - files = [{"type": "image", "transfer_method": "remote_url", "url": "https://example.com/test-image.jpg"}] - response = chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files) - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_create_chat_message_with_vision_model_by_local_file(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "I can see this is a test uploaded image."}' - mock_response.json.return_value = {"answer": "I can see this is a test uploaded image."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - files = [ - { - "type": "image", - "transfer_method": "local_file", - "upload_file_id": "test-file-id", - } - ] - response = chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files) - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_get_conversation_messages(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "Here are the conversation messages."}' - mock_response.json.return_value = {"answer": "Here are the conversation messages."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - response = chat_client.get_conversation_messages("test_user", "test-conversation-id") - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_get_conversations(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"data": [{"id": "conv1", "name": "Test Conversation"}]}' - mock_response.json.return_value = {"data": [{"id": "conv1", "name": "Test Conversation"}]} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - chat_client = ChatClient(self.api_key) - response = chat_client.get_conversations("test_user") - self.assertIn("data", response.text) - - -class TestCompletionClient(unittest.TestCase): - @patch("dify_client.client.httpx.Client") - def setUp(self, mock_httpx_client): - self.api_key = "test-api-key" - self.completion_client = CompletionClient(self.api_key) - - # Set up default mock response for the client - mock_response = Mock() - mock_response.text = '{"answer": "This is a test completion response."}' - mock_response.json.return_value = {"answer": "This is a test completion response."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - @patch("dify_client.client.httpx.Client") - def test_create_completion_message(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "The weather today is sunny with a temperature of 75°F."}' - mock_response.json.return_value = {"answer": "The weather today is sunny with a temperature of 75°F."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - completion_client = CompletionClient(self.api_key) - response = completion_client.create_completion_message( - {"query": "What's the weather like today?"}, "blocking", "test_user" - ) - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_create_completion_message_with_vision_model_by_remote_url(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "This is a test image description from completion API."}' - mock_response.json.return_value = {"answer": "This is a test image description from completion API."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - completion_client = CompletionClient(self.api_key) - files = [{"type": "image", "transfer_method": "remote_url", "url": "https://example.com/test-image.jpg"}] - response = completion_client.create_completion_message( - {"query": "Describe the picture."}, "blocking", "test_user", files - ) - self.assertIn("answer", response.text) - - @patch("dify_client.client.httpx.Client") - def test_create_completion_message_with_vision_model_by_local_file(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"answer": "This is a test uploaded image description from completion API."}' - mock_response.json.return_value = {"answer": "This is a test uploaded image description from completion API."} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - completion_client = CompletionClient(self.api_key) - files = [ - { - "type": "image", - "transfer_method": "local_file", - "upload_file_id": "test-file-id", - } - ] - response = completion_client.create_completion_message( - {"query": "Describe the picture."}, "blocking", "test_user", files - ) - self.assertIn("answer", response.text) - - -class TestDifyClient(unittest.TestCase): - @patch("dify_client.client.httpx.Client") - def setUp(self, mock_httpx_client): - self.api_key = "test-api-key" - self.dify_client = DifyClient(self.api_key) - - # Set up default mock response for the client - mock_response = Mock() - mock_response.text = '{"result": "success"}' - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - @patch("dify_client.client.httpx.Client") - def test_message_feedback(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"success": true}' - mock_response.json.return_value = {"success": True} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - dify_client = DifyClient(self.api_key) - response = dify_client.message_feedback("test-message-id", "like", "test_user") - self.assertIn("success", response.text) - - @patch("dify_client.client.httpx.Client") - def test_get_application_parameters(self, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"user_input_form": [{"field": "text", "label": "Input"}]}' - mock_response.json.return_value = {"user_input_form": [{"field": "text", "label": "Input"}]} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - dify_client = DifyClient(self.api_key) - response = dify_client.get_application_parameters("test_user") - self.assertIn("user_input_form", response.text) - - @patch("dify_client.client.httpx.Client") - @patch("builtins.open", new_callable=mock_open, read_data=b"fake image data") - def test_file_upload(self, mock_file_open, mock_httpx_client): - # Mock the HTTP response - mock_response = Mock() - mock_response.text = '{"name": "panda.jpeg", "id": "test-file-id"}' - mock_response.json.return_value = {"name": "panda.jpeg", "id": "test-file-id"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - # Create client with mocked httpx - dify_client = DifyClient(self.api_key) - file_path = "/path/to/test/panda.jpeg" - file_name = "panda.jpeg" - mime_type = "image/jpeg" - - with open(file_path, "rb") as file: - files = {"file": (file_name, file, mime_type)} - response = dify_client.file_upload("test_user", files) - self.assertIn("name", response.text) - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_exceptions.py b/sdks/python-client/tests/test_exceptions.py deleted file mode 100644 index eb44895749..0000000000 --- a/sdks/python-client/tests/test_exceptions.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Tests for custom exceptions.""" - -import unittest -from dify_client.exceptions import ( - DifyClientError, - APIError, - AuthenticationError, - RateLimitError, - ValidationError, - NetworkError, - TimeoutError, - FileUploadError, - DatasetError, - WorkflowError, -) - - -class TestExceptions(unittest.TestCase): - """Test custom exception classes.""" - - def test_base_exception(self): - """Test base DifyClientError.""" - error = DifyClientError("Test message", 500, {"error": "details"}) - self.assertEqual(str(error), "Test message") - self.assertEqual(error.status_code, 500) - self.assertEqual(error.response, {"error": "details"}) - - def test_api_error(self): - """Test APIError.""" - error = APIError("API failed", 400) - self.assertEqual(error.status_code, 400) - self.assertEqual(error.message, "API failed") - - def test_authentication_error(self): - """Test AuthenticationError.""" - error = AuthenticationError("Invalid API key") - self.assertEqual(str(error), "Invalid API key") - - def test_rate_limit_error(self): - """Test RateLimitError.""" - error = RateLimitError("Rate limited", retry_after=60) - self.assertEqual(error.retry_after, 60) - - error_default = RateLimitError() - self.assertEqual(error_default.retry_after, None) - - def test_validation_error(self): - """Test ValidationError.""" - error = ValidationError("Invalid parameter") - self.assertEqual(str(error), "Invalid parameter") - - def test_network_error(self): - """Test NetworkError.""" - error = NetworkError("Connection failed") - self.assertEqual(str(error), "Connection failed") - - def test_timeout_error(self): - """Test TimeoutError.""" - error = TimeoutError("Request timed out") - self.assertEqual(str(error), "Request timed out") - - def test_file_upload_error(self): - """Test FileUploadError.""" - error = FileUploadError("Upload failed") - self.assertEqual(str(error), "Upload failed") - - def test_dataset_error(self): - """Test DatasetError.""" - error = DatasetError("Dataset operation failed") - self.assertEqual(str(error), "Dataset operation failed") - - def test_workflow_error(self): - """Test WorkflowError.""" - error = WorkflowError("Workflow failed") - self.assertEqual(str(error), "Workflow failed") - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_httpx_migration.py b/sdks/python-client/tests/test_httpx_migration.py deleted file mode 100644 index cf26de6eba..0000000000 --- a/sdks/python-client/tests/test_httpx_migration.py +++ /dev/null @@ -1,333 +0,0 @@ -#!/usr/bin/env python3 -""" -Test suite for httpx migration in the Python SDK. - -This test validates that the migration from requests to httpx maintains -backward compatibility and proper resource management. -""" - -import unittest -from unittest.mock import Mock, patch - -from dify_client import ( - DifyClient, - ChatClient, - CompletionClient, - WorkflowClient, - WorkspaceClient, - KnowledgeBaseClient, -) - - -class TestHttpxMigrationMocked(unittest.TestCase): - """Test cases for httpx migration with mocked requests.""" - - def setUp(self): - """Set up test fixtures.""" - self.api_key = "test-api-key" - self.base_url = "https://api.dify.ai/v1" - - @patch("dify_client.client.httpx.Client") - def test_client_initialization(self, mock_httpx_client): - """Test that client initializes with httpx.Client.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - - # Verify httpx.Client was called with correct parameters - mock_httpx_client.assert_called_once() - call_kwargs = mock_httpx_client.call_args[1] - self.assertEqual(call_kwargs["base_url"], self.base_url) - - # Verify client properties - self.assertEqual(client.api_key, self.api_key) - self.assertEqual(client.base_url, self.base_url) - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_context_manager_support(self, mock_httpx_client): - """Test that client works as context manager.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - with DifyClient(self.api_key, self.base_url) as client: - self.assertEqual(client.api_key, self.api_key) - - # Verify close was called - mock_client_instance.close.assert_called_once() - - @patch("dify_client.client.httpx.Client") - def test_manual_close(self, mock_httpx_client): - """Test manual close() method.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - client.close() - - # Verify close was called - mock_client_instance.close.assert_called_once() - - @patch("dify_client.client.httpx.Client") - def test_send_request_httpx_compatibility(self, mock_httpx_client): - """Test _send_request uses httpx.Client.request properly.""" - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - response = client._send_request("GET", "/test-endpoint") - - # Verify httpx.Client.request was called correctly - mock_client_instance.request.assert_called_once() - call_args = mock_client_instance.request.call_args - - # Verify method and endpoint - self.assertEqual(call_args[0][0], "GET") - self.assertEqual(call_args[0][1], "/test-endpoint") - - # Verify headers contain authorization - headers = call_args[1]["headers"] - self.assertEqual(headers["Authorization"], f"Bearer {self.api_key}") - self.assertEqual(headers["Content-Type"], "application/json") - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_response_compatibility(self, mock_httpx_client): - """Test httpx.Response is compatible with requests.Response API.""" - mock_response = Mock() - mock_response.json.return_value = {"key": "value"} - mock_response.text = '{"key": "value"}' - mock_response.content = b'{"key": "value"}' - mock_response.status_code = 200 - mock_response.headers = {"Content-Type": "application/json"} - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - response = client._send_request("GET", "/test") - - # Verify all common response methods work - self.assertEqual(response.json(), {"key": "value"}) - self.assertEqual(response.text, '{"key": "value"}') - self.assertEqual(response.content, b'{"key": "value"}') - self.assertEqual(response.status_code, 200) - self.assertEqual(response.headers["Content-Type"], "application/json") - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_all_client_classes_use_httpx(self, mock_httpx_client): - """Test that all client classes properly use httpx.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - clients = [ - DifyClient(self.api_key, self.base_url), - ChatClient(self.api_key, self.base_url), - CompletionClient(self.api_key, self.base_url), - WorkflowClient(self.api_key, self.base_url), - WorkspaceClient(self.api_key, self.base_url), - KnowledgeBaseClient(self.api_key, self.base_url), - ] - - # Verify httpx.Client was called for each client - self.assertEqual(mock_httpx_client.call_count, 6) - - # Clean up - for client in clients: - client.close() - - @patch("dify_client.client.httpx.Client") - def test_json_parameter_handling(self, mock_httpx_client): - """Test that json parameter is passed correctly.""" - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 # Add status_code attribute - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - test_data = {"key": "value", "number": 123} - - client._send_request("POST", "/test", json=test_data) - - # Verify json parameter was passed - call_args = mock_client_instance.request.call_args - self.assertEqual(call_args[1]["json"], test_data) - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_params_parameter_handling(self, mock_httpx_client): - """Test that params parameter is passed correctly.""" - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 # Add status_code attribute - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - client = DifyClient(self.api_key, self.base_url) - test_params = {"page": 1, "limit": 20} - - client._send_request("GET", "/test", params=test_params) - - # Verify params parameter was passed - call_args = mock_client_instance.request.call_args - self.assertEqual(call_args[1]["params"], test_params) - - client.close() - - @patch("dify_client.client.httpx.Client") - def test_inheritance_chain(self, mock_httpx_client): - """Test that inheritance chain is maintained.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - # ChatClient inherits from DifyClient - chat_client = ChatClient(self.api_key, self.base_url) - self.assertIsInstance(chat_client, DifyClient) - - # CompletionClient inherits from DifyClient - completion_client = CompletionClient(self.api_key, self.base_url) - self.assertIsInstance(completion_client, DifyClient) - - # WorkflowClient inherits from DifyClient - workflow_client = WorkflowClient(self.api_key, self.base_url) - self.assertIsInstance(workflow_client, DifyClient) - - # Clean up - chat_client.close() - completion_client.close() - workflow_client.close() - - @patch("dify_client.client.httpx.Client") - def test_nested_context_managers(self, mock_httpx_client): - """Test nested context managers work correctly.""" - mock_client_instance = Mock() - mock_httpx_client.return_value = mock_client_instance - - with DifyClient(self.api_key, self.base_url) as client1: - with ChatClient(self.api_key, self.base_url) as client2: - self.assertEqual(client1.api_key, self.api_key) - self.assertEqual(client2.api_key, self.api_key) - - # Both close methods should have been called - self.assertEqual(mock_client_instance.close.call_count, 2) - - -class TestChatClientHttpx(unittest.TestCase): - """Test ChatClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_create_chat_message_httpx(self, mock_httpx_client): - """Test create_chat_message works with httpx.""" - mock_response = Mock() - mock_response.text = '{"answer": "Hello!"}' - mock_response.json.return_value = {"answer": "Hello!"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with ChatClient("test-key") as client: - response = client.create_chat_message({}, "Hi", "user123") - self.assertIn("answer", response.text) - self.assertEqual(response.json()["answer"], "Hello!") - - -class TestCompletionClientHttpx(unittest.TestCase): - """Test CompletionClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_create_completion_message_httpx(self, mock_httpx_client): - """Test create_completion_message works with httpx.""" - mock_response = Mock() - mock_response.text = '{"answer": "Response"}' - mock_response.json.return_value = {"answer": "Response"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with CompletionClient("test-key") as client: - response = client.create_completion_message({"query": "test"}, "blocking", "user123") - self.assertIn("answer", response.text) - - -class TestKnowledgeBaseClientHttpx(unittest.TestCase): - """Test KnowledgeBaseClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_list_datasets_httpx(self, mock_httpx_client): - """Test list_datasets works with httpx.""" - mock_response = Mock() - mock_response.json.return_value = {"data": [], "total": 0} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with KnowledgeBaseClient("test-key") as client: - response = client.list_datasets() - data = response.json() - self.assertIn("data", data) - self.assertIn("total", data) - - -class TestWorkflowClientHttpx(unittest.TestCase): - """Test WorkflowClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_run_workflow_httpx(self, mock_httpx_client): - """Test run workflow works with httpx.""" - mock_response = Mock() - mock_response.json.return_value = {"result": "success"} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with WorkflowClient("test-key") as client: - response = client.run({"input": "test"}, "blocking", "user123") - self.assertEqual(response.json()["result"], "success") - - -class TestWorkspaceClientHttpx(unittest.TestCase): - """Test WorkspaceClient specific httpx integration.""" - - @patch("dify_client.client.httpx.Client") - def test_get_available_models_httpx(self, mock_httpx_client): - """Test get_available_models works with httpx.""" - mock_response = Mock() - mock_response.json.return_value = {"data": []} - mock_response.status_code = 200 - - mock_client_instance = Mock() - mock_client_instance.request.return_value = mock_response - mock_httpx_client.return_value = mock_client_instance - - with WorkspaceClient("test-key") as client: - response = client.get_available_models("llm") - self.assertIn("data", response.json()) - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_integration.py b/sdks/python-client/tests/test_integration.py deleted file mode 100644 index 6f38c5de56..0000000000 --- a/sdks/python-client/tests/test_integration.py +++ /dev/null @@ -1,539 +0,0 @@ -"""Integration tests with proper mocking.""" - -import unittest -from unittest.mock import Mock, patch, MagicMock -import json -import httpx -from dify_client import ( - DifyClient, - ChatClient, - CompletionClient, - WorkflowClient, - KnowledgeBaseClient, - WorkspaceClient, -) -from dify_client.exceptions import ( - APIError, - AuthenticationError, - RateLimitError, - ValidationError, -) - - -class TestDifyClientIntegration(unittest.TestCase): - """Integration tests for DifyClient with mocked HTTP responses.""" - - def setUp(self): - self.api_key = "test_api_key" - self.base_url = "https://api.dify.ai/v1" - self.client = DifyClient(api_key=self.api_key, base_url=self.base_url, enable_logging=False) - - @patch("httpx.Client.request") - def test_get_app_info_integration(self, mock_request): - """Test get_app_info integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "app_123", - "name": "Test App", - "description": "A test application", - "mode": "chat", - } - mock_request.return_value = mock_response - - response = self.client.get_app_info() - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["id"], "app_123") - self.assertEqual(data["name"], "Test App") - mock_request.assert_called_once_with( - "GET", - "/info", - json=None, - params=None, - headers={ - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - }, - ) - - @patch("httpx.Client.request") - def test_get_application_parameters_integration(self, mock_request): - """Test get_application_parameters integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "opening_statement": "Hello! How can I help you?", - "suggested_questions": ["What is AI?", "How does this work?"], - "speech_to_text": {"enabled": True}, - "text_to_speech": {"enabled": False}, - } - mock_request.return_value = mock_response - - response = self.client.get_application_parameters("user_123") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["opening_statement"], "Hello! How can I help you?") - self.assertEqual(len(data["suggested_questions"]), 2) - mock_request.assert_called_once_with( - "GET", - "/parameters", - json=None, - params={"user": "user_123"}, - headers={ - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - }, - ) - - @patch("httpx.Client.request") - def test_file_upload_integration(self, mock_request): - """Test file_upload integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "file_123", - "name": "test.txt", - "size": 1024, - "mime_type": "text/plain", - } - mock_request.return_value = mock_response - - files = {"file": ("test.txt", "test content", "text/plain")} - response = self.client.file_upload("user_123", files) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["id"], "file_123") - self.assertEqual(data["name"], "test.txt") - - @patch("httpx.Client.request") - def test_message_feedback_integration(self, mock_request): - """Test message_feedback integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = {"success": True} - mock_request.return_value = mock_response - - response = self.client.message_feedback("msg_123", "like", "user_123") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertTrue(data["success"]) - mock_request.assert_called_once_with( - "POST", - "/messages/msg_123/feedbacks", - json={"rating": "like", "user": "user_123"}, - params=None, - headers={ - "Authorization": "Bearer test_api_key", - "Content-Type": "application/json", - }, - ) - - -class TestChatClientIntegration(unittest.TestCase): - """Integration tests for ChatClient.""" - - def setUp(self): - self.client = ChatClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_create_chat_message_blocking(self, mock_request): - """Test create_chat_message with blocking response.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "msg_123", - "answer": "Hello! How can I help you today?", - "conversation_id": "conv_123", - "created_at": 1234567890, - } - mock_request.return_value = mock_response - - response = self.client.create_chat_message( - inputs={"query": "Hello"}, - query="Hello, AI!", - user="user_123", - response_mode="blocking", - ) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["answer"], "Hello! How can I help you today?") - self.assertEqual(data["conversation_id"], "conv_123") - - @patch("httpx.Client.request") - def test_create_chat_message_streaming(self, mock_request): - """Test create_chat_message with streaming response.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.iter_lines.return_value = [ - b'data: {"answer": "Hello"}', - b'data: {"answer": " world"}', - b'data: {"answer": "!"}', - ] - mock_request.return_value = mock_response - - response = self.client.create_chat_message(inputs={}, query="Hello", user="user_123", response_mode="streaming") - - self.assertEqual(response.status_code, 200) - lines = list(response.iter_lines()) - self.assertEqual(len(lines), 3) - - @patch("httpx.Client.request") - def test_get_conversations_integration(self, mock_request): - """Test get_conversations integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "data": [ - {"id": "conv_1", "name": "Conversation 1"}, - {"id": "conv_2", "name": "Conversation 2"}, - ], - "has_more": False, - "limit": 20, - } - mock_request.return_value = mock_response - - response = self.client.get_conversations("user_123", limit=20) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["data"]), 2) - self.assertEqual(data["data"][0]["name"], "Conversation 1") - - @patch("httpx.Client.request") - def test_get_conversation_messages_integration(self, mock_request): - """Test get_conversation_messages integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "data": [ - {"id": "msg_1", "role": "user", "content": "Hello"}, - {"id": "msg_2", "role": "assistant", "content": "Hi there!"}, - ] - } - mock_request.return_value = mock_response - - response = self.client.get_conversation_messages("user_123", conversation_id="conv_123") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["data"]), 2) - self.assertEqual(data["data"][0]["role"], "user") - - -class TestCompletionClientIntegration(unittest.TestCase): - """Integration tests for CompletionClient.""" - - def setUp(self): - self.client = CompletionClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_create_completion_message_blocking(self, mock_request): - """Test create_completion_message with blocking response.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "comp_123", - "answer": "This is a completion response.", - "created_at": 1234567890, - } - mock_request.return_value = mock_response - - response = self.client.create_completion_message( - inputs={"prompt": "Complete this sentence"}, - response_mode="blocking", - user="user_123", - ) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["answer"], "This is a completion response.") - - @patch("httpx.Client.request") - def test_create_completion_message_with_files(self, mock_request): - """Test create_completion_message with files.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "comp_124", - "answer": "I can see the image shows...", - "files": [{"id": "file_1", "type": "image"}], - } - mock_request.return_value = mock_response - - files = { - "file": { - "type": "image", - "transfer_method": "remote_url", - "url": "https://example.com/image.jpg", - } - } - response = self.client.create_completion_message( - inputs={"prompt": "Describe this image"}, - response_mode="blocking", - user="user_123", - files=files, - ) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertIn("image", data["answer"]) - self.assertEqual(len(data["files"]), 1) - - -class TestWorkflowClientIntegration(unittest.TestCase): - """Integration tests for WorkflowClient.""" - - def setUp(self): - self.client = WorkflowClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_run_workflow_blocking(self, mock_request): - """Test run workflow with blocking response.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "run_123", - "workflow_id": "workflow_123", - "status": "succeeded", - "inputs": {"query": "Test input"}, - "outputs": {"result": "Test output"}, - "elapsed_time": 2.5, - } - mock_request.return_value = mock_response - - response = self.client.run(inputs={"query": "Test input"}, response_mode="blocking", user="user_123") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["status"], "succeeded") - self.assertEqual(data["outputs"]["result"], "Test output") - - @patch("httpx.Client.request") - def test_get_workflow_logs(self, mock_request): - """Test get_workflow_logs integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "logs": [ - {"id": "log_1", "status": "succeeded", "created_at": 1234567890}, - {"id": "log_2", "status": "failed", "created_at": 1234567891}, - ], - "total": 2, - "page": 1, - "limit": 20, - } - mock_request.return_value = mock_response - - response = self.client.get_workflow_logs(page=1, limit=20) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["logs"]), 2) - self.assertEqual(data["logs"][0]["status"], "succeeded") - - -class TestKnowledgeBaseClientIntegration(unittest.TestCase): - """Integration tests for KnowledgeBaseClient.""" - - def setUp(self): - self.client = KnowledgeBaseClient("test_api_key") - - @patch("httpx.Client.request") - def test_create_dataset(self, mock_request): - """Test create_dataset integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "id": "dataset_123", - "name": "Test Dataset", - "description": "A test dataset", - "created_at": 1234567890, - } - mock_request.return_value = mock_response - - response = self.client.create_dataset(name="Test Dataset") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["name"], "Test Dataset") - self.assertEqual(data["id"], "dataset_123") - - @patch("httpx.Client.request") - def test_list_datasets(self, mock_request): - """Test list_datasets integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "data": [ - {"id": "dataset_1", "name": "Dataset 1"}, - {"id": "dataset_2", "name": "Dataset 2"}, - ], - "has_more": False, - "limit": 20, - } - mock_request.return_value = mock_response - - response = self.client.list_datasets(page=1, page_size=20) - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["data"]), 2) - - @patch("httpx.Client.request") - def test_create_document_by_text(self, mock_request): - """Test create_document_by_text integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "document": { - "id": "doc_123", - "name": "Test Document", - "word_count": 100, - "status": "indexing", - } - } - mock_request.return_value = mock_response - - # Mock dataset_id - self.client.dataset_id = "dataset_123" - - response = self.client.create_document_by_text(name="Test Document", text="This is test document content.") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(data["document"]["name"], "Test Document") - self.assertEqual(data["document"]["word_count"], 100) - - -class TestWorkspaceClientIntegration(unittest.TestCase): - """Integration tests for WorkspaceClient.""" - - def setUp(self): - self.client = WorkspaceClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_get_available_models(self, mock_request): - """Test get_available_models integration.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "models": [ - {"id": "gpt-4", "name": "GPT-4", "provider": "openai"}, - {"id": "claude-3", "name": "Claude 3", "provider": "anthropic"}, - ] - } - mock_request.return_value = mock_response - - response = self.client.get_available_models("llm") - data = response.json() - - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data["models"]), 2) - self.assertEqual(data["models"][0]["id"], "gpt-4") - - -class TestErrorScenariosIntegration(unittest.TestCase): - """Integration tests for error scenarios.""" - - def setUp(self): - self.client = DifyClient("test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_authentication_error_integration(self, mock_request): - """Test authentication error in integration.""" - mock_response = Mock() - mock_response.status_code = 401 - mock_response.json.return_value = {"message": "Invalid API key"} - mock_request.return_value = mock_response - - with self.assertRaises(AuthenticationError) as context: - self.client.get_app_info() - - self.assertEqual(str(context.exception), "Invalid API key") - self.assertEqual(context.exception.status_code, 401) - - @patch("httpx.Client.request") - def test_rate_limit_error_integration(self, mock_request): - """Test rate limit error in integration.""" - mock_response = Mock() - mock_response.status_code = 429 - mock_response.json.return_value = {"message": "Rate limit exceeded"} - mock_response.headers = {"Retry-After": "60"} - mock_request.return_value = mock_response - - with self.assertRaises(RateLimitError) as context: - self.client.get_app_info() - - self.assertEqual(str(context.exception), "Rate limit exceeded") - self.assertEqual(context.exception.retry_after, "60") - - @patch("httpx.Client.request") - def test_server_error_with_retry_integration(self, mock_request): - """Test server error with retry in integration.""" - # API errors don't retry by design - only network/timeout errors retry - mock_response_500 = Mock() - mock_response_500.status_code = 500 - mock_response_500.json.return_value = {"message": "Internal server error"} - - mock_request.return_value = mock_response_500 - - with patch("time.sleep"): # Skip actual sleep - with self.assertRaises(APIError) as context: - self.client.get_app_info() - - self.assertEqual(str(context.exception), "Internal server error") - self.assertEqual(mock_request.call_count, 1) - - @patch("httpx.Client.request") - def test_validation_error_integration(self, mock_request): - """Test validation error in integration.""" - mock_response = Mock() - mock_response.status_code = 422 - mock_response.json.return_value = { - "message": "Validation failed", - "details": {"field": "query", "error": "required"}, - } - mock_request.return_value = mock_response - - with self.assertRaises(ValidationError) as context: - self.client.get_app_info() - - self.assertEqual(str(context.exception), "Validation failed") - self.assertEqual(context.exception.status_code, 422) - - -class TestContextManagerIntegration(unittest.TestCase): - """Integration tests for context manager usage.""" - - @patch("httpx.Client.close") - @patch("httpx.Client.request") - def test_context_manager_usage(self, mock_request, mock_close): - """Test context manager properly closes connections.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = {"id": "app_123", "name": "Test App"} - mock_request.return_value = mock_response - - with DifyClient("test_api_key") as client: - response = client.get_app_info() - self.assertEqual(response.status_code, 200) - - # Verify close was called - mock_close.assert_called_once() - - @patch("httpx.Client.close") - def test_manual_close(self, mock_close): - """Test manual close method.""" - client = DifyClient("test_api_key") - client.close() - mock_close.assert_called_once() - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_models.py b/sdks/python-client/tests/test_models.py deleted file mode 100644 index db9d92ad5b..0000000000 --- a/sdks/python-client/tests/test_models.py +++ /dev/null @@ -1,640 +0,0 @@ -"""Unit tests for response models.""" - -import unittest -import json -from datetime import datetime -from dify_client.models import ( - BaseResponse, - ErrorResponse, - FileInfo, - MessageResponse, - ConversationResponse, - DatasetResponse, - DocumentResponse, - DocumentSegmentResponse, - WorkflowRunResponse, - ApplicationParametersResponse, - AnnotationResponse, - PaginatedResponse, - ConversationVariableResponse, - FileUploadResponse, - AudioResponse, - SuggestedQuestionsResponse, - AppInfoResponse, - WorkspaceModelsResponse, - HitTestingResponse, - DatasetTagsResponse, - WorkflowLogsResponse, - ModelProviderResponse, - FileInfoResponse, - WorkflowDraftResponse, - ApiTokenResponse, - JobStatusResponse, - DatasetQueryResponse, - DatasetTemplateResponse, -) - - -class TestResponseModels(unittest.TestCase): - """Test cases for response model classes.""" - - def test_base_response(self): - """Test BaseResponse model.""" - response = BaseResponse(success=True, message="Operation successful") - self.assertTrue(response.success) - self.assertEqual(response.message, "Operation successful") - - def test_base_response_defaults(self): - """Test BaseResponse with default values.""" - response = BaseResponse(success=True) - self.assertTrue(response.success) - self.assertIsNone(response.message) - - def test_error_response(self): - """Test ErrorResponse model.""" - response = ErrorResponse( - success=False, - message="Error occurred", - error_code="VALIDATION_ERROR", - details={"field": "invalid_value"}, - ) - self.assertFalse(response.success) - self.assertEqual(response.message, "Error occurred") - self.assertEqual(response.error_code, "VALIDATION_ERROR") - self.assertEqual(response.details["field"], "invalid_value") - - def test_file_info(self): - """Test FileInfo model.""" - now = datetime.now() - file_info = FileInfo( - id="file_123", - name="test.txt", - size=1024, - mime_type="text/plain", - url="https://example.com/file.txt", - created_at=now, - ) - self.assertEqual(file_info.id, "file_123") - self.assertEqual(file_info.name, "test.txt") - self.assertEqual(file_info.size, 1024) - self.assertEqual(file_info.mime_type, "text/plain") - self.assertEqual(file_info.url, "https://example.com/file.txt") - self.assertEqual(file_info.created_at, now) - - def test_message_response(self): - """Test MessageResponse model.""" - response = MessageResponse( - success=True, - id="msg_123", - answer="Hello, world!", - conversation_id="conv_123", - created_at=1234567890, - metadata={"model": "gpt-4"}, - files=[{"id": "file_1", "type": "image"}], - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "msg_123") - self.assertEqual(response.answer, "Hello, world!") - self.assertEqual(response.conversation_id, "conv_123") - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.metadata["model"], "gpt-4") - self.assertEqual(response.files[0]["id"], "file_1") - - def test_conversation_response(self): - """Test ConversationResponse model.""" - response = ConversationResponse( - success=True, - id="conv_123", - name="Test Conversation", - inputs={"query": "Hello"}, - status="active", - created_at=1234567890, - updated_at=1234567891, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "conv_123") - self.assertEqual(response.name, "Test Conversation") - self.assertEqual(response.inputs["query"], "Hello") - self.assertEqual(response.status, "active") - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.updated_at, 1234567891) - - def test_dataset_response(self): - """Test DatasetResponse model.""" - response = DatasetResponse( - success=True, - id="dataset_123", - name="Test Dataset", - description="A test dataset", - permission="read", - indexing_technique="high_quality", - embedding_model="text-embedding-ada-002", - embedding_model_provider="openai", - retrieval_model={"search_type": "semantic"}, - document_count=10, - word_count=5000, - app_count=2, - created_at=1234567890, - updated_at=1234567891, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "dataset_123") - self.assertEqual(response.name, "Test Dataset") - self.assertEqual(response.description, "A test dataset") - self.assertEqual(response.permission, "read") - self.assertEqual(response.indexing_technique, "high_quality") - self.assertEqual(response.embedding_model, "text-embedding-ada-002") - self.assertEqual(response.embedding_model_provider, "openai") - self.assertEqual(response.retrieval_model["search_type"], "semantic") - self.assertEqual(response.document_count, 10) - self.assertEqual(response.word_count, 5000) - self.assertEqual(response.app_count, 2) - - def test_document_response(self): - """Test DocumentResponse model.""" - response = DocumentResponse( - success=True, - id="doc_123", - name="test_document.txt", - data_source_type="upload_file", - position=1, - enabled=True, - word_count=1000, - hit_count=5, - doc_form="text_model", - created_at=1234567890.0, - indexing_status="completed", - completed_at=1234567891.0, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "doc_123") - self.assertEqual(response.name, "test_document.txt") - self.assertEqual(response.data_source_type, "upload_file") - self.assertEqual(response.position, 1) - self.assertTrue(response.enabled) - self.assertEqual(response.word_count, 1000) - self.assertEqual(response.hit_count, 5) - self.assertEqual(response.doc_form, "text_model") - self.assertEqual(response.created_at, 1234567890.0) - self.assertEqual(response.indexing_status, "completed") - self.assertEqual(response.completed_at, 1234567891.0) - - def test_document_segment_response(self): - """Test DocumentSegmentResponse model.""" - response = DocumentSegmentResponse( - success=True, - id="seg_123", - position=1, - document_id="doc_123", - content="This is a test segment.", - answer="Test answer", - word_count=5, - tokens=10, - keywords=["test", "segment"], - hit_count=2, - enabled=True, - status="completed", - created_at=1234567890.0, - completed_at=1234567891.0, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "seg_123") - self.assertEqual(response.position, 1) - self.assertEqual(response.document_id, "doc_123") - self.assertEqual(response.content, "This is a test segment.") - self.assertEqual(response.answer, "Test answer") - self.assertEqual(response.word_count, 5) - self.assertEqual(response.tokens, 10) - self.assertEqual(response.keywords, ["test", "segment"]) - self.assertEqual(response.hit_count, 2) - self.assertTrue(response.enabled) - self.assertEqual(response.status, "completed") - self.assertEqual(response.created_at, 1234567890.0) - self.assertEqual(response.completed_at, 1234567891.0) - - def test_workflow_run_response(self): - """Test WorkflowRunResponse model.""" - response = WorkflowRunResponse( - success=True, - id="run_123", - workflow_id="workflow_123", - status="succeeded", - inputs={"query": "test"}, - outputs={"answer": "result"}, - elapsed_time=5.5, - total_tokens=100, - total_steps=3, - created_at=1234567890.0, - finished_at=1234567895.5, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "run_123") - self.assertEqual(response.workflow_id, "workflow_123") - self.assertEqual(response.status, "succeeded") - self.assertEqual(response.inputs["query"], "test") - self.assertEqual(response.outputs["answer"], "result") - self.assertEqual(response.elapsed_time, 5.5) - self.assertEqual(response.total_tokens, 100) - self.assertEqual(response.total_steps, 3) - self.assertEqual(response.created_at, 1234567890.0) - self.assertEqual(response.finished_at, 1234567895.5) - - def test_application_parameters_response(self): - """Test ApplicationParametersResponse model.""" - response = ApplicationParametersResponse( - success=True, - opening_statement="Hello! How can I help you?", - suggested_questions=["What is AI?", "How does this work?"], - speech_to_text={"enabled": True}, - text_to_speech={"enabled": False, "voice": "alloy"}, - retriever_resource={"enabled": True}, - sensitive_word_avoidance={"enabled": False}, - file_upload={"enabled": True, "file_size_limit": 10485760}, - system_parameters={"max_tokens": 1000}, - user_input_form=[{"type": "text", "label": "Query"}], - ) - self.assertTrue(response.success) - self.assertEqual(response.opening_statement, "Hello! How can I help you?") - self.assertEqual(response.suggested_questions, ["What is AI?", "How does this work?"]) - self.assertTrue(response.speech_to_text["enabled"]) - self.assertFalse(response.text_to_speech["enabled"]) - self.assertEqual(response.text_to_speech["voice"], "alloy") - self.assertTrue(response.retriever_resource["enabled"]) - self.assertFalse(response.sensitive_word_avoidance["enabled"]) - self.assertTrue(response.file_upload["enabled"]) - self.assertEqual(response.file_upload["file_size_limit"], 10485760) - self.assertEqual(response.system_parameters["max_tokens"], 1000) - self.assertEqual(response.user_input_form[0]["type"], "text") - - def test_annotation_response(self): - """Test AnnotationResponse model.""" - response = AnnotationResponse( - success=True, - id="annotation_123", - question="What is the capital of France?", - answer="Paris", - content="Additional context", - created_at=1234567890.0, - updated_at=1234567891.0, - created_by="user_123", - updated_by="user_123", - hit_count=5, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "annotation_123") - self.assertEqual(response.question, "What is the capital of France?") - self.assertEqual(response.answer, "Paris") - self.assertEqual(response.content, "Additional context") - self.assertEqual(response.created_at, 1234567890.0) - self.assertEqual(response.updated_at, 1234567891.0) - self.assertEqual(response.created_by, "user_123") - self.assertEqual(response.updated_by, "user_123") - self.assertEqual(response.hit_count, 5) - - def test_paginated_response(self): - """Test PaginatedResponse model.""" - response = PaginatedResponse( - success=True, - data=[{"id": 1}, {"id": 2}, {"id": 3}], - has_more=True, - limit=10, - total=100, - page=1, - ) - self.assertTrue(response.success) - self.assertEqual(len(response.data), 3) - self.assertEqual(response.data[0]["id"], 1) - self.assertTrue(response.has_more) - self.assertEqual(response.limit, 10) - self.assertEqual(response.total, 100) - self.assertEqual(response.page, 1) - - def test_conversation_variable_response(self): - """Test ConversationVariableResponse model.""" - response = ConversationVariableResponse( - success=True, - conversation_id="conv_123", - variables=[ - {"id": "var_1", "name": "user_name", "value": "John"}, - {"id": "var_2", "name": "preferences", "value": {"theme": "dark"}}, - ], - ) - self.assertTrue(response.success) - self.assertEqual(response.conversation_id, "conv_123") - self.assertEqual(len(response.variables), 2) - self.assertEqual(response.variables[0]["name"], "user_name") - self.assertEqual(response.variables[0]["value"], "John") - self.assertEqual(response.variables[1]["name"], "preferences") - self.assertEqual(response.variables[1]["value"]["theme"], "dark") - - def test_file_upload_response(self): - """Test FileUploadResponse model.""" - response = FileUploadResponse( - success=True, - id="file_123", - name="test.txt", - size=1024, - mime_type="text/plain", - url="https://example.com/files/test.txt", - created_at=1234567890.0, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "file_123") - self.assertEqual(response.name, "test.txt") - self.assertEqual(response.size, 1024) - self.assertEqual(response.mime_type, "text/plain") - self.assertEqual(response.url, "https://example.com/files/test.txt") - self.assertEqual(response.created_at, 1234567890.0) - - def test_audio_response(self): - """Test AudioResponse model.""" - response = AudioResponse( - success=True, - audio="base64_encoded_audio_data", - audio_url="https://example.com/audio.mp3", - duration=10.5, - sample_rate=44100, - ) - self.assertTrue(response.success) - self.assertEqual(response.audio, "base64_encoded_audio_data") - self.assertEqual(response.audio_url, "https://example.com/audio.mp3") - self.assertEqual(response.duration, 10.5) - self.assertEqual(response.sample_rate, 44100) - - def test_suggested_questions_response(self): - """Test SuggestedQuestionsResponse model.""" - response = SuggestedQuestionsResponse( - success=True, - message_id="msg_123", - questions=[ - "What is machine learning?", - "How does AI work?", - "Can you explain neural networks?", - ], - ) - self.assertTrue(response.success) - self.assertEqual(response.message_id, "msg_123") - self.assertEqual(len(response.questions), 3) - self.assertEqual(response.questions[0], "What is machine learning?") - - def test_app_info_response(self): - """Test AppInfoResponse model.""" - response = AppInfoResponse( - success=True, - id="app_123", - name="Test App", - description="A test application", - icon="🤖", - icon_background="#FF6B6B", - mode="chat", - tags=["AI", "Chat", "Test"], - enable_site=True, - enable_api=True, - api_token="app_token_123", - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "app_123") - self.assertEqual(response.name, "Test App") - self.assertEqual(response.description, "A test application") - self.assertEqual(response.icon, "🤖") - self.assertEqual(response.icon_background, "#FF6B6B") - self.assertEqual(response.mode, "chat") - self.assertEqual(response.tags, ["AI", "Chat", "Test"]) - self.assertTrue(response.enable_site) - self.assertTrue(response.enable_api) - self.assertEqual(response.api_token, "app_token_123") - - def test_workspace_models_response(self): - """Test WorkspaceModelsResponse model.""" - response = WorkspaceModelsResponse( - success=True, - models=[ - {"id": "gpt-4", "name": "GPT-4", "provider": "openai"}, - {"id": "claude-3", "name": "Claude 3", "provider": "anthropic"}, - ], - ) - self.assertTrue(response.success) - self.assertEqual(len(response.models), 2) - self.assertEqual(response.models[0]["id"], "gpt-4") - self.assertEqual(response.models[0]["name"], "GPT-4") - self.assertEqual(response.models[0]["provider"], "openai") - - def test_hit_testing_response(self): - """Test HitTestingResponse model.""" - response = HitTestingResponse( - success=True, - query="What is machine learning?", - records=[ - {"content": "Machine learning is a subset of AI...", "score": 0.95}, - {"content": "ML algorithms learn from data...", "score": 0.87}, - ], - ) - self.assertTrue(response.success) - self.assertEqual(response.query, "What is machine learning?") - self.assertEqual(len(response.records), 2) - self.assertEqual(response.records[0]["score"], 0.95) - - def test_dataset_tags_response(self): - """Test DatasetTagsResponse model.""" - response = DatasetTagsResponse( - success=True, - tags=[ - {"id": "tag_1", "name": "Technology", "color": "#FF0000"}, - {"id": "tag_2", "name": "Science", "color": "#00FF00"}, - ], - ) - self.assertTrue(response.success) - self.assertEqual(len(response.tags), 2) - self.assertEqual(response.tags[0]["name"], "Technology") - self.assertEqual(response.tags[0]["color"], "#FF0000") - - def test_workflow_logs_response(self): - """Test WorkflowLogsResponse model.""" - response = WorkflowLogsResponse( - success=True, - logs=[ - {"id": "log_1", "status": "succeeded", "created_at": 1234567890}, - {"id": "log_2", "status": "failed", "created_at": 1234567891}, - ], - total=50, - page=1, - limit=10, - has_more=True, - ) - self.assertTrue(response.success) - self.assertEqual(len(response.logs), 2) - self.assertEqual(response.logs[0]["status"], "succeeded") - self.assertEqual(response.total, 50) - self.assertEqual(response.page, 1) - self.assertEqual(response.limit, 10) - self.assertTrue(response.has_more) - - def test_model_serialization(self): - """Test that models can be serialized to JSON.""" - response = MessageResponse( - success=True, - id="msg_123", - answer="Hello, world!", - conversation_id="conv_123", - ) - - # Convert to dict and then to JSON - response_dict = { - "success": response.success, - "id": response.id, - "answer": response.answer, - "conversation_id": response.conversation_id, - } - - json_str = json.dumps(response_dict) - parsed = json.loads(json_str) - - self.assertTrue(parsed["success"]) - self.assertEqual(parsed["id"], "msg_123") - self.assertEqual(parsed["answer"], "Hello, world!") - self.assertEqual(parsed["conversation_id"], "conv_123") - - # Tests for new response models - def test_model_provider_response(self): - """Test ModelProviderResponse model.""" - response = ModelProviderResponse( - success=True, - provider_name="openai", - provider_type="llm", - models=[ - {"id": "gpt-4", "name": "GPT-4", "max_tokens": 8192}, - {"id": "gpt-3.5-turbo", "name": "GPT-3.5 Turbo", "max_tokens": 4096}, - ], - is_enabled=True, - credentials={"api_key": "sk-..."}, - ) - self.assertTrue(response.success) - self.assertEqual(response.provider_name, "openai") - self.assertEqual(response.provider_type, "llm") - self.assertEqual(len(response.models), 2) - self.assertEqual(response.models[0]["id"], "gpt-4") - self.assertTrue(response.is_enabled) - self.assertEqual(response.credentials["api_key"], "sk-...") - - def test_file_info_response(self): - """Test FileInfoResponse model.""" - response = FileInfoResponse( - success=True, - id="file_123", - name="document.pdf", - size=2048576, - mime_type="application/pdf", - url="https://example.com/files/document.pdf", - created_at=1234567890, - metadata={"pages": 10, "author": "John Doe"}, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "file_123") - self.assertEqual(response.name, "document.pdf") - self.assertEqual(response.size, 2048576) - self.assertEqual(response.mime_type, "application/pdf") - self.assertEqual(response.url, "https://example.com/files/document.pdf") - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.metadata["pages"], 10) - - def test_workflow_draft_response(self): - """Test WorkflowDraftResponse model.""" - response = WorkflowDraftResponse( - success=True, - id="draft_123", - app_id="app_456", - draft_data={"nodes": [], "edges": [], "config": {"name": "Test Workflow"}}, - version=1, - created_at=1234567890, - updated_at=1234567891, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "draft_123") - self.assertEqual(response.app_id, "app_456") - self.assertEqual(response.draft_data["config"]["name"], "Test Workflow") - self.assertEqual(response.version, 1) - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.updated_at, 1234567891) - - def test_api_token_response(self): - """Test ApiTokenResponse model.""" - response = ApiTokenResponse( - success=True, - id="token_123", - name="Production Token", - token="app-xxxxxxxxxxxx", - description="Token for production environment", - created_at=1234567890, - last_used_at=1234567891, - is_active=True, - ) - self.assertTrue(response.success) - self.assertEqual(response.id, "token_123") - self.assertEqual(response.name, "Production Token") - self.assertEqual(response.token, "app-xxxxxxxxxxxx") - self.assertEqual(response.description, "Token for production environment") - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.last_used_at, 1234567891) - self.assertTrue(response.is_active) - - def test_job_status_response(self): - """Test JobStatusResponse model.""" - response = JobStatusResponse( - success=True, - job_id="job_123", - job_status="running", - error_msg=None, - progress=0.75, - created_at=1234567890, - updated_at=1234567891, - ) - self.assertTrue(response.success) - self.assertEqual(response.job_id, "job_123") - self.assertEqual(response.job_status, "running") - self.assertIsNone(response.error_msg) - self.assertEqual(response.progress, 0.75) - self.assertEqual(response.created_at, 1234567890) - self.assertEqual(response.updated_at, 1234567891) - - def test_dataset_query_response(self): - """Test DatasetQueryResponse model.""" - response = DatasetQueryResponse( - success=True, - query="What is machine learning?", - records=[ - {"content": "Machine learning is...", "score": 0.95}, - {"content": "ML algorithms...", "score": 0.87}, - ], - total=2, - search_time=0.123, - retrieval_model={"method": "semantic_search", "top_k": 3}, - ) - self.assertTrue(response.success) - self.assertEqual(response.query, "What is machine learning?") - self.assertEqual(len(response.records), 2) - self.assertEqual(response.total, 2) - self.assertEqual(response.search_time, 0.123) - self.assertEqual(response.retrieval_model["method"], "semantic_search") - - def test_dataset_template_response(self): - """Test DatasetTemplateResponse model.""" - response = DatasetTemplateResponse( - success=True, - template_name="customer_support", - display_name="Customer Support", - description="Template for customer support knowledge base", - category="support", - icon="🎧", - config_schema={"fields": [{"name": "category", "type": "string"}]}, - ) - self.assertTrue(response.success) - self.assertEqual(response.template_name, "customer_support") - self.assertEqual(response.display_name, "Customer Support") - self.assertEqual(response.description, "Template for customer support knowledge base") - self.assertEqual(response.category, "support") - self.assertEqual(response.icon, "🎧") - self.assertEqual(response.config_schema["fields"][0]["name"], "category") - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/tests/test_retry_and_error_handling.py b/sdks/python-client/tests/test_retry_and_error_handling.py deleted file mode 100644 index bd415bde43..0000000000 --- a/sdks/python-client/tests/test_retry_and_error_handling.py +++ /dev/null @@ -1,313 +0,0 @@ -"""Unit tests for retry mechanism and error handling.""" - -import unittest -from unittest.mock import Mock, patch, MagicMock -import httpx -from dify_client.client import DifyClient -from dify_client.exceptions import ( - APIError, - AuthenticationError, - RateLimitError, - ValidationError, - NetworkError, - TimeoutError, - FileUploadError, -) - - -class TestRetryMechanism(unittest.TestCase): - """Test cases for retry mechanism.""" - - def setUp(self): - self.api_key = "test_api_key" - self.base_url = "https://api.dify.ai/v1" - self.client = DifyClient( - api_key=self.api_key, - base_url=self.base_url, - max_retries=3, - retry_delay=0.1, # Short delay for tests - enable_logging=False, - ) - - @patch("httpx.Client.request") - def test_successful_request_no_retry(self, mock_request): - """Test that successful requests don't trigger retries.""" - mock_response = Mock() - mock_response.status_code = 200 - mock_response.content = b'{"success": true}' - mock_request.return_value = mock_response - - response = self.client._send_request("GET", "/test") - - self.assertEqual(response, mock_response) - self.assertEqual(mock_request.call_count, 1) - - @patch("httpx.Client.request") - @patch("time.sleep") - def test_retry_on_network_error(self, mock_sleep, mock_request): - """Test retry on network errors.""" - # First two calls raise network error, third succeeds - mock_request.side_effect = [ - httpx.NetworkError("Connection failed"), - httpx.NetworkError("Connection failed"), - Mock(status_code=200, content=b'{"success": true}'), - ] - mock_response = Mock() - mock_response.status_code = 200 - mock_response.content = b'{"success": true}' - - response = self.client._send_request("GET", "/test") - - self.assertEqual(response.status_code, 200) - self.assertEqual(mock_request.call_count, 3) - self.assertEqual(mock_sleep.call_count, 2) - - @patch("httpx.Client.request") - @patch("time.sleep") - def test_retry_on_timeout_error(self, mock_sleep, mock_request): - """Test retry on timeout errors.""" - mock_request.side_effect = [ - httpx.TimeoutException("Request timed out"), - httpx.TimeoutException("Request timed out"), - Mock(status_code=200, content=b'{"success": true}'), - ] - - response = self.client._send_request("GET", "/test") - - self.assertEqual(response.status_code, 200) - self.assertEqual(mock_request.call_count, 3) - self.assertEqual(mock_sleep.call_count, 2) - - @patch("httpx.Client.request") - @patch("time.sleep") - def test_max_retries_exceeded(self, mock_sleep, mock_request): - """Test behavior when max retries are exceeded.""" - mock_request.side_effect = httpx.NetworkError("Persistent network error") - - with self.assertRaises(NetworkError): - self.client._send_request("GET", "/test") - - self.assertEqual(mock_request.call_count, 4) # 1 initial + 3 retries - self.assertEqual(mock_sleep.call_count, 3) - - @patch("httpx.Client.request") - def test_no_retry_on_client_error(self, mock_request): - """Test that client errors (4xx) don't trigger retries.""" - mock_response = Mock() - mock_response.status_code = 401 - mock_response.json.return_value = {"message": "Unauthorized"} - mock_request.return_value = mock_response - - with self.assertRaises(AuthenticationError): - self.client._send_request("GET", "/test") - - self.assertEqual(mock_request.call_count, 1) - - @patch("httpx.Client.request") - def test_retry_on_server_error(self, mock_request): - """Test that server errors (5xx) don't retry - they raise APIError immediately.""" - mock_response_500 = Mock() - mock_response_500.status_code = 500 - mock_response_500.json.return_value = {"message": "Internal server error"} - - mock_request.return_value = mock_response_500 - - with self.assertRaises(APIError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Internal server error") - self.assertEqual(context.exception.status_code, 500) - # Should not retry server errors - self.assertEqual(mock_request.call_count, 1) - - @patch("httpx.Client.request") - def test_exponential_backoff(self, mock_request): - """Test exponential backoff timing.""" - mock_request.side_effect = [ - httpx.NetworkError("Connection failed"), - httpx.NetworkError("Connection failed"), - httpx.NetworkError("Connection failed"), - httpx.NetworkError("Connection failed"), # All attempts fail - ] - - with patch("time.sleep") as mock_sleep: - with self.assertRaises(NetworkError): - self.client._send_request("GET", "/test") - - # Check exponential backoff: 0.1, 0.2, 0.4 - expected_calls = [0.1, 0.2, 0.4] - actual_calls = [call[0][0] for call in mock_sleep.call_args_list] - self.assertEqual(actual_calls, expected_calls) - - -class TestErrorHandling(unittest.TestCase): - """Test cases for error handling.""" - - def setUp(self): - self.client = DifyClient(api_key="test_api_key", enable_logging=False) - - @patch("httpx.Client.request") - def test_authentication_error(self, mock_request): - """Test AuthenticationError handling.""" - mock_response = Mock() - mock_response.status_code = 401 - mock_response.json.return_value = {"message": "Invalid API key"} - mock_request.return_value = mock_response - - with self.assertRaises(AuthenticationError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Invalid API key") - self.assertEqual(context.exception.status_code, 401) - - @patch("httpx.Client.request") - def test_rate_limit_error(self, mock_request): - """Test RateLimitError handling.""" - mock_response = Mock() - mock_response.status_code = 429 - mock_response.json.return_value = {"message": "Rate limit exceeded"} - mock_response.headers = {"Retry-After": "60"} - mock_request.return_value = mock_response - - with self.assertRaises(RateLimitError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Rate limit exceeded") - self.assertEqual(context.exception.retry_after, "60") - - @patch("httpx.Client.request") - def test_validation_error(self, mock_request): - """Test ValidationError handling.""" - mock_response = Mock() - mock_response.status_code = 422 - mock_response.json.return_value = {"message": "Invalid parameters"} - mock_request.return_value = mock_response - - with self.assertRaises(ValidationError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Invalid parameters") - self.assertEqual(context.exception.status_code, 422) - - @patch("httpx.Client.request") - def test_api_error(self, mock_request): - """Test general APIError handling.""" - mock_response = Mock() - mock_response.status_code = 500 - mock_response.json.return_value = {"message": "Internal server error"} - mock_request.return_value = mock_response - - with self.assertRaises(APIError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "Internal server error") - self.assertEqual(context.exception.status_code, 500) - - @patch("httpx.Client.request") - def test_error_response_without_json(self, mock_request): - """Test error handling when response doesn't contain valid JSON.""" - mock_response = Mock() - mock_response.status_code = 500 - mock_response.content = b"Internal Server Error" - mock_response.json.side_effect = ValueError("No JSON object could be decoded") - mock_request.return_value = mock_response - - with self.assertRaises(APIError) as context: - self.client._send_request("GET", "/test") - - self.assertEqual(str(context.exception), "HTTP 500") - - @patch("httpx.Client.request") - def test_file_upload_error(self, mock_request): - """Test FileUploadError handling.""" - mock_response = Mock() - mock_response.status_code = 400 - mock_response.json.return_value = {"message": "File upload failed"} - mock_request.return_value = mock_response - - with self.assertRaises(FileUploadError) as context: - self.client._send_request_with_files("POST", "/upload", {}, {}) - - self.assertEqual(str(context.exception), "File upload failed") - self.assertEqual(context.exception.status_code, 400) - - -class TestParameterValidation(unittest.TestCase): - """Test cases for parameter validation.""" - - def setUp(self): - self.client = DifyClient(api_key="test_api_key", enable_logging=False) - - def test_empty_string_validation(self): - """Test validation of empty strings.""" - with self.assertRaises(ValidationError): - self.client._validate_params(empty_string="") - - def test_whitespace_only_string_validation(self): - """Test validation of whitespace-only strings.""" - with self.assertRaises(ValidationError): - self.client._validate_params(whitespace_string=" ") - - def test_long_string_validation(self): - """Test validation of overly long strings.""" - long_string = "a" * 10001 # Exceeds 10000 character limit - with self.assertRaises(ValidationError): - self.client._validate_params(long_string=long_string) - - def test_large_list_validation(self): - """Test validation of overly large lists.""" - large_list = list(range(1001)) # Exceeds 1000 item limit - with self.assertRaises(ValidationError): - self.client._validate_params(large_list=large_list) - - def test_large_dict_validation(self): - """Test validation of overly large dictionaries.""" - large_dict = {f"key_{i}": i for i in range(101)} # Exceeds 100 item limit - with self.assertRaises(ValidationError): - self.client._validate_params(large_dict=large_dict) - - def test_valid_parameters_pass(self): - """Test that valid parameters pass validation.""" - # Should not raise any exception - self.client._validate_params( - valid_string="Hello, World!", - valid_list=[1, 2, 3], - valid_dict={"key": "value"}, - none_value=None, - ) - - def test_message_feedback_validation(self): - """Test validation in message_feedback method.""" - with self.assertRaises(ValidationError): - self.client.message_feedback("msg_id", "invalid_rating", "user") - - def test_completion_message_validation(self): - """Test validation in create_completion_message method.""" - from dify_client.client import CompletionClient - - client = CompletionClient("test_api_key") - - with self.assertRaises(ValidationError): - client.create_completion_message( - inputs="not_a_dict", # Should be a dict - response_mode="invalid_mode", # Should be 'blocking' or 'streaming' - user="test_user", - ) - - def test_chat_message_validation(self): - """Test validation in create_chat_message method.""" - from dify_client.client import ChatClient - - client = ChatClient("test_api_key") - - with self.assertRaises(ValidationError): - client.create_chat_message( - inputs="not_a_dict", # Should be a dict - query="", # Should not be empty - user="test_user", - response_mode="invalid_mode", # Should be 'blocking' or 'streaming' - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/sdks/python-client/uv.lock b/sdks/python-client/uv.lock deleted file mode 100644 index 4a9d7d5193..0000000000 --- a/sdks/python-client/uv.lock +++ /dev/null @@ -1,307 +0,0 @@ -version = 1 -revision = 3 -requires-python = ">=3.10" - -[[package]] -name = "aiofiles" -version = "25.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/41/c3/534eac40372d8ee36ef40df62ec129bee4fdb5ad9706e58a29be53b2c970/aiofiles-25.1.0.tar.gz", hash = "sha256:a8d728f0a29de45dc521f18f07297428d56992a742f0cd2701ba86e44d23d5b2", size = 46354, upload-time = "2025-10-09T20:51:04.358Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/bc/8a/340a1555ae33d7354dbca4faa54948d76d89a27ceef032c8c3bc661d003e/aiofiles-25.1.0-py3-none-any.whl", hash = "sha256:abe311e527c862958650f9438e859c1fa7568a141b22abcd015e120e86a85695", size = 14668, upload-time = "2025-10-09T20:51:03.174Z" }, -] - -[[package]] -name = "anyio" -version = "4.11.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, - { name = "idna" }, - { name = "sniffio" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c6/78/7d432127c41b50bccba979505f272c16cbcadcc33645d5fa3a738110ae75/anyio-4.11.0.tar.gz", hash = "sha256:82a8d0b81e318cc5ce71a5f1f8b5c4e63619620b63141ef8c995fa0db95a57c4", size = 219094, upload-time = "2025-09-23T09:19:12.58Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/15/b3/9b1a8074496371342ec1e796a96f99c82c945a339cd81a8e73de28b4cf9e/anyio-4.11.0-py3-none-any.whl", hash = "sha256:0287e96f4d26d4149305414d4e3bc32f0dcd0862365a4bddea19d7a1ec38c4fc", size = 109097, upload-time = "2025-09-23T09:19:10.601Z" }, -] - -[[package]] -name = "backports-asyncio-runner" -version = "1.2.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893, upload-time = "2025-07-02T02:27:15.685Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" }, -] - -[[package]] -name = "certifi" -version = "2025.10.5" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/4c/5b/b6ce21586237c77ce67d01dc5507039d444b630dd76611bbca2d8e5dcd91/certifi-2025.10.5.tar.gz", hash = "sha256:47c09d31ccf2acf0be3f701ea53595ee7e0b8fa08801c6624be771df09ae7b43", size = 164519, upload-time = "2025-10-05T04:12:15.808Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e4/37/af0d2ef3967ac0d6113837b44a4f0bfe1328c2b9763bd5b1744520e5cfed/certifi-2025.10.5-py3-none-any.whl", hash = "sha256:0f212c2744a9bb6de0c56639a6f68afe01ecd92d91f14ae897c4fe7bbeeef0de", size = 163286, upload-time = "2025-10-05T04:12:14.03Z" }, -] - -[[package]] -name = "colorama" -version = "0.4.6" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, -] - -[[package]] -name = "dify-client" -version = "0.1.12" -source = { editable = "." } -dependencies = [ - { name = "aiofiles" }, - { name = "httpx", extra = ["http2"] }, -] - -[package.optional-dependencies] -dev = [ - { name = "pytest" }, - { name = "pytest-asyncio" }, -] - -[package.metadata] -requires-dist = [ - { name = "aiofiles", specifier = ">=23.0.0" }, - { name = "httpx", extras = ["http2"], specifier = ">=0.27.0" }, - { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" }, - { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.21.0" }, -] -provides-extras = ["dev"] - -[[package]] -name = "exceptiongroup" -version = "1.3.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10", size = 16674, upload-time = "2025-05-10T17:42:49.33Z" }, -] - -[[package]] -name = "h11" -version = "0.16.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, -] - -[[package]] -name = "h2" -version = "4.3.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "hpack" }, - { name = "hyperframe" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/1d/17/afa56379f94ad0fe8defd37d6eb3f89a25404ffc71d4d848893d270325fc/h2-4.3.0.tar.gz", hash = "sha256:6c59efe4323fa18b47a632221a1888bd7fde6249819beda254aeca909f221bf1", size = 2152026, upload-time = "2025-08-23T18:12:19.778Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/69/b2/119f6e6dcbd96f9069ce9a2665e0146588dc9f88f29549711853645e736a/h2-4.3.0-py3-none-any.whl", hash = "sha256:c438f029a25f7945c69e0ccf0fb951dc3f73a5f6412981daee861431b70e2bdd", size = 61779, upload-time = "2025-08-23T18:12:17.779Z" }, -] - -[[package]] -name = "hpack" -version = "4.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/2c/48/71de9ed269fdae9c8057e5a4c0aa7402e8bb16f2c6e90b3aa53327b113f8/hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca", size = 51276, upload-time = "2025-01-22T21:44:58.347Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/07/c6/80c95b1b2b94682a72cbdbfb85b81ae2daffa4291fbfa1b1464502ede10d/hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496", size = 34357, upload-time = "2025-01-22T21:44:56.92Z" }, -] - -[[package]] -name = "httpcore" -version = "1.0.9" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "certifi" }, - { name = "h11" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, -] - -[[package]] -name = "httpx" -version = "0.28.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "certifi" }, - { name = "httpcore" }, - { name = "idna" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, -] - -[package.optional-dependencies] -http2 = [ - { name = "h2" }, -] - -[[package]] -name = "hyperframe" -version = "6.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/02/e7/94f8232d4a74cc99514c13a9f995811485a6903d48e5d952771ef6322e30/hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08", size = 26566, upload-time = "2025-01-22T21:41:49.302Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/48/30/47d0bf6072f7252e6521f3447ccfa40b421b6824517f82854703d0f5a98b/hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5", size = 13007, upload-time = "2025-01-22T21:41:47.295Z" }, -] - -[[package]] -name = "idna" -version = "3.10" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490, upload-time = "2024-09-15T18:07:39.745Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, -] - -[[package]] -name = "iniconfig" -version = "2.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, -] - -[[package]] -name = "packaging" -version = "25.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, -] - -[[package]] -name = "pluggy" -version = "1.6.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, -] - -[[package]] -name = "pygments" -version = "2.19.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, -] - -[[package]] -name = "pytest" -version = "8.4.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, - { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, - { name = "iniconfig" }, - { name = "packaging" }, - { name = "pluggy" }, - { name = "pygments" }, - { name = "tomli", marker = "python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, -] - -[[package]] -name = "pytest-asyncio" -version = "1.2.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "backports-asyncio-runner", marker = "python_full_version < '3.11'" }, - { name = "pytest" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/42/86/9e3c5f48f7b7b638b216e4b9e645f54d199d7abbbab7a64a13b4e12ba10f/pytest_asyncio-1.2.0.tar.gz", hash = "sha256:c609a64a2a8768462d0c99811ddb8bd2583c33fd33cf7f21af1c142e824ffb57", size = 50119, upload-time = "2025-09-12T07:33:53.816Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/93/2fa34714b7a4ae72f2f8dad66ba17dd9a2c793220719e736dda28b7aec27/pytest_asyncio-1.2.0-py3-none-any.whl", hash = "sha256:8e17ae5e46d8e7efe51ab6494dd2010f4ca8dae51652aa3c8d55acf50bfb2e99", size = 15095, upload-time = "2025-09-12T07:33:52.639Z" }, -] - -[[package]] -name = "sniffio" -version = "1.3.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372, upload-time = "2024-02-25T23:20:04.057Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, -] - -[[package]] -name = "tomli" -version = "2.3.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/52/ed/3f73f72945444548f33eba9a87fc7a6e969915e7b1acc8260b30e1f76a2f/tomli-2.3.0.tar.gz", hash = "sha256:64be704a875d2a59753d80ee8a533c3fe183e3f06807ff7dc2232938ccb01549", size = 17392, upload-time = "2025-10-08T22:01:47.119Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b3/2e/299f62b401438d5fe1624119c723f5d877acc86a4c2492da405626665f12/tomli-2.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:88bd15eb972f3664f5ed4b57c1634a97153b4bac4479dcb6a495f41921eb7f45", size = 153236, upload-time = "2025-10-08T22:01:00.137Z" }, - { url = "https://files.pythonhosted.org/packages/86/7f/d8fffe6a7aefdb61bced88fcb5e280cfd71e08939da5894161bd71bea022/tomli-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:883b1c0d6398a6a9d29b508c331fa56adbcdff647f6ace4dfca0f50e90dfd0ba", size = 148084, upload-time = "2025-10-08T22:01:01.63Z" }, - { url = "https://files.pythonhosted.org/packages/47/5c/24935fb6a2ee63e86d80e4d3b58b222dafaf438c416752c8b58537c8b89a/tomli-2.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d1381caf13ab9f300e30dd8feadb3de072aeb86f1d34a8569453ff32a7dea4bf", size = 234832, upload-time = "2025-10-08T22:01:02.543Z" }, - { url = "https://files.pythonhosted.org/packages/89/da/75dfd804fc11e6612846758a23f13271b76d577e299592b4371a4ca4cd09/tomli-2.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0e285d2649b78c0d9027570d4da3425bdb49830a6156121360b3f8511ea3441", size = 242052, upload-time = "2025-10-08T22:01:03.836Z" }, - { url = "https://files.pythonhosted.org/packages/70/8c/f48ac899f7b3ca7eb13af73bacbc93aec37f9c954df3c08ad96991c8c373/tomli-2.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0a154a9ae14bfcf5d8917a59b51ffd5a3ac1fd149b71b47a3a104ca4edcfa845", size = 239555, upload-time = "2025-10-08T22:01:04.834Z" }, - { url = "https://files.pythonhosted.org/packages/ba/28/72f8afd73f1d0e7829bfc093f4cb98ce0a40ffc0cc997009ee1ed94ba705/tomli-2.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:74bf8464ff93e413514fefd2be591c3b0b23231a77f901db1eb30d6f712fc42c", size = 245128, upload-time = "2025-10-08T22:01:05.84Z" }, - { url = "https://files.pythonhosted.org/packages/b6/eb/a7679c8ac85208706d27436e8d421dfa39d4c914dcf5fa8083a9305f58d9/tomli-2.3.0-cp311-cp311-win32.whl", hash = "sha256:00b5f5d95bbfc7d12f91ad8c593a1659b6387b43f054104cda404be6bda62456", size = 96445, upload-time = "2025-10-08T22:01:06.896Z" }, - { url = "https://files.pythonhosted.org/packages/0a/fe/3d3420c4cb1ad9cb462fb52967080575f15898da97e21cb6f1361d505383/tomli-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:4dc4ce8483a5d429ab602f111a93a6ab1ed425eae3122032db7e9acf449451be", size = 107165, upload-time = "2025-10-08T22:01:08.107Z" }, - { url = "https://files.pythonhosted.org/packages/ff/b7/40f36368fcabc518bb11c8f06379a0fd631985046c038aca08c6d6a43c6e/tomli-2.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7d86942e56ded512a594786a5ba0a5e521d02529b3826e7761a05138341a2ac", size = 154891, upload-time = "2025-10-08T22:01:09.082Z" }, - { url = "https://files.pythonhosted.org/packages/f9/3f/d9dd692199e3b3aab2e4e4dd948abd0f790d9ded8cd10cbaae276a898434/tomli-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:73ee0b47d4dad1c5e996e3cd33b8a76a50167ae5f96a2607cbe8cc773506ab22", size = 148796, upload-time = "2025-10-08T22:01:10.266Z" }, - { url = "https://files.pythonhosted.org/packages/60/83/59bff4996c2cf9f9387a0f5a3394629c7efa5ef16142076a23a90f1955fa/tomli-2.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:792262b94d5d0a466afb5bc63c7daa9d75520110971ee269152083270998316f", size = 242121, upload-time = "2025-10-08T22:01:11.332Z" }, - { url = "https://files.pythonhosted.org/packages/45/e5/7c5119ff39de8693d6baab6c0b6dcb556d192c165596e9fc231ea1052041/tomli-2.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f195fe57ecceac95a66a75ac24d9d5fbc98ef0962e09b2eddec5d39375aae52", size = 250070, upload-time = "2025-10-08T22:01:12.498Z" }, - { url = "https://files.pythonhosted.org/packages/45/12/ad5126d3a278f27e6701abde51d342aa78d06e27ce2bb596a01f7709a5a2/tomli-2.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e31d432427dcbf4d86958c184b9bfd1e96b5b71f8eb17e6d02531f434fd335b8", size = 245859, upload-time = "2025-10-08T22:01:13.551Z" }, - { url = "https://files.pythonhosted.org/packages/fb/a1/4d6865da6a71c603cfe6ad0e6556c73c76548557a8d658f9e3b142df245f/tomli-2.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7b0882799624980785240ab732537fcfc372601015c00f7fc367c55308c186f6", size = 250296, upload-time = "2025-10-08T22:01:14.614Z" }, - { url = "https://files.pythonhosted.org/packages/a0/b7/a7a7042715d55c9ba6e8b196d65d2cb662578b4d8cd17d882d45322b0d78/tomli-2.3.0-cp312-cp312-win32.whl", hash = "sha256:ff72b71b5d10d22ecb084d345fc26f42b5143c5533db5e2eaba7d2d335358876", size = 97124, upload-time = "2025-10-08T22:01:15.629Z" }, - { url = "https://files.pythonhosted.org/packages/06/1e/f22f100db15a68b520664eb3328fb0ae4e90530887928558112c8d1f4515/tomli-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:1cb4ed918939151a03f33d4242ccd0aa5f11b3547d0cf30f7c74a408a5b99878", size = 107698, upload-time = "2025-10-08T22:01:16.51Z" }, - { url = "https://files.pythonhosted.org/packages/89/48/06ee6eabe4fdd9ecd48bf488f4ac783844fd777f547b8d1b61c11939974e/tomli-2.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5192f562738228945d7b13d4930baffda67b69425a7f0da96d360b0a3888136b", size = 154819, upload-time = "2025-10-08T22:01:17.964Z" }, - { url = "https://files.pythonhosted.org/packages/f1/01/88793757d54d8937015c75dcdfb673c65471945f6be98e6a0410fba167ed/tomli-2.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:be71c93a63d738597996be9528f4abe628d1adf5e6eb11607bc8fe1a510b5dae", size = 148766, upload-time = "2025-10-08T22:01:18.959Z" }, - { url = "https://files.pythonhosted.org/packages/42/17/5e2c956f0144b812e7e107f94f1cc54af734eb17b5191c0bbfb72de5e93e/tomli-2.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c4665508bcbac83a31ff8ab08f424b665200c0e1e645d2bd9ab3d3e557b6185b", size = 240771, upload-time = "2025-10-08T22:01:20.106Z" }, - { url = "https://files.pythonhosted.org/packages/d5/f4/0fbd014909748706c01d16824eadb0307115f9562a15cbb012cd9b3512c5/tomli-2.3.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4021923f97266babc6ccab9f5068642a0095faa0a51a246a6a02fccbb3514eaf", size = 248586, upload-time = "2025-10-08T22:01:21.164Z" }, - { url = "https://files.pythonhosted.org/packages/30/77/fed85e114bde5e81ecf9bc5da0cc69f2914b38f4708c80ae67d0c10180c5/tomli-2.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a4ea38c40145a357d513bffad0ed869f13c1773716cf71ccaa83b0fa0cc4e42f", size = 244792, upload-time = "2025-10-08T22:01:22.417Z" }, - { url = "https://files.pythonhosted.org/packages/55/92/afed3d497f7c186dc71e6ee6d4fcb0acfa5f7d0a1a2878f8beae379ae0cc/tomli-2.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ad805ea85eda330dbad64c7ea7a4556259665bdf9d2672f5dccc740eb9d3ca05", size = 248909, upload-time = "2025-10-08T22:01:23.859Z" }, - { url = "https://files.pythonhosted.org/packages/f8/84/ef50c51b5a9472e7265ce1ffc7f24cd4023d289e109f669bdb1553f6a7c2/tomli-2.3.0-cp313-cp313-win32.whl", hash = "sha256:97d5eec30149fd3294270e889b4234023f2c69747e555a27bd708828353ab606", size = 96946, upload-time = "2025-10-08T22:01:24.893Z" }, - { url = "https://files.pythonhosted.org/packages/b2/b7/718cd1da0884f281f95ccfa3a6cc572d30053cba64603f79d431d3c9b61b/tomli-2.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:0c95ca56fbe89e065c6ead5b593ee64b84a26fca063b5d71a1122bf26e533999", size = 107705, upload-time = "2025-10-08T22:01:26.153Z" }, - { url = "https://files.pythonhosted.org/packages/19/94/aeafa14a52e16163008060506fcb6aa1949d13548d13752171a755c65611/tomli-2.3.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:cebc6fe843e0733ee827a282aca4999b596241195f43b4cc371d64fc6639da9e", size = 154244, upload-time = "2025-10-08T22:01:27.06Z" }, - { url = "https://files.pythonhosted.org/packages/db/e4/1e58409aa78eefa47ccd19779fc6f36787edbe7d4cd330eeeedb33a4515b/tomli-2.3.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:4c2ef0244c75aba9355561272009d934953817c49f47d768070c3c94355c2aa3", size = 148637, upload-time = "2025-10-08T22:01:28.059Z" }, - { url = "https://files.pythonhosted.org/packages/26/b6/d1eccb62f665e44359226811064596dd6a366ea1f985839c566cd61525ae/tomli-2.3.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c22a8bf253bacc0cf11f35ad9808b6cb75ada2631c2d97c971122583b129afbc", size = 241925, upload-time = "2025-10-08T22:01:29.066Z" }, - { url = "https://files.pythonhosted.org/packages/70/91/7cdab9a03e6d3d2bb11beae108da5bdc1c34bdeb06e21163482544ddcc90/tomli-2.3.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0eea8cc5c5e9f89c9b90c4896a8deefc74f518db5927d0e0e8d4a80953d774d0", size = 249045, upload-time = "2025-10-08T22:01:31.98Z" }, - { url = "https://files.pythonhosted.org/packages/15/1b/8c26874ed1f6e4f1fcfeb868db8a794cbe9f227299402db58cfcc858766c/tomli-2.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b74a0e59ec5d15127acdabd75ea17726ac4c5178ae51b85bfe39c4f8a278e879", size = 245835, upload-time = "2025-10-08T22:01:32.989Z" }, - { url = "https://files.pythonhosted.org/packages/fd/42/8e3c6a9a4b1a1360c1a2a39f0b972cef2cc9ebd56025168c4137192a9321/tomli-2.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b5870b50c9db823c595983571d1296a6ff3e1b88f734a4c8f6fc6188397de005", size = 253109, upload-time = "2025-10-08T22:01:34.052Z" }, - { url = "https://files.pythonhosted.org/packages/22/0c/b4da635000a71b5f80130937eeac12e686eefb376b8dee113b4a582bba42/tomli-2.3.0-cp314-cp314-win32.whl", hash = "sha256:feb0dacc61170ed7ab602d3d972a58f14ee3ee60494292d384649a3dc38ef463", size = 97930, upload-time = "2025-10-08T22:01:35.082Z" }, - { url = "https://files.pythonhosted.org/packages/b9/74/cb1abc870a418ae99cd5c9547d6bce30701a954e0e721821df483ef7223c/tomli-2.3.0-cp314-cp314-win_amd64.whl", hash = "sha256:b273fcbd7fc64dc3600c098e39136522650c49bca95df2d11cf3b626422392c8", size = 107964, upload-time = "2025-10-08T22:01:36.057Z" }, - { url = "https://files.pythonhosted.org/packages/54/78/5c46fff6432a712af9f792944f4fcd7067d8823157949f4e40c56b8b3c83/tomli-2.3.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:940d56ee0410fa17ee1f12b817b37a4d4e4dc4d27340863cc67236c74f582e77", size = 163065, upload-time = "2025-10-08T22:01:37.27Z" }, - { url = "https://files.pythonhosted.org/packages/39/67/f85d9bd23182f45eca8939cd2bc7050e1f90c41f4a2ecbbd5963a1d1c486/tomli-2.3.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:f85209946d1fe94416debbb88d00eb92ce9cd5266775424ff81bc959e001acaf", size = 159088, upload-time = "2025-10-08T22:01:38.235Z" }, - { url = "https://files.pythonhosted.org/packages/26/5a/4b546a0405b9cc0659b399f12b6adb750757baf04250b148d3c5059fc4eb/tomli-2.3.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a56212bdcce682e56b0aaf79e869ba5d15a6163f88d5451cbde388d48b13f530", size = 268193, upload-time = "2025-10-08T22:01:39.712Z" }, - { url = "https://files.pythonhosted.org/packages/42/4f/2c12a72ae22cf7b59a7fe75b3465b7aba40ea9145d026ba41cb382075b0e/tomli-2.3.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c5f3ffd1e098dfc032d4d3af5c0ac64f6d286d98bc148698356847b80fa4de1b", size = 275488, upload-time = "2025-10-08T22:01:40.773Z" }, - { url = "https://files.pythonhosted.org/packages/92/04/a038d65dbe160c3aa5a624e93ad98111090f6804027d474ba9c37c8ae186/tomli-2.3.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:5e01decd096b1530d97d5d85cb4dff4af2d8347bd35686654a004f8dea20fc67", size = 272669, upload-time = "2025-10-08T22:01:41.824Z" }, - { url = "https://files.pythonhosted.org/packages/be/2f/8b7c60a9d1612a7cbc39ffcca4f21a73bf368a80fc25bccf8253e2563267/tomli-2.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:8a35dd0e643bb2610f156cca8db95d213a90015c11fee76c946aa62b7ae7e02f", size = 279709, upload-time = "2025-10-08T22:01:43.177Z" }, - { url = "https://files.pythonhosted.org/packages/7e/46/cc36c679f09f27ded940281c38607716c86cf8ba4a518d524e349c8b4874/tomli-2.3.0-cp314-cp314t-win32.whl", hash = "sha256:a1f7f282fe248311650081faafa5f4732bdbfef5d45fe3f2e702fbc6f2d496e0", size = 107563, upload-time = "2025-10-08T22:01:44.233Z" }, - { url = "https://files.pythonhosted.org/packages/84/ff/426ca8683cf7b753614480484f6437f568fd2fda2edbdf57a2d3d8b27a0b/tomli-2.3.0-cp314-cp314t-win_amd64.whl", hash = "sha256:70a251f8d4ba2d9ac2542eecf008b3c8a9fc5c3f9f02c56a9d7952612be2fdba", size = 119756, upload-time = "2025-10-08T22:01:45.234Z" }, - { url = "https://files.pythonhosted.org/packages/77/b8/0135fadc89e73be292b473cb820b4f5a08197779206b33191e801feeae40/tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b", size = 14408, upload-time = "2025-10-08T22:01:46.04Z" }, -] - -[[package]] -name = "typing-extensions" -version = "4.15.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, -] diff --git a/web/.env.example b/web/.env.example index eff6f77fd9..b488c31057 100644 --- a/web/.env.example +++ b/web/.env.example @@ -70,3 +70,6 @@ NEXT_PUBLIC_ENABLE_SINGLE_DOLLAR_LATEX=false # The maximum number of tree node depth for workflow NEXT_PUBLIC_MAX_TREE_DEPTH=50 + +# The API key of amplitude +NEXT_PUBLIC_AMPLITUDE_API_KEY= diff --git a/web/.husky/pre-commit b/web/.husky/pre-commit index 26e9bf69d4..dd4140b47e 100644 --- a/web/.husky/pre-commit +++ b/web/.husky/pre-commit @@ -61,13 +61,13 @@ if $web_modified; then lint-staged if $web_ts_modified; then - echo "Running TypeScript type-check" - if ! pnpm run type-check; then - echo "Type check failed. Please run 'pnpm run type-check' to fix the errors." + echo "Running TypeScript type-check:tsgo" + if ! pnpm run type-check:tsgo; then + echo "Type check failed. Please run 'pnpm run type-check:tsgo' to fix the errors." exit 1 fi else - echo "No staged TypeScript changes detected, skipping type-check" + echo "No staged TypeScript changes detected, skipping type-check:tsgo" fi echo "Running unit tests check" diff --git a/web/__mocks__/provider-context.ts b/web/__mocks__/provider-context.ts new file mode 100644 index 0000000000..594fe38f14 --- /dev/null +++ b/web/__mocks__/provider-context.ts @@ -0,0 +1,47 @@ +import { merge, noop } from 'lodash-es' +import { defaultPlan } from '@/app/components/billing/config' +import { baseProviderContextValue } from '@/context/provider-context' +import type { ProviderContextState } from '@/context/provider-context' +import type { Plan, UsagePlanInfo } from '@/app/components/billing/type' + +export const createMockProviderContextValue = (overrides: Partial = {}): ProviderContextState => { + const merged = merge({}, baseProviderContextValue, overrides) + + return { + ...merged, + refreshModelProviders: merged.refreshModelProviders ?? noop, + onPlanInfoChanged: merged.onPlanInfoChanged ?? noop, + refreshLicenseLimit: merged.refreshLicenseLimit ?? noop, + } +} + +export const createMockPlan = (plan: Plan): ProviderContextState => + createMockProviderContextValue({ + plan: merge({}, defaultPlan, { + type: plan, + }), + }) + +export const createMockPlanUsage = (usage: UsagePlanInfo, ctx: Partial): ProviderContextState => + createMockProviderContextValue({ + ...ctx, + plan: merge(ctx.plan, { + usage, + }), + }) + +export const createMockPlanTotal = (total: UsagePlanInfo, ctx: Partial): ProviderContextState => + createMockProviderContextValue({ + ...ctx, + plan: merge(ctx.plan, { + total, + }), + }) + +export const createMockPlanReset = (reset: Partial, ctx: Partial): ProviderContextState => + createMockProviderContextValue({ + ...ctx, + plan: merge(ctx?.plan, { + reset, + }), + }) diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index da8839e869..3effb79f20 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -121,7 +121,7 @@ const DatasetDetailLayout: FC = (props) => {
{ return ( <> + diff --git a/web/app/(commonLayout)/plugins/page.tsx b/web/app/(commonLayout)/plugins/page.tsx index d07c4307ad..ad61b16ba2 100644 --- a/web/app/(commonLayout)/plugins/page.tsx +++ b/web/app/(commonLayout)/plugins/page.tsx @@ -8,7 +8,7 @@ const PluginList = async () => { return ( } - marketplace={} + marketplace={} /> ) } diff --git a/web/app/(shareLayout)/webapp-signin/normalForm.tsx b/web/app/(shareLayout)/webapp-signin/normalForm.tsx index 44006a9f1e..219722eef3 100644 --- a/web/app/(shareLayout)/webapp-signin/normalForm.tsx +++ b/web/app/(shareLayout)/webapp-signin/normalForm.tsx @@ -94,8 +94,8 @@ const NormalForm = () => { <>
-

{t('login.pageTitle')}

- {!systemFeatures.branding.enabled &&

{t('login.welcome')}

} +

{systemFeatures.branding.enabled ? t('login.pageTitleForE') : t('login.pageTitle')}

+

{t('login.welcome')}

diff --git a/web/app/account/(commonLayout)/avatar.tsx b/web/app/account/(commonLayout)/avatar.tsx index d8943b7879..ef8f6334f1 100644 --- a/web/app/account/(commonLayout)/avatar.tsx +++ b/web/app/account/(commonLayout)/avatar.tsx @@ -12,6 +12,7 @@ import { useProviderContext } from '@/context/provider-context' import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general' import PremiumBadge from '@/app/components/base/premium-badge' import { useLogout } from '@/service/use-common' +import { resetUser } from '@/app/components/base/amplitude/utils' export type IAppSelector = { isMobile: boolean @@ -28,6 +29,7 @@ export default function AppSelector() { await logout() localStorage.removeItem('setup_status') + resetUser() // Tokens are now stored in cookies and cleared by backend router.push('/signin') diff --git a/web/app/account/(commonLayout)/layout.tsx b/web/app/account/(commonLayout)/layout.tsx index b3225b5341..b661c130eb 100644 --- a/web/app/account/(commonLayout)/layout.tsx +++ b/web/app/account/(commonLayout)/layout.tsx @@ -4,6 +4,7 @@ import Header from './header' import SwrInitor from '@/app/components/swr-initializer' import { AppContextProvider } from '@/context/app-context' import GA, { GaType } from '@/app/components/base/ga' +import AmplitudeProvider from '@/app/components/base/amplitude' import HeaderWrapper from '@/app/components/header/header-wrapper' import { EventEmitterContextProvider } from '@/context/event-emitter' import { ProviderContextProvider } from '@/context/provider-context' @@ -13,6 +14,7 @@ const Layout = ({ children }: { children: ReactNode }) => { return ( <> + diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx new file mode 100644 index 0000000000..91e1e9d8fe --- /dev/null +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx @@ -0,0 +1,121 @@ +import React from 'react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import CSVUploader, { type Props } from './csv-uploader' +import { ToastContext } from '@/app/components/base/toast' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +describe('CSVUploader', () => { + const notify = jest.fn() + const updateFile = jest.fn() + + const getDropElements = () => { + const title = screen.getByText('appAnnotation.batchModal.csvUploadTitle') + const dropZone = title.parentElement?.parentElement as HTMLDivElement | null + if (!dropZone || !dropZone.parentElement) + throw new Error('Drop zone not found') + const dropContainer = dropZone.parentElement as HTMLDivElement + return { dropZone, dropContainer } + } + + const renderComponent = (props?: Partial) => { + const mergedProps: Props = { + file: undefined, + updateFile, + ...props, + } + return render( + + + , + ) + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should open the file picker when clicking browse', () => { + const clickSpy = jest.spyOn(HTMLInputElement.prototype, 'click') + renderComponent() + + fireEvent.click(screen.getByText('appAnnotation.batchModal.browse')) + + expect(clickSpy).toHaveBeenCalledTimes(1) + clickSpy.mockRestore() + }) + + it('should toggle dragging styles and upload the dropped file', async () => { + const file = new File(['content'], 'input.csv', { type: 'text/csv' }) + renderComponent() + const { dropZone, dropContainer } = getDropElements() + + fireEvent.dragEnter(dropContainer) + expect(dropZone.className).toContain('border-components-dropzone-border-accent') + expect(dropZone.className).toContain('bg-components-dropzone-bg-accent') + + fireEvent.drop(dropContainer, { dataTransfer: { files: [file] } }) + + await waitFor(() => expect(updateFile).toHaveBeenCalledWith(file)) + expect(dropZone.className).not.toContain('border-components-dropzone-border-accent') + }) + + it('should ignore drop events without dataTransfer', () => { + renderComponent() + const { dropContainer } = getDropElements() + + fireEvent.drop(dropContainer) + + expect(updateFile).not.toHaveBeenCalled() + }) + + it('should show an error when multiple files are dropped', async () => { + const fileA = new File(['a'], 'a.csv', { type: 'text/csv' }) + const fileB = new File(['b'], 'b.csv', { type: 'text/csv' }) + renderComponent() + const { dropContainer } = getDropElements() + + fireEvent.drop(dropContainer, { dataTransfer: { files: [fileA, fileB] } }) + + await waitFor(() => expect(notify).toHaveBeenCalledWith({ + type: 'error', + message: 'datasetCreation.stepOne.uploader.validation.count', + })) + expect(updateFile).not.toHaveBeenCalled() + }) + + it('should propagate file selection changes through input change event', () => { + const file = new File(['row'], 'selected.csv', { type: 'text/csv' }) + const { container } = renderComponent() + const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement + + fireEvent.change(fileInput, { target: { files: [file] } }) + + expect(updateFile).toHaveBeenCalledWith(file) + }) + + it('should render selected file details and allow change/removal', () => { + const file = new File(['data'], 'report.csv', { type: 'text/csv' }) + const { container } = renderComponent({ file }) + const fileInput = container.querySelector('input[type="file"]') as HTMLInputElement + + expect(screen.getByText('report')).toBeInTheDocument() + expect(screen.getByText('.csv')).toBeInTheDocument() + + const clickSpy = jest.spyOn(HTMLInputElement.prototype, 'click') + fireEvent.click(screen.getByText('datasetCreation.stepOne.uploader.change')) + expect(clickSpy).toHaveBeenCalled() + clickSpy.mockRestore() + + const valueSetter = jest.spyOn(fileInput, 'value', 'set') + const removeTrigger = screen.getByTestId('remove-file-button') + fireEvent.click(removeTrigger) + + expect(updateFile).toHaveBeenCalledWith() + expect(valueSetter).toHaveBeenCalledWith('') + }) +}) diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx index b98eb815f9..ccad46b860 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx @@ -114,7 +114,7 @@ const CSVUploader: FC = ({
-
+
diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index bba5ebfa21..2dc45e1337 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -42,6 +42,7 @@ import type { InputVar, Variable } from '@/app/components/workflow/types' import { appDefaultIconBackground } from '@/config' import { useGlobalPublicStore } from '@/context/global-public-context' import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now' +import { useAsyncWindowOpen } from '@/hooks/use-async-window-open' import { AccessMode } from '@/models/access-control' import { useAppWhiteListSubjects, useGetUserCanAccessApp } from '@/service/access-control' import { fetchAppDetailDirect } from '@/service/apps' @@ -153,6 +154,7 @@ const AppPublisher = ({ const { data: userCanAccessApp, isLoading: isGettingUserCanAccessApp, refetch } = useGetUserCanAccessApp({ appId: appDetail?.id, enabled: false }) const { data: appAccessSubjects, isLoading: isGettingAppWhiteListSubjects } = useAppWhiteListSubjects(appDetail?.id, open && systemFeatures.webapp_auth.enabled && appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS) + const openAsyncWindow = useAsyncWindowOpen() const noAccessPermission = useMemo(() => systemFeatures.webapp_auth.enabled && appDetail && appDetail.access_mode !== AccessMode.EXTERNAL_MEMBERS && !userCanAccessApp?.result, [systemFeatures, appDetail, userCanAccessApp]) const disabledFunctionButton = useMemo(() => (!publishedAt || missingStartNode || noAccessPermission), [publishedAt, missingStartNode, noAccessPermission]) @@ -217,17 +219,19 @@ const AppPublisher = ({ }, [disabled, onToggle, open]) const handleOpenInExplore = useCallback(async () => { - try { + await openAsyncWindow(async () => { + if (!appDetail?.id) + throw new Error('App not found') const { installed_apps }: any = await fetchInstalledAppList(appDetail?.id) || {} if (installed_apps?.length > 0) - window.open(`${basePath}/explore/installed/${installed_apps[0].id}`, '_blank') - else - throw new Error('No app found in Explore') - } - catch (e: any) { - Toast.notify({ type: 'error', message: `${e.message || e}` }) - } - }, [appDetail?.id]) + return `${basePath}/explore/installed/${installed_apps[0].id}` + throw new Error('No app found in Explore') + }, { + onError: (err) => { + Toast.notify({ type: 'error', message: `${err.message || err}` }) + }, + }) + }, [appDetail?.id, openAsyncWindow]) const handleAccessControlUpdate = useCallback(async () => { if (!appDetail) diff --git a/web/app/components/app/configuration/config-prompt/confirm-add-var/index.spec.tsx b/web/app/components/app/configuration/config-prompt/confirm-add-var/index.spec.tsx new file mode 100644 index 0000000000..7ffafbb172 --- /dev/null +++ b/web/app/components/app/configuration/config-prompt/confirm-add-var/index.spec.tsx @@ -0,0 +1,49 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import ConfirmAddVar from './index' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +jest.mock('../../base/var-highlight', () => ({ + __esModule: true, + default: ({ name }: { name: string }) => {name}, +})) + +describe('ConfirmAddVar', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should render variable names', () => { + render() + + const highlights = screen.getAllByTestId('var-highlight') + expect(highlights).toHaveLength(2) + expect(highlights[0]).toHaveTextContent('foo') + expect(highlights[1]).toHaveTextContent('bar') + }) + + it('should trigger cancel actions', () => { + const onConfirm = jest.fn() + const onCancel = jest.fn() + render() + + fireEvent.click(screen.getByText('common.operation.cancel')) + + expect(onCancel).toHaveBeenCalledTimes(1) + }) + + it('should trigger confirm actions', () => { + const onConfirm = jest.fn() + const onCancel = jest.fn() + render() + + fireEvent.click(screen.getByText('common.operation.add')) + + expect(onConfirm).toHaveBeenCalledTimes(1) + }) +}) diff --git a/web/app/components/app/configuration/config-prompt/conversation-history/edit-modal.spec.tsx b/web/app/components/app/configuration/config-prompt/conversation-history/edit-modal.spec.tsx new file mode 100644 index 0000000000..652f5409e8 --- /dev/null +++ b/web/app/components/app/configuration/config-prompt/conversation-history/edit-modal.spec.tsx @@ -0,0 +1,56 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import EditModal from './edit-modal' +import type { ConversationHistoriesRole } from '@/models/debug' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +jest.mock('@/app/components/base/modal', () => ({ + __esModule: true, + default: ({ children }: { children: React.ReactNode }) =>
{children}
, +})) + +describe('Conversation history edit modal', () => { + const data: ConversationHistoriesRole = { + user_prefix: 'user', + assistant_prefix: 'assistant', + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should render provided prefixes', () => { + render() + + expect(screen.getByDisplayValue('user')).toBeInTheDocument() + expect(screen.getByDisplayValue('assistant')).toBeInTheDocument() + }) + + it('should update prefixes and save changes', () => { + const onSave = jest.fn() + render() + + fireEvent.change(screen.getByDisplayValue('user'), { target: { value: 'member' } }) + fireEvent.change(screen.getByDisplayValue('assistant'), { target: { value: 'helper' } }) + fireEvent.click(screen.getByText('common.operation.save')) + + expect(onSave).toHaveBeenCalledWith({ + user_prefix: 'member', + assistant_prefix: 'helper', + }) + }) + + it('should call close handler', () => { + const onClose = jest.fn() + render() + + fireEvent.click(screen.getByText('common.operation.cancel')) + + expect(onClose).toHaveBeenCalledTimes(1) + }) +}) diff --git a/web/app/components/app/configuration/config-prompt/conversation-history/history-panel.spec.tsx b/web/app/components/app/configuration/config-prompt/conversation-history/history-panel.spec.tsx new file mode 100644 index 0000000000..61e361c057 --- /dev/null +++ b/web/app/components/app/configuration/config-prompt/conversation-history/history-panel.spec.tsx @@ -0,0 +1,48 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import HistoryPanel from './history-panel' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +const mockDocLink = jest.fn(() => 'doc-link') +jest.mock('@/context/i18n', () => ({ + useDocLink: () => mockDocLink, +})) + +jest.mock('@/app/components/app/configuration/base/operation-btn', () => ({ + __esModule: true, + default: ({ onClick }: { onClick: () => void }) => ( + + ), +})) + +jest.mock('@/app/components/app/configuration/base/feature-panel', () => ({ + __esModule: true, + default: ({ children }: { children: React.ReactNode }) =>
{children}
, +})) + +describe('HistoryPanel', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should render warning content and link when showWarning is true', () => { + render() + + expect(screen.getByText('appDebug.feature.conversationHistory.tip')).toBeInTheDocument() + const link = screen.getByText('appDebug.feature.conversationHistory.learnMore') + expect(link).toHaveAttribute('href', 'doc-link') + }) + + it('should hide warning when showWarning is false', () => { + render() + + expect(screen.queryByText('appDebug.feature.conversationHistory.tip')).toBeNull() + }) +}) diff --git a/web/app/components/app/configuration/config-prompt/index.spec.tsx b/web/app/components/app/configuration/config-prompt/index.spec.tsx new file mode 100644 index 0000000000..b2098862da --- /dev/null +++ b/web/app/components/app/configuration/config-prompt/index.spec.tsx @@ -0,0 +1,351 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import Prompt, { type IPromptProps } from './index' +import ConfigContext from '@/context/debug-configuration' +import { MAX_PROMPT_MESSAGE_LENGTH } from '@/config' +import { type PromptItem, PromptRole, type PromptVariable } from '@/models/debug' +import { AppModeEnum, ModelModeType } from '@/types/app' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +type DebugConfiguration = { + isAdvancedMode: boolean + currentAdvancedPrompt: PromptItem | PromptItem[] + setCurrentAdvancedPrompt: (prompt: PromptItem | PromptItem[], isUserChanged?: boolean) => void + modelModeType: ModelModeType + dataSets: Array<{ + id: string + name?: string + }> + hasSetBlockStatus: { + context: boolean + history: boolean + query: boolean + } +} + +const defaultPromptVariables: PromptVariable[] = [ + { key: 'var', name: 'Variable', type: 'string', required: true }, +] + +let mockSimplePromptInputProps: IPromptProps | null = null + +jest.mock('./simple-prompt-input', () => ({ + __esModule: true, + default: (props: IPromptProps) => { + mockSimplePromptInputProps = props + return ( +
props.onChange?.('mocked prompt', props.promptVariables)} + > + SimplePromptInput Mock +
+ ) + }, +})) + +type AdvancedMessageInputProps = { + isChatMode: boolean + type: PromptRole + value: string + onTypeChange: (value: PromptRole) => void + canDelete: boolean + onDelete: () => void + onChange: (value: string) => void + promptVariables: PromptVariable[] + isContextMissing: boolean + onHideContextMissingTip: () => void + noResize?: boolean +} + +jest.mock('./advanced-prompt-input', () => ({ + __esModule: true, + default: (props: AdvancedMessageInputProps) => { + return ( +
+ + + + +
+ ) + }, +})) +const getContextValue = (overrides: Partial = {}): DebugConfiguration => { + return { + setCurrentAdvancedPrompt: jest.fn(), + isAdvancedMode: false, + currentAdvancedPrompt: [], + modelModeType: ModelModeType.chat, + dataSets: [], + hasSetBlockStatus: { + context: false, + history: false, + query: false, + }, + ...overrides, + } +} + +const renderComponent = ( + props: Partial = {}, + contextOverrides: Partial = {}, +) => { + const mergedProps: IPromptProps = { + mode: AppModeEnum.CHAT, + promptTemplate: 'initial template', + promptVariables: defaultPromptVariables, + onChange: jest.fn(), + ...props, + } + const contextValue = getContextValue(contextOverrides) + + return { + contextValue, + ...render( + + + , + ), + } +} + +describe('Prompt config component', () => { + beforeEach(() => { + jest.clearAllMocks() + mockSimplePromptInputProps = null + }) + + // Rendering simple mode + it('should render simple prompt when advanced mode is disabled', () => { + const onChange = jest.fn() + renderComponent({ onChange }, { isAdvancedMode: false }) + + const simplePrompt = screen.getByTestId('simple-prompt-input') + expect(simplePrompt).toBeInTheDocument() + expect(simplePrompt).toHaveAttribute('data-mode', AppModeEnum.CHAT) + expect(mockSimplePromptInputProps?.promptTemplate).toBe('initial template') + fireEvent.click(simplePrompt) + expect(onChange).toHaveBeenCalledWith('mocked prompt', defaultPromptVariables) + expect(screen.queryByTestId('advanced-message-input')).toBeNull() + }) + + // Rendering advanced chat messages + it('should render advanced chat prompts and show context missing tip when dataset context is not set', () => { + const currentAdvancedPrompt: PromptItem[] = [ + { role: PromptRole.user, text: 'first' }, + { role: PromptRole.assistant, text: 'second' }, + ] + renderComponent( + {}, + { + isAdvancedMode: true, + currentAdvancedPrompt, + modelModeType: ModelModeType.chat, + dataSets: [{ id: 'ds' } as unknown as DebugConfiguration['dataSets'][number]], + hasSetBlockStatus: { context: false, history: true, query: true }, + }, + ) + + const renderedMessages = screen.getAllByTestId('advanced-message-input') + expect(renderedMessages).toHaveLength(2) + expect(renderedMessages[0]).toHaveAttribute('data-context-missing', 'true') + fireEvent.click(screen.getAllByText('hide-context')[0]) + expect(screen.getAllByTestId('advanced-message-input')[0]).toHaveAttribute('data-context-missing', 'false') + }) + + // Chat message mutations + it('should update chat prompt value and call setter with user change flag', () => { + const currentAdvancedPrompt: PromptItem[] = [ + { role: PromptRole.user, text: 'first' }, + { role: PromptRole.assistant, text: 'second' }, + ] + const setCurrentAdvancedPrompt = jest.fn() + renderComponent( + {}, + { + isAdvancedMode: true, + currentAdvancedPrompt, + modelModeType: ModelModeType.chat, + setCurrentAdvancedPrompt, + }, + ) + + fireEvent.click(screen.getAllByText('change')[0]) + expect(setCurrentAdvancedPrompt).toHaveBeenCalledWith( + [ + { role: PromptRole.user, text: 'updated text' }, + { role: PromptRole.assistant, text: 'second' }, + ], + true, + ) + }) + + it('should update chat prompt role when type changes', () => { + const currentAdvancedPrompt: PromptItem[] = [ + { role: PromptRole.user, text: 'first' }, + { role: PromptRole.user, text: 'second' }, + ] + const setCurrentAdvancedPrompt = jest.fn() + renderComponent( + {}, + { + isAdvancedMode: true, + currentAdvancedPrompt, + modelModeType: ModelModeType.chat, + setCurrentAdvancedPrompt, + }, + ) + + fireEvent.click(screen.getAllByText('type')[1]) + expect(setCurrentAdvancedPrompt).toHaveBeenCalledWith( + [ + { role: PromptRole.user, text: 'first' }, + { role: PromptRole.assistant, text: 'second' }, + ], + ) + }) + + it('should delete chat prompt item', () => { + const currentAdvancedPrompt: PromptItem[] = [ + { role: PromptRole.user, text: 'first' }, + { role: PromptRole.assistant, text: 'second' }, + ] + const setCurrentAdvancedPrompt = jest.fn() + renderComponent( + {}, + { + isAdvancedMode: true, + currentAdvancedPrompt, + modelModeType: ModelModeType.chat, + setCurrentAdvancedPrompt, + }, + ) + + fireEvent.click(screen.getAllByText('delete')[0]) + expect(setCurrentAdvancedPrompt).toHaveBeenCalledWith([{ role: PromptRole.assistant, text: 'second' }]) + }) + + // Add message behavior + it('should append a mirrored role message when clicking add in chat mode', () => { + const currentAdvancedPrompt: PromptItem[] = [ + { role: PromptRole.user, text: 'first' }, + ] + const setCurrentAdvancedPrompt = jest.fn() + renderComponent( + {}, + { + isAdvancedMode: true, + currentAdvancedPrompt, + modelModeType: ModelModeType.chat, + setCurrentAdvancedPrompt, + }, + ) + + fireEvent.click(screen.getByText('appDebug.promptMode.operation.addMessage')) + expect(setCurrentAdvancedPrompt).toHaveBeenCalledWith([ + { role: PromptRole.user, text: 'first' }, + { role: PromptRole.assistant, text: '' }, + ]) + }) + + it('should append a user role when the last chat prompt is from assistant', () => { + const currentAdvancedPrompt: PromptItem[] = [ + { role: PromptRole.assistant, text: 'reply' }, + ] + const setCurrentAdvancedPrompt = jest.fn() + renderComponent( + {}, + { + isAdvancedMode: true, + currentAdvancedPrompt, + modelModeType: ModelModeType.chat, + setCurrentAdvancedPrompt, + }, + ) + + fireEvent.click(screen.getByText('appDebug.promptMode.operation.addMessage')) + expect(setCurrentAdvancedPrompt).toHaveBeenCalledWith([ + { role: PromptRole.assistant, text: 'reply' }, + { role: PromptRole.user, text: '' }, + ]) + }) + + it('should insert a system message when adding to an empty chat prompt list', () => { + const setCurrentAdvancedPrompt = jest.fn() + renderComponent( + {}, + { + isAdvancedMode: true, + currentAdvancedPrompt: [], + modelModeType: ModelModeType.chat, + setCurrentAdvancedPrompt, + }, + ) + + fireEvent.click(screen.getByText('appDebug.promptMode.operation.addMessage')) + expect(setCurrentAdvancedPrompt).toHaveBeenCalledWith([{ role: PromptRole.system, text: '' }]) + }) + + it('should not show add button when reaching max prompt length', () => { + const prompts: PromptItem[] = Array.from({ length: MAX_PROMPT_MESSAGE_LENGTH }, (_, index) => ({ + role: PromptRole.user, + text: `item-${index}`, + })) + renderComponent( + {}, + { + isAdvancedMode: true, + currentAdvancedPrompt: prompts, + modelModeType: ModelModeType.chat, + }, + ) + + expect(screen.queryByText('appDebug.promptMode.operation.addMessage')).toBeNull() + }) + + // Completion mode + it('should update completion prompt value and flag as user change', () => { + const setCurrentAdvancedPrompt = jest.fn() + renderComponent( + {}, + { + isAdvancedMode: true, + currentAdvancedPrompt: { role: PromptRole.user, text: 'single' }, + modelModeType: ModelModeType.completion, + setCurrentAdvancedPrompt, + }, + ) + + fireEvent.click(screen.getByText('change')) + + expect(setCurrentAdvancedPrompt).toHaveBeenCalledWith({ role: PromptRole.user, text: 'updated text' }, true) + }) +}) diff --git a/web/app/components/app/configuration/config-prompt/message-type-selector.spec.tsx b/web/app/components/app/configuration/config-prompt/message-type-selector.spec.tsx new file mode 100644 index 0000000000..4401b7e57e --- /dev/null +++ b/web/app/components/app/configuration/config-prompt/message-type-selector.spec.tsx @@ -0,0 +1,37 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import MessageTypeSelector from './message-type-selector' +import { PromptRole } from '@/models/debug' + +describe('MessageTypeSelector', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should render current value and keep options hidden by default', () => { + render() + + expect(screen.getByText(PromptRole.user)).toBeInTheDocument() + expect(screen.queryByText(PromptRole.system)).toBeNull() + }) + + it('should toggle option list when clicking the selector', () => { + render() + + fireEvent.click(screen.getByText(PromptRole.system)) + + expect(screen.getByText(PromptRole.user)).toBeInTheDocument() + expect(screen.getByText(PromptRole.assistant)).toBeInTheDocument() + }) + + it('should call onChange with selected type and close the list', () => { + const onChange = jest.fn() + render() + + fireEvent.click(screen.getByText(PromptRole.assistant)) + fireEvent.click(screen.getByText(PromptRole.user)) + + expect(onChange).toHaveBeenCalledWith(PromptRole.user) + expect(screen.queryByText(PromptRole.system)).toBeNull() + }) +}) diff --git a/web/app/components/app/configuration/config-prompt/prompt-editor-height-resize-wrap.spec.tsx b/web/app/components/app/configuration/config-prompt/prompt-editor-height-resize-wrap.spec.tsx new file mode 100644 index 0000000000..d6bef4cdd7 --- /dev/null +++ b/web/app/components/app/configuration/config-prompt/prompt-editor-height-resize-wrap.spec.tsx @@ -0,0 +1,66 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import PromptEditorHeightResizeWrap from './prompt-editor-height-resize-wrap' + +describe('PromptEditorHeightResizeWrap', () => { + beforeEach(() => { + jest.clearAllMocks() + jest.useFakeTimers() + }) + + afterEach(() => { + jest.runOnlyPendingTimers() + jest.useRealTimers() + }) + + it('should render children, footer, and hide resize handler when requested', () => { + const { container } = render( + footer
} + hideResize + > +
content
+ , + ) + + expect(screen.getByText('content')).toBeInTheDocument() + expect(screen.getByText('footer')).toBeInTheDocument() + expect(container.querySelector('.cursor-row-resize')).toBeNull() + }) + + it('should resize height with mouse events and clamp to minHeight', () => { + const onHeightChange = jest.fn() + + const { container } = render( + +
content
+
, + ) + + const handle = container.querySelector('.cursor-row-resize') + expect(handle).not.toBeNull() + + fireEvent.mouseDown(handle as Element, { clientY: 100 }) + expect(document.body.style.userSelect).toBe('none') + + fireEvent.mouseMove(document, { clientY: 130 }) + jest.runAllTimers() + expect(onHeightChange).toHaveBeenLastCalledWith(180) + + onHeightChange.mockClear() + fireEvent.mouseMove(document, { clientY: -100 }) + jest.runAllTimers() + expect(onHeightChange).toHaveBeenLastCalledWith(100) + + fireEvent.mouseUp(document) + expect(document.body.style.userSelect).toBe('') + }) +}) diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx index f2b9c105fc..5716bfd92d 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx @@ -32,6 +32,7 @@ import { canFindTool } from '@/utils' import { useAllBuiltInTools, useAllCustomTools, useAllMCPTools, useAllWorkflowTools } from '@/service/use-tools' import type { ToolWithProvider } from '@/app/components/workflow/types' import { useMittContextSelector } from '@/context/mitt-context' +import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' type AgentToolWithMoreInfo = AgentTool & { icon: any; collection?: Collection } | null const AgentTools: FC = () => { @@ -93,13 +94,17 @@ const AgentTools: FC = () => { const [isDeleting, setIsDeleting] = useState(-1) const getToolValue = (tool: ToolDefaultValue) => { + const currToolInCollections = collectionList.find(c => c.id === tool.provider_id) + const currToolWithConfigs = currToolInCollections?.tools.find(t => t.name === tool.tool_name) + const formSchemas = currToolWithConfigs ? toolParametersToFormSchemas(currToolWithConfigs.parameters) : [] + const paramsWithDefaultValue = addDefaultValue(tool.params, formSchemas) return { provider_id: tool.provider_id, provider_type: tool.provider_type as CollectionType, provider_name: tool.provider_name, tool_name: tool.tool_name, tool_label: tool.tool_label, - tool_parameters: tool.params, + tool_parameters: paramsWithDefaultValue, notAuthor: !tool.is_team_authorization, enabled: true, } @@ -119,7 +124,7 @@ const AgentTools: FC = () => { } const getProviderShowName = (item: AgentTool) => { const type = item.provider_type - if(type === CollectionType.builtIn) + if (type === CollectionType.builtIn) return item.provider_name.split('/').pop() return item.provider_name } @@ -251,6 +256,7 @@ const AgentTools: FC = () => { {!item.notAuthor && (
{ setCurrentTool(item) diff --git a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx index ef28dd222c..c5947495db 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx @@ -16,7 +16,7 @@ import Description from '@/app/components/plugins/card/base/description' import TabSlider from '@/app/components/base/tab-slider-plain' import Button from '@/app/components/base/button' import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form' -import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' +import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' import type { Collection, Tool } from '@/app/components/tools/types' import { CollectionType } from '@/app/components/tools/types' import { fetchBuiltInToolList, fetchCustomToolList, fetchModelToolList, fetchWorkflowToolList } from '@/service/tools' @@ -92,15 +92,11 @@ const SettingBuiltInTool: FC = ({ }()) }) setTools(list) - const currTool = list.find(tool => tool.name === toolName) - if (currTool) { - const formSchemas = toolParametersToFormSchemas(currTool.parameters) - setTempSetting(addDefaultValue(setting, formSchemas)) - } } catch { } setIsLoading(false) })() + // eslint-disable-next-line react-hooks/exhaustive-deps }, [collection?.name, collection?.id, collection?.type]) useEffect(() => { @@ -249,7 +245,7 @@ const SettingBuiltInTool: FC = ({ {!readonly && !isInfoActive && (
- +
)}
diff --git a/web/app/components/app/configuration/dataset-config/index.tsx b/web/app/components/app/configuration/dataset-config/index.tsx index bf81858565..44a54f8e8b 100644 --- a/web/app/components/app/configuration/dataset-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/index.tsx @@ -77,7 +77,7 @@ const DatasetConfig: FC = () => { const oldRetrievalConfig = { top_k, score_threshold, - reranking_model: (reranking_model.reranking_provider_name && reranking_model.reranking_model_name) ? { + reranking_model: (reranking_model && reranking_model.reranking_provider_name && reranking_model.reranking_model_name) ? { provider: reranking_model.reranking_provider_name, model: reranking_model.reranking_model_name, } : undefined, diff --git a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx index feb7a38165..6857c38e1e 100644 --- a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx +++ b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx @@ -1,18 +1,20 @@ 'use client' import type { FC } from 'react' -import React, { useRef, useState } from 'react' -import { useGetState, useInfiniteScroll } from 'ahooks' +import React, { useEffect, useMemo, useRef, useState } from 'react' +import { useInfiniteScroll } from 'ahooks' import { useTranslation } from 'react-i18next' import Link from 'next/link' import Modal from '@/app/components/base/modal' import type { DataSet } from '@/models/datasets' import Button from '@/app/components/base/button' -import { fetchDatasets } from '@/service/datasets' import Loading from '@/app/components/base/loading' import Badge from '@/app/components/base/badge' import { useKnowledge } from '@/hooks/use-knowledge' import cn from '@/utils/classnames' import AppIcon from '@/app/components/base/app-icon' +import { useInfiniteDatasets } from '@/service/knowledge/use-dataset' +import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import FeatureIcon from '@/app/components/header/account-setting/model-provider-page/model-selector/feature-icon' export type ISelectDataSetProps = { isShow: boolean @@ -28,51 +30,70 @@ const SelectDataSet: FC = ({ onSelect, }) => { const { t } = useTranslation() - const [selected, setSelected] = React.useState([]) - const [loaded, setLoaded] = React.useState(false) - const [datasets, setDataSets] = React.useState(null) - const [hasInitialized, setHasInitialized] = React.useState(false) - const hasNoData = !datasets || datasets?.length === 0 + const [selected, setSelected] = useState([]) const canSelectMulti = true + const { formatIndexingTechniqueAndMethod } = useKnowledge() + const { data, isLoading, isFetchingNextPage, fetchNextPage, hasNextPage } = useInfiniteDatasets( + { page: 1 }, + { enabled: isShow, staleTime: 0, refetchOnMount: 'always' }, + ) + const pages = data?.pages || [] + const datasets = useMemo(() => { + return pages.flatMap(page => page.data.filter(item => item.indexing_technique || item.provider === 'external')) + }, [pages]) + const hasNoData = !isLoading && datasets.length === 0 const listRef = useRef(null) - const [page, setPage, getPage] = useGetState(1) - const [isNoMore, setIsNoMore] = useState(false) - const { formatIndexingTechniqueAndMethod } = useKnowledge() + const isNoMore = hasNextPage === false useInfiniteScroll( async () => { - if (!isNoMore) { - const { data, has_more } = await fetchDatasets({ url: '/datasets', params: { page } }) - setPage(getPage() + 1) - setIsNoMore(!has_more) - const newList = [...(datasets || []), ...data.filter(item => item.indexing_technique || item.provider === 'external')] - setDataSets(newList) - setLoaded(true) - - // Initialize selected datasets based on selectedIds and available datasets - if (!hasInitialized) { - if (selectedIds.length > 0) { - const validSelectedDatasets = selectedIds - .map(id => newList.find(item => item.id === id)) - .filter(Boolean) as DataSet[] - setSelected(validSelectedDatasets) - } - setHasInitialized(true) - } - } + if (!hasNextPage || isFetchingNextPage) + return { list: [] } + await fetchNextPage() return { list: [] } }, { target: listRef, - isNoMore: () => { - return isNoMore - }, - reloadDeps: [isNoMore], + isNoMore: () => isNoMore, + reloadDeps: [isNoMore, isFetchingNextPage], }, ) + const prevSelectedIdsRef = useRef([]) + const hasUserModifiedSelectionRef = useRef(false) + useEffect(() => { + if (isShow) + hasUserModifiedSelectionRef.current = false + }, [isShow]) + useEffect(() => { + const prevSelectedIds = prevSelectedIdsRef.current + const idsChanged = selectedIds.length !== prevSelectedIds.length + || selectedIds.some((id, idx) => id !== prevSelectedIds[idx]) + + if (!selectedIds.length && (!hasUserModifiedSelectionRef.current || idsChanged)) { + setSelected([]) + prevSelectedIdsRef.current = selectedIds + hasUserModifiedSelectionRef.current = false + return + } + + if (!idsChanged && hasUserModifiedSelectionRef.current) + return + + setSelected((prev) => { + const prevMap = new Map(prev.map(item => [item.id, item])) + const nextSelected = selectedIds + .map(id => datasets.find(item => item.id === id) || prevMap.get(id)) + .filter(Boolean) as DataSet[] + return nextSelected + }) + prevSelectedIdsRef.current = selectedIds + hasUserModifiedSelectionRef.current = false + }, [datasets, selectedIds]) + const toggleSelect = (dataSet: DataSet) => { + hasUserModifiedSelectionRef.current = true const isSelected = selected.some(item => item.id === dataSet.id) if (isSelected) { setSelected(selected.filter(item => item.id !== dataSet.id)) @@ -96,13 +117,13 @@ const SelectDataSet: FC = ({ className='w-[400px]' title={t('appDebug.feature.dataSet.selectTitle')} > - {!loaded && ( + {(isLoading && datasets.length === 0) && (
)} - {(loaded && hasNoData) && ( + {hasNoData && (
= ({
)} - {datasets && datasets?.length > 0 && ( + {datasets.length > 0 && ( <>
{datasets.map(item => (
i.id === item.id) && 'border-[1.5px] border-components-option-card-option-selected-border bg-state-accent-hover shadow-xs hover:border-components-option-card-option-selected-border hover:bg-state-accent-hover hover:shadow-xs', !item.embedding_available && 'hover:border-components-panel-border-subtle hover:bg-components-panel-on-panel-item-bg hover:shadow-xs', )} @@ -131,7 +152,7 @@ const SelectDataSet: FC = ({ toggleSelect(item) }} > -
+
= ({ {t('dataset.unavailable')} )}
+ {item.is_multimodal && ( +
+ +
+ )} { item.indexing_technique && ( = ({
)} - {loaded && ( + {!isLoading && (
{selected.length > 0 && `${selected.length} ${t('appDebug.feature.dataSet.selected')}`} diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx index 93d0384aee..cd6e39011e 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx @@ -1,5 +1,5 @@ import type { FC } from 'react' -import { useRef, useState } from 'react' +import { useMemo, useRef, useState } from 'react' import { useMount } from 'ahooks' import { useTranslation } from 'react-i18next' import { isEqual } from 'lodash-es' @@ -25,15 +25,13 @@ import { isReRankModelSelected } from '@/app/components/datasets/common/check-re import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' import PermissionSelector from '@/app/components/datasets/settings/permission-selector' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' -import { - useModelList, - useModelListAndDefaultModelAndCurrentProviderAndModel, -} from '@/app/components/header/account-setting/model-provider-page/hooks' +import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { fetchMembers } from '@/service/common' import type { Member } from '@/models/common' import { IndexingType } from '@/app/components/datasets/create/step-two' import { useDocLink } from '@/context/i18n' +import { checkShowMultiModalTip } from '@/app/components/datasets/settings/utils' type SettingsModalProps = { currentDataset: DataSet @@ -54,10 +52,8 @@ const SettingsModal: FC = ({ onCancel, onSave, }) => { - const { data: embeddingsModelList } = useModelList(ModelTypeEnum.textEmbedding) - const { - modelList: rerankModelList, - } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) + const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding) + const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank) const { t } = useTranslation() const docLink = useDocLink() const { notify } = useToastContext() @@ -181,6 +177,23 @@ const SettingsModal: FC = ({ getMembers() }) + const showMultiModalTip = useMemo(() => { + return checkShowMultiModalTip({ + embeddingModel: { + provider: localeCurrentDataset.embedding_model_provider, + model: localeCurrentDataset.embedding_model, + }, + rerankingEnable: retrievalConfig.reranking_enable, + rerankModel: { + rerankingProviderName: retrievalConfig.reranking_model.reranking_provider_name, + rerankingModelName: retrievalConfig.reranking_model.reranking_model_name, + }, + indexMethod, + embeddingModelList, + rerankModelList, + }) + }, [localeCurrentDataset.embedding_model, localeCurrentDataset.embedding_model_provider, retrievalConfig.reranking_enable, retrievalConfig.reranking_model, indexMethod, embeddingModelList, rerankModelList]) + return (
= ({ provider: localeCurrentDataset.embedding_model_provider, model: localeCurrentDataset.embedding_model, }} - modelList={embeddingsModelList} + modelList={embeddingModelList} />
@@ -344,6 +357,7 @@ const SettingsModal: FC = ({ ) : ( diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx new file mode 100644 index 0000000000..7607a21b07 --- /dev/null +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx @@ -0,0 +1,480 @@ +import '@testing-library/jest-dom' +import type { CSSProperties } from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import DebugWithMultipleModel from './index' +import type { DebugWithMultipleModelContextType } from './context' +import { APP_CHAT_WITH_MULTIPLE_MODEL } from '../types' +import type { ModelAndParameter } from '../types' +import type { Inputs, ModelConfig } from '@/models/debug' +import { DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config' +import type { FeatureStoreState } from '@/app/components/base/features/store' +import type { FileEntity } from '@/app/components/base/file-uploader/types' +import type { InputForm } from '@/app/components/base/chat/chat/type' +import { AppModeEnum, ModelModeType, type PromptVariable, Resolution, TransferMethod } from '@/types/app' + +type PromptVariableWithMeta = Omit & { + type: PromptVariable['type'] | 'api' + required?: boolean + hide?: boolean +} + +const mockUseDebugConfigurationContext = jest.fn() +const mockUseFeaturesSelector = jest.fn() +const mockUseEventEmitterContext = jest.fn() +const mockUseAppStoreSelector = jest.fn() +const mockEventEmitter = { emit: jest.fn() } +const mockSetShowAppConfigureFeaturesModal = jest.fn() +let capturedChatInputProps: MockChatInputAreaProps | null = null +let modelIdCounter = 0 +let featureState: FeatureStoreState + +type MockChatInputAreaProps = { + onSend?: (message: string, files?: FileEntity[]) => void + onFeatureBarClick?: (state: boolean) => void + showFeatureBar?: boolean + showFileUpload?: boolean + inputs?: Record + inputsForm?: InputForm[] + speechToTextConfig?: unknown + visionConfig?: unknown +} + +const mockFiles: FileEntity[] = [ + { + id: 'file-1', + name: 'file.txt', + size: 10, + type: 'text/plain', + progress: 100, + transferMethod: TransferMethod.remote_url, + supportFileType: 'text', + }, +] + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +jest.mock('@/context/debug-configuration', () => ({ + __esModule: true, + useDebugConfigurationContext: () => mockUseDebugConfigurationContext(), +})) + +jest.mock('@/app/components/base/features/hooks', () => ({ + __esModule: true, + useFeatures: (selector: (state: FeatureStoreState) => unknown) => mockUseFeaturesSelector(selector), +})) + +jest.mock('@/context/event-emitter', () => ({ + __esModule: true, + useEventEmitterContextContext: () => mockUseEventEmitterContext(), +})) + +jest.mock('@/app/components/app/store', () => ({ + __esModule: true, + useStore: (selector: (state: { setShowAppConfigureFeaturesModal: typeof mockSetShowAppConfigureFeaturesModal }) => unknown) => mockUseAppStoreSelector(selector), +})) + +jest.mock('./debug-item', () => ({ + __esModule: true, + default: ({ + modelAndParameter, + className, + style, + }: { + modelAndParameter: ModelAndParameter + className?: string + style?: CSSProperties + }) => ( +
+ DebugItem-{modelAndParameter.id} +
+ ), +})) + +jest.mock('@/app/components/base/chat/chat/chat-input-area', () => ({ + __esModule: true, + default: (props: MockChatInputAreaProps) => { + capturedChatInputProps = props + return ( +
+ + +
+ ) + }, +})) + +const createFeatureState = (): FeatureStoreState => ({ + features: { + speech2text: { enabled: true }, + file: { + image: { + enabled: true, + detail: Resolution.high, + number_limits: 2, + transfer_methods: [TransferMethod.remote_url], + }, + }, + }, + setFeatures: jest.fn(), + showFeaturesModal: false, + setShowFeaturesModal: jest.fn(), +}) + +const createModelConfig = (promptVariables: PromptVariableWithMeta[] = []): ModelConfig => ({ + provider: 'OPENAI', + model_id: 'gpt-4', + mode: ModelModeType.chat, + configs: { + prompt_template: '', + prompt_variables: promptVariables as unknown as PromptVariable[], + }, + chat_prompt_config: DEFAULT_CHAT_PROMPT_CONFIG, + completion_prompt_config: DEFAULT_COMPLETION_PROMPT_CONFIG, + opening_statement: '', + more_like_this: null, + suggested_questions: [], + suggested_questions_after_answer: null, + speech_to_text: null, + text_to_speech: null, + file_upload: null, + retriever_resource: null, + sensitive_word_avoidance: null, + annotation_reply: null, + external_data_tools: [], + system_parameters: { + audio_file_size_limit: 0, + file_size_limit: 0, + image_file_size_limit: 0, + video_file_size_limit: 0, + workflow_file_upload_limit: 0, + }, + dataSets: [], + agentConfig: DEFAULT_AGENT_SETTING, +}) + +type DebugConfiguration = { + mode: AppModeEnum + inputs: Inputs + modelConfig: ModelConfig +} + +const createDebugConfiguration = (overrides: Partial = {}): DebugConfiguration => ({ + mode: AppModeEnum.CHAT, + inputs: {}, + modelConfig: createModelConfig(), + ...overrides, +}) + +const createModelAndParameter = (overrides: Partial = {}): ModelAndParameter => ({ + id: `model-${++modelIdCounter}`, + model: 'gpt-3.5-turbo', + provider: 'openai', + parameters: {}, + ...overrides, +}) + +const createProps = (overrides: Partial = {}): DebugWithMultipleModelContextType => ({ + multipleModelConfigs: [createModelAndParameter()], + onMultipleModelConfigsChange: jest.fn(), + onDebugWithMultipleModelChange: jest.fn(), + ...overrides, +}) + +const renderComponent = (props?: Partial) => { + const mergedProps = createProps(props) + return render() +} + +describe('DebugWithMultipleModel', () => { + beforeEach(() => { + jest.clearAllMocks() + capturedChatInputProps = null + modelIdCounter = 0 + featureState = createFeatureState() + mockUseFeaturesSelector.mockImplementation(selector => selector(featureState)) + mockUseEventEmitterContext.mockReturnValue({ eventEmitter: mockEventEmitter }) + mockUseAppStoreSelector.mockImplementation(selector => selector({ setShowAppConfigureFeaturesModal: mockSetShowAppConfigureFeaturesModal })) + mockUseDebugConfigurationContext.mockReturnValue(createDebugConfiguration()) + }) + + describe('chat input rendering', () => { + it('should render chat input in chat mode with transformed prompt variables and feature handler', () => { + // Arrange + const promptVariables: PromptVariableWithMeta[] = [ + { key: 'city', name: 'City', type: 'string', required: true }, + { key: 'audience', name: 'Audience', type: 'number' }, + { key: 'hidden', name: 'Hidden', type: 'select', hide: true }, + { key: 'api-only', name: 'API Only', type: 'api' }, + ] + const debugConfiguration = createDebugConfiguration({ + inputs: { audience: 'engineers' }, + modelConfig: createModelConfig(promptVariables), + }) + mockUseDebugConfigurationContext.mockReturnValue(debugConfiguration) + + // Act + renderComponent() + fireEvent.click(screen.getByRole('button', { name: /feature/i })) + + // Assert + expect(screen.getByTestId('chat-input-area')).toBeInTheDocument() + expect(capturedChatInputProps?.inputs).toEqual({ audience: 'engineers' }) + expect(capturedChatInputProps?.inputsForm).toEqual([ + expect.objectContaining({ label: 'City', variable: 'city', hide: false, required: true }), + expect.objectContaining({ label: 'Audience', variable: 'audience', hide: false, required: false }), + expect.objectContaining({ label: 'Hidden', variable: 'hidden', hide: true, required: false }), + ]) + expect(capturedChatInputProps?.showFeatureBar).toBe(true) + expect(capturedChatInputProps?.showFileUpload).toBe(false) + expect(capturedChatInputProps?.speechToTextConfig).toEqual(featureState.features.speech2text) + expect(capturedChatInputProps?.visionConfig).toEqual(featureState.features.file) + expect(mockSetShowAppConfigureFeaturesModal).toHaveBeenCalledWith(true) + }) + + it('should render chat input in agent chat mode', () => { + // Arrange + mockUseDebugConfigurationContext.mockReturnValue(createDebugConfiguration({ + mode: AppModeEnum.AGENT_CHAT, + })) + + // Act + renderComponent() + + // Assert + expect(screen.getByTestId('chat-input-area')).toBeInTheDocument() + }) + + it('should hide chat input when not in chat mode', () => { + // Arrange + mockUseDebugConfigurationContext.mockReturnValue(createDebugConfiguration({ + mode: AppModeEnum.COMPLETION, + })) + const multipleModelConfigs = [createModelAndParameter()] + + // Act + renderComponent({ multipleModelConfigs }) + + // Assert + expect(screen.queryByTestId('chat-input-area')).not.toBeInTheDocument() + expect(screen.getAllByTestId('debug-item')).toHaveLength(1) + }) + }) + + describe('sending flow', () => { + it('should emit chat event when allowed to send', () => { + // Arrange + const checkCanSend = jest.fn(() => true) + const multipleModelConfigs = [createModelAndParameter(), createModelAndParameter()] + renderComponent({ multipleModelConfigs, checkCanSend }) + + // Act + fireEvent.click(screen.getByRole('button', { name: /send/i })) + + // Assert + expect(checkCanSend).toHaveBeenCalled() + expect(mockEventEmitter.emit).toHaveBeenCalledWith({ + type: APP_CHAT_WITH_MULTIPLE_MODEL, + payload: { + message: 'test message', + files: mockFiles, + }, + }) + }) + + it('should emit when no checkCanSend is provided', () => { + renderComponent() + + fireEvent.click(screen.getByRole('button', { name: /send/i })) + + expect(mockEventEmitter.emit).toHaveBeenCalledWith({ + type: APP_CHAT_WITH_MULTIPLE_MODEL, + payload: { + message: 'test message', + files: mockFiles, + }, + }) + }) + + it('should block sending when checkCanSend returns false', () => { + // Arrange + const checkCanSend = jest.fn(() => false) + renderComponent({ checkCanSend }) + + // Act + fireEvent.click(screen.getByRole('button', { name: /send/i })) + + // Assert + expect(checkCanSend).toHaveBeenCalled() + expect(mockEventEmitter.emit).not.toHaveBeenCalled() + }) + + it('should tolerate missing event emitter without throwing', () => { + mockUseEventEmitterContext.mockReturnValue({ eventEmitter: null }) + renderComponent() + + expect(() => fireEvent.click(screen.getByRole('button', { name: /send/i }))).not.toThrow() + expect(mockEventEmitter.emit).not.toHaveBeenCalled() + }) + }) + + describe('layout sizing and positioning', () => { + const expectItemLayout = ( + element: HTMLElement, + expectation: { + width?: string + height?: string + transform: string + classes?: string[] + }, + ) => { + if (expectation.width !== undefined) + expect(element.style.width).toBe(expectation.width) + else + expect(element.style.width).toBe('') + + if (expectation.height !== undefined) + expect(element.style.height).toBe(expectation.height) + else + expect(element.style.height).toBe('') + + expect(element.style.transform).toBe(expectation.transform) + expectation.classes?.forEach(cls => expect(element).toHaveClass(cls)) + } + + it('should arrange items in two-column layout for two models', () => { + // Arrange + const multipleModelConfigs = [createModelAndParameter(), createModelAndParameter()] + + // Act + renderComponent({ multipleModelConfigs }) + const items = screen.getAllByTestId('debug-item') + + // Assert + expect(items).toHaveLength(2) + expectItemLayout(items[0], { + width: 'calc(50% - 4px - 24px)', + height: '100%', + transform: 'translateX(0) translateY(0)', + classes: ['mr-2'], + }) + expectItemLayout(items[1], { + width: 'calc(50% - 4px - 24px)', + height: '100%', + transform: 'translateX(calc(100% + 8px)) translateY(0)', + classes: [], + }) + }) + + it('should arrange items in thirds for three models', () => { + // Arrange + const multipleModelConfigs = [createModelAndParameter(), createModelAndParameter(), createModelAndParameter()] + + // Act + renderComponent({ multipleModelConfigs }) + const items = screen.getAllByTestId('debug-item') + + // Assert + expect(items).toHaveLength(3) + expectItemLayout(items[0], { + width: 'calc(33.3% - 5.33px - 16px)', + height: '100%', + transform: 'translateX(0) translateY(0)', + classes: ['mr-2'], + }) + expectItemLayout(items[1], { + width: 'calc(33.3% - 5.33px - 16px)', + height: '100%', + transform: 'translateX(calc(100% + 8px)) translateY(0)', + classes: ['mr-2'], + }) + expectItemLayout(items[2], { + width: 'calc(33.3% - 5.33px - 16px)', + height: '100%', + transform: 'translateX(calc(200% + 16px)) translateY(0)', + classes: [], + }) + }) + + it('should position items on a grid for four models', () => { + // Arrange + const multipleModelConfigs = [ + createModelAndParameter(), + createModelAndParameter(), + createModelAndParameter(), + createModelAndParameter(), + ] + + // Act + renderComponent({ multipleModelConfigs }) + const items = screen.getAllByTestId('debug-item') + + // Assert + expect(items).toHaveLength(4) + expectItemLayout(items[0], { + width: 'calc(50% - 4px - 24px)', + height: 'calc(50% - 4px)', + transform: 'translateX(0) translateY(0)', + classes: ['mr-2', 'mb-2'], + }) + expectItemLayout(items[1], { + width: 'calc(50% - 4px - 24px)', + height: 'calc(50% - 4px)', + transform: 'translateX(calc(100% + 8px)) translateY(0)', + classes: ['mb-2'], + }) + expectItemLayout(items[2], { + width: 'calc(50% - 4px - 24px)', + height: 'calc(50% - 4px)', + transform: 'translateX(0) translateY(calc(100% + 8px))', + classes: ['mr-2'], + }) + expectItemLayout(items[3], { + width: 'calc(50% - 4px - 24px)', + height: 'calc(50% - 4px)', + transform: 'translateX(calc(100% + 8px)) translateY(calc(100% + 8px))', + classes: [], + }) + }) + + it('should fall back to single column layout when only one model is provided', () => { + // Arrange + const multipleModelConfigs = [createModelAndParameter()] + + // Act + renderComponent({ multipleModelConfigs }) + const item = screen.getByTestId('debug-item') + + // Assert + expectItemLayout(item, { + transform: 'translateX(0) translateY(0)', + classes: [], + }) + }) + + it('should set scroll area height for chat modes', () => { + const { container } = renderComponent() + const scrollArea = container.querySelector('.relative.mb-3.grow.overflow-auto.px-6') as HTMLElement + expect(scrollArea).toBeInTheDocument() + expect(scrollArea.style.height).toBe('calc(100% - 60px)') + }) + + it('should set full height when chat input is hidden', () => { + mockUseDebugConfigurationContext.mockReturnValue(createDebugConfiguration({ + mode: AppModeEnum.COMPLETION, + })) + + const { container } = renderComponent() + const scrollArea = container.querySelector('.relative.mb-3.grow.overflow-auto.px-6') as HTMLElement + expect(scrollArea.style.height).toBe('100%') + }) + }) +}) diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index afe640278e..2537062e13 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -307,7 +307,7 @@ const Configuration: FC = () => { const oldRetrievalConfig = { top_k, score_threshold, - reranking_model: (reranking_model.reranking_provider_name && reranking_model.reranking_model_name) ? { + reranking_model: (reranking_model?.reranking_provider_name && reranking_model?.reranking_model_name) ? { provider: reranking_model.reranking_provider_name, model: reranking_model.reranking_model_name, } : undefined, diff --git a/web/app/components/app/create-app-dialog/app-list/index.tsx b/web/app/components/app/create-app-dialog/app-list/index.tsx index 8b19f43034..51b6874d52 100644 --- a/web/app/components/app/create-app-dialog/app-list/index.tsx +++ b/web/app/components/app/create-app-dialog/app-list/index.tsx @@ -28,6 +28,7 @@ import Input from '@/app/components/base/input' import { AppModeEnum } from '@/types/app' import { DSLImportMode } from '@/models/app' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' +import { trackEvent } from '@/app/components/base/amplitude' type AppsProps = { onSuccess?: () => void @@ -141,6 +142,15 @@ const Apps = ({ icon_background, description, }) + + // Track app creation from template + trackEvent('create_app_with_template', { + app_mode: mode, + template_id: currApp?.app.id, + template_name: currApp?.app.name, + description, + }) + setIsShowCreateModal(false) Toast.notify({ type: 'success', diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index 10fc099f9f..a449ec8ef2 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -30,6 +30,7 @@ import { getRedirection } from '@/utils/app-redirection' import FullScreenModal from '@/app/components/base/fullscreen-modal' import useTheme from '@/hooks/use-theme' import { useDocLink } from '@/context/i18n' +import { trackEvent } from '@/app/components/base/amplitude' type CreateAppProps = { onSuccess: () => void @@ -82,6 +83,13 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }: icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined, mode: appMode, }) + + // Track app creation success + trackEvent('create_app', { + app_mode: appMode, + description, + }) + notify({ type: 'success', message: t('app.newApp.appCreated') }) onSuccess() onClose() diff --git a/web/app/components/app/create-from-dsl-modal/index.tsx b/web/app/components/app/create-from-dsl-modal/index.tsx index 0c137abb71..3564738dfd 100644 --- a/web/app/components/app/create-from-dsl-modal/index.tsx +++ b/web/app/components/app/create-from-dsl-modal/index.tsx @@ -28,6 +28,7 @@ import { getRedirection } from '@/utils/app-redirection' import cn from '@/utils/classnames' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { noop } from 'lodash-es' +import { trackEvent } from '@/app/components/base/amplitude' type CreateFromDSLModalProps = { show: boolean @@ -112,6 +113,13 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS return const { id, status, app_id, app_mode, imported_dsl_version, current_dsl_version } = response if (status === DSLImportStatus.COMPLETED || status === DSLImportStatus.COMPLETED_WITH_WARNINGS) { + // Track app creation from DSL import + trackEvent('create_app_with_dsl', { + app_mode, + creation_method: currentTab === CreateFromDSLModalTab.FROM_FILE ? 'dsl_file' : 'dsl_url', + has_warnings: status === DSLImportStatus.COMPLETED_WITH_WARNINGS, + }) + if (onSuccess) onSuccess() if (onClose) diff --git a/web/app/components/app/log/index.tsx b/web/app/components/app/log/index.tsx index cedf2de74d..4fda71bece 100644 --- a/web/app/components/app/log/index.tsx +++ b/web/app/components/app/log/index.tsx @@ -27,24 +27,33 @@ export type QueryParam = { sort_by?: string } +const defaultQueryParams: QueryParam = { + period: '2', + annotation_status: 'all', + sort_by: '-created_at', +} + +const logsStateCache = new Map() + const Logs: FC = ({ appDetail }) => { const { t } = useTranslation() const router = useRouter() const pathname = usePathname() const searchParams = useSearchParams() - const [queryParams, setQueryParams] = useState({ - period: '2', - annotation_status: 'all', - sort_by: '-created_at', - }) const getPageFromParams = useCallback(() => { const pageParam = Number.parseInt(searchParams.get('page') || '1', 10) if (Number.isNaN(pageParam) || pageParam < 1) return 0 return pageParam - 1 }, [searchParams]) - const [currPage, setCurrPage] = React.useState(() => getPageFromParams()) - const [limit, setLimit] = React.useState(APP_PAGE_LIMIT) + const cachedState = logsStateCache.get(appDetail.id) + const [queryParams, setQueryParams] = useState(cachedState?.queryParams ?? defaultQueryParams) + const [currPage, setCurrPage] = React.useState(() => cachedState?.currPage ?? getPageFromParams()) + const [limit, setLimit] = React.useState(cachedState?.limit ?? APP_PAGE_LIMIT) const debouncedQueryParams = useDebounce(queryParams, { wait: 500 }) useEffect(() => { @@ -52,6 +61,14 @@ const Logs: FC = ({ appDetail }) => { setCurrPage(prev => (prev === pageFromParams ? prev : pageFromParams)) }, [getPageFromParams]) + useEffect(() => { + logsStateCache.set(appDetail.id, { + queryParams, + currPage, + limit, + }) + }, [appDetail.id, currPage, limit, queryParams]) + // Get the app type first const isChatMode = appDetail.mode !== AppModeEnum.COMPLETION @@ -85,6 +102,11 @@ const Logs: FC = ({ appDetail }) => { const total = isChatMode ? chatConversations?.total : completionConversations?.total + const handleQueryParamsChange = useCallback((next: QueryParam) => { + setCurrPage(0) + setQueryParams(next) + }, []) + const handlePageChange = useCallback((page: number) => { setCurrPage(page) const params = new URLSearchParams(searchParams.toString()) @@ -101,7 +123,7 @@ const Logs: FC = ({ appDetail }) => {

{t('appLog.description')}

- + {total === undefined ? : total > 0 diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index d21d35eeee..0ff375d815 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -816,9 +816,12 @@ const CompletionConversationDetailComp: FC<{ appId?: string; conversationId?: st const { notify } = useContext(ToastContext) const { t } = useTranslation() - const handleFeedback = async (mid: string, { rating }: FeedbackType): Promise => { + const handleFeedback = async (mid: string, { rating, content }: FeedbackType): Promise => { try { - await updateLogMessageFeedbacks({ url: `/apps/${appId}/feedbacks`, body: { message_id: mid, rating } }) + await updateLogMessageFeedbacks({ + url: `/apps/${appId}/feedbacks`, + body: { message_id: mid, rating, content: content ?? undefined }, + }) conversationDetailMutate() notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) return true @@ -861,9 +864,12 @@ const ChatConversationDetailComp: FC<{ appId?: string; conversationId?: string } const { notify } = useContext(ToastContext) const { t } = useTranslation() - const handleFeedback = async (mid: string, { rating }: FeedbackType): Promise => { + const handleFeedback = async (mid: string, { rating, content }: FeedbackType): Promise => { try { - await updateLogMessageFeedbacks({ url: `/apps/${appId}/feedbacks`, body: { message_id: mid, rating } }) + await updateLogMessageFeedbacks({ + url: `/apps/${appId}/feedbacks`, + body: { message_id: mid, rating, content: content ?? undefined }, + }) notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) return true } diff --git a/web/app/components/app/workflow-log/filter.tsx b/web/app/components/app/workflow-log/filter.tsx index 1ef1bd7a29..0c8d72c1be 100644 --- a/web/app/components/app/workflow-log/filter.tsx +++ b/web/app/components/app/workflow-log/filter.tsx @@ -8,6 +8,7 @@ import quarterOfYear from 'dayjs/plugin/quarterOfYear' import type { QueryParam } from './index' import Chip from '@/app/components/base/chip' import Input from '@/app/components/base/input' +import { trackEvent } from '@/app/components/base/amplitude/utils' dayjs.extend(quarterOfYear) const today = dayjs() @@ -37,6 +38,9 @@ const Filter: FC = ({ queryParams, setQueryParams }: IFilterProps) value={queryParams.status || 'all'} onSelect={(item) => { setQueryParams({ ...queryParams, status: item.value as string }) + trackEvent('workflow_log_filter_status_selected', { + workflow_log_filter_status: item.value as string, + }) }} onClear={() => setQueryParams({ ...queryParams, status: 'all' })} items={[{ value: 'all', name: 'All' }, diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index 8356cfd31c..b8da0264e4 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -27,6 +27,7 @@ import { fetchWorkflowDraft } from '@/service/workflow' import { fetchInstalledAppList } from '@/service/explore' import { AppTypeIcon } from '@/app/components/app/type-selector' import Tooltip from '@/app/components/base/tooltip' +import { useAsyncWindowOpen } from '@/hooks/use-async-window-open' import { AccessMode } from '@/models/access-control' import { useGlobalPublicStore } from '@/context/global-public-context' import { formatTime } from '@/utils/time' @@ -64,6 +65,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { const { isCurrentWorkspaceEditor } = useAppContext() const { onPlanInfoChanged } = useProviderContext() const { push } = useRouter() + const openAsyncWindow = useAsyncWindowOpen() const [showEditModal, setShowEditModal] = useState(false) const [showDuplicateModal, setShowDuplicateModal] = useState(false) @@ -247,11 +249,16 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { props.onClick?.() e.preventDefault() try { - const { installed_apps }: any = await fetchInstalledAppList(app.id) || {} - if (installed_apps?.length > 0) - window.open(`${basePath}/explore/installed/${installed_apps[0].id}`, '_blank') - else + await openAsyncWindow(async () => { + const { installed_apps }: any = await fetchInstalledAppList(app.id) || {} + if (installed_apps?.length > 0) + return `${basePath}/explore/installed/${installed_apps[0].id}` throw new Error('No app found in Explore') + }, { + onError: (err) => { + Toast.notify({ type: 'error', message: `${err.message || err}` }) + }, + }) } catch (e: any) { Toast.notify({ type: 'error', message: `${e.message || e}` }) diff --git a/web/app/components/base/amplitude/AmplitudeProvider.tsx b/web/app/components/base/amplitude/AmplitudeProvider.tsx new file mode 100644 index 0000000000..424475aba7 --- /dev/null +++ b/web/app/components/base/amplitude/AmplitudeProvider.tsx @@ -0,0 +1,49 @@ +'use client' + +import type { FC } from 'react' +import React, { useEffect } from 'react' +import * as amplitude from '@amplitude/analytics-browser' +import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser' +import { AMPLITUDE_API_KEY, IS_CLOUD_EDITION } from '@/config' + +export type IAmplitudeProps = { + sessionReplaySampleRate?: number +} + +// Check if Amplitude should be enabled +export const isAmplitudeEnabled = () => { + return IS_CLOUD_EDITION && !!AMPLITUDE_API_KEY +} + +const AmplitudeProvider: FC = ({ + sessionReplaySampleRate = 1, +}) => { + useEffect(() => { + // Only enable in Saas edition with valid API key + if (!isAmplitudeEnabled()) + return + + // Initialize Amplitude + amplitude.init(AMPLITUDE_API_KEY, { + defaultTracking: { + sessions: true, + pageViews: true, + formInteractions: true, + fileDownloads: true, + }, + // Enable debug logs in development environment + logLevel: amplitude.Types.LogLevel.Warn, + }) + + // Add Session Replay plugin + const sessionReplay = sessionReplayPlugin({ + sampleRate: sessionReplaySampleRate, + }) + amplitude.add(sessionReplay) + }, []) + + // This is a client component that renders nothing + return null +} + +export default React.memo(AmplitudeProvider) diff --git a/web/app/components/base/amplitude/index.ts b/web/app/components/base/amplitude/index.ts new file mode 100644 index 0000000000..acc792339e --- /dev/null +++ b/web/app/components/base/amplitude/index.ts @@ -0,0 +1,2 @@ +export { default, isAmplitudeEnabled } from './AmplitudeProvider' +export { resetUser, setUserId, setUserProperties, trackEvent } from './utils' diff --git a/web/app/components/base/amplitude/utils.ts b/web/app/components/base/amplitude/utils.ts new file mode 100644 index 0000000000..57b96243ec --- /dev/null +++ b/web/app/components/base/amplitude/utils.ts @@ -0,0 +1,46 @@ +import * as amplitude from '@amplitude/analytics-browser' +import { isAmplitudeEnabled } from './AmplitudeProvider' + +/** + * Track custom event + * @param eventName Event name + * @param eventProperties Event properties (optional) + */ +export const trackEvent = (eventName: string, eventProperties?: Record) => { + if (!isAmplitudeEnabled()) + return + amplitude.track(eventName, eventProperties) +} + +/** + * Set user ID + * @param userId User ID + */ +export const setUserId = (userId: string) => { + if (!isAmplitudeEnabled()) + return + amplitude.setUserId(userId) +} + +/** + * Set user properties + * @param properties User properties + */ +export const setUserProperties = (properties: Record) => { + if (!isAmplitudeEnabled()) + return + const identifyEvent = new amplitude.Identify() + Object.entries(properties).forEach(([key, value]) => { + identifyEvent.set(key, value) + }) + amplitude.identify(identifyEvent) +} + +/** + * Reset user (e.g., when user logs out) + */ +export const resetUser = () => { + if (!isAmplitudeEnabled()) + return + amplitude.reset() +} diff --git a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx index 302fb9a3c7..94c80687ed 100644 --- a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx +++ b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx @@ -284,7 +284,6 @@ const ChatWrapper = () => { themeBuilder={themeBuilder} switchSibling={siblingMessageId => setTargetMessageId(siblingMessageId)} inputDisabled={inputDisabled} - isMobile={isMobile} sidebarCollapseState={sidebarCollapseState} questionIcon={ initUserVariables?.avatar_url diff --git a/web/app/components/base/chat/chat/answer/operation.tsx b/web/app/components/base/chat/chat/answer/operation.tsx index 4667bb8c3d..6c1bfd1bbc 100644 --- a/web/app/components/base/chat/chat/answer/operation.tsx +++ b/web/app/components/base/chat/chat/answer/operation.tsx @@ -11,7 +11,10 @@ import { RiThumbDownLine, RiThumbUpLine, } from '@remixicon/react' -import type { ChatItem } from '../../types' +import type { + ChatItem, + Feedback, +} from '../../types' import { useChatContext } from '../context' import copy from 'copy-to-clipboard' import Toast from '@/app/components/base/toast' @@ -22,6 +25,7 @@ import ActionButton, { ActionButtonState } from '@/app/components/base/action-bu import NewAudioButton from '@/app/components/base/new-audio-button' import Modal from '@/app/components/base/modal/modal' import Textarea from '@/app/components/base/textarea' +import Tooltip from '@/app/components/base/tooltip' import cn from '@/utils/classnames' type OperationProps = { @@ -67,8 +71,9 @@ const Operation: FC = ({ agent_thoughts, humanInputFormData, } = item - const [localFeedback, setLocalFeedback] = useState(config?.supportAnnotation ? adminFeedback : feedback) + const [userLocalFeedback, setUserLocalFeedback] = useState(feedback) const [adminLocalFeedback, setAdminLocalFeedback] = useState(adminFeedback) + const [feedbackTarget, setFeedbackTarget] = useState<'user' | 'admin'>('user') // Separate feedback types for display const userFeedback = feedback @@ -80,24 +85,68 @@ const Operation: FC = ({ return messageContent }, [agent_thoughts, messageContent]) - const handleFeedback = async (rating: 'like' | 'dislike' | null, content?: string) => { + const displayUserFeedback = userLocalFeedback ?? userFeedback + + const hasUserFeedback = !!displayUserFeedback?.rating + const hasAdminFeedback = !!adminLocalFeedback?.rating + + const shouldShowUserFeedbackBar = !isOpeningStatement && config?.supportFeedback && !!onFeedback && !config?.supportAnnotation + const shouldShowAdminFeedbackBar = !isOpeningStatement && config?.supportFeedback && !!onFeedback && !!config?.supportAnnotation + + const userFeedbackLabel = t('appLog.table.header.userRate') || 'User feedback' + const adminFeedbackLabel = t('appLog.table.header.adminRate') || 'Admin feedback' + const feedbackTooltipClassName = 'max-w-[260px]' + + const buildFeedbackTooltip = (feedbackData?: Feedback | null, label = userFeedbackLabel) => { + if (!feedbackData?.rating) + return label + + const ratingLabel = feedbackData.rating === 'like' + ? (t('appLog.detail.operation.like') || 'like') + : (t('appLog.detail.operation.dislike') || 'dislike') + const feedbackText = feedbackData.content?.trim() + + if (feedbackText) + return `${label}: ${ratingLabel} - ${feedbackText}` + + return `${label}: ${ratingLabel}` + } + + const handleFeedback = async (rating: 'like' | 'dislike' | null, content?: string, target: 'user' | 'admin' = 'user') => { if (!config?.supportFeedback || !onFeedback) return await onFeedback?.(id, { rating, content }) - setLocalFeedback({ rating }) - // Update admin feedback state separately if annotation is supported - if (config?.supportAnnotation) - setAdminLocalFeedback(rating ? { rating } : undefined) + const nextFeedback = rating === null ? { rating: null } : { rating, content } + + if (target === 'admin') + setAdminLocalFeedback(nextFeedback) + else + setUserLocalFeedback(nextFeedback) } - const handleThumbsDown = () => { + const handleLikeClick = (target: 'user' | 'admin') => { + const currentRating = target === 'admin' ? adminLocalFeedback?.rating : displayUserFeedback?.rating + if (currentRating === 'like') { + handleFeedback(null, undefined, target) + return + } + handleFeedback('like', undefined, target) + } + + const handleDislikeClick = (target: 'user' | 'admin') => { + const currentRating = target === 'admin' ? adminLocalFeedback?.rating : displayUserFeedback?.rating + if (currentRating === 'dislike') { + handleFeedback(null, undefined, target) + return + } + setFeedbackTarget(target) setIsShowFeedbackModal(true) } const handleFeedbackSubmit = async () => { - await handleFeedback('dislike', feedbackContent) + await handleFeedback('dislike', feedbackContent, feedbackTarget) setFeedbackContent('') setIsShowFeedbackModal(false) } @@ -117,12 +166,13 @@ const Operation: FC = ({ width += 26 if (!isOpeningStatement && config?.supportAnnotation && config?.annotation_reply?.enabled) width += 26 - if (config?.supportFeedback && !localFeedback?.rating && onFeedback && !isOpeningStatement) - width += 60 + 8 - if (config?.supportFeedback && localFeedback?.rating && onFeedback && !isOpeningStatement) - width += 28 + 8 + if (shouldShowUserFeedbackBar) + width += hasUserFeedback ? 28 + 8 : 60 + 8 + if (shouldShowAdminFeedbackBar) + width += (hasAdminFeedback ? 28 : 60) + 8 + (hasUserFeedback ? 28 : 0) + return width - }, [isOpeningStatement, showPromptLog, config?.text_to_speech?.enabled, config?.supportAnnotation, config?.annotation_reply?.enabled, config?.supportFeedback, localFeedback?.rating, onFeedback]) + }, [config?.annotation_reply?.enabled, config?.supportAnnotation, config?.text_to_speech?.enabled, hasAdminFeedback, hasUserFeedback, isOpeningStatement, shouldShowAdminFeedbackBar, shouldShowUserFeedbackBar, showPromptLog]) const positionRight = useMemo(() => operationWidth < maxSize, [operationWidth, maxSize]) @@ -137,6 +187,110 @@ const Operation: FC = ({ )} style={(!hasWorkflowProcess && positionRight) ? { left: contentWidth + 8 } : {}} > + {shouldShowUserFeedbackBar && ( +
+ {hasUserFeedback ? ( + + handleFeedback(null, undefined, 'user')} + > + {displayUserFeedback?.rating === 'like' + ? + : } + + + ) : ( + <> + handleLikeClick('user')} + > + + + handleDislikeClick('user')} + > + + + + )} +
+ )} + {shouldShowAdminFeedbackBar && ( +
+ {/* User Feedback Display */} + {displayUserFeedback?.rating && ( + + {displayUserFeedback.rating === 'like' ? ( + + + + ) : ( + + + + )} + + )} + + {/* Admin Feedback Controls */} + {displayUserFeedback?.rating &&
} + {hasAdminFeedback ? ( + + handleFeedback(null, undefined, 'admin')} + > + {adminLocalFeedback?.rating === 'like' + ? + : } + + + ) : ( + <> + + handleLikeClick('admin')} + > + + + + + handleDislikeClick('admin')} + > + + + + + )} +
+ )} {showPromptLog && !isOpeningStatement && (
@@ -177,69 +331,6 @@ const Operation: FC = ({ )}
)} - {!isOpeningStatement && config?.supportFeedback && !localFeedback?.rating && onFeedback && ( -
- {!localFeedback?.rating && ( - <> - handleFeedback('like')}> - - - - - - - )} -
- )} - {!isOpeningStatement && config?.supportFeedback && onFeedback && ( -
- {/* User Feedback Display */} - {userFeedback?.rating && ( -
- User - {userFeedback.rating === 'like' ? ( - - - - ) : ( - - - - )} -
- )} - - {/* Admin Feedback Controls */} - {config?.supportAnnotation && ( -
- {userFeedback?.rating &&
} - {!adminLocalFeedback?.rating ? ( - <> - handleFeedback('like')}> - - - - - - - ) : ( - <> - {adminLocalFeedback.rating === 'like' ? ( - handleFeedback(null)}> - - - ) : ( - handleFeedback(null)}> - - - )} - - )} -
- )} - -
- )}
void noSpacing?: boolean inputDisabled?: boolean - isMobile?: boolean sidebarCollapseState?: boolean hideAvatar?: boolean onHumanInputFormSubmit?: (formID: string, formData: any) => void @@ -113,7 +112,6 @@ const Chat: FC = ({ onFeatureBarClick, noSpacing, inputDisabled, - isMobile, sidebarCollapseState, hideAvatar, onHumanInputFormSubmit, @@ -331,7 +329,6 @@ const Chat: FC = ({ ) } diff --git a/web/app/components/base/chat/chat/try-to-ask.tsx b/web/app/components/base/chat/chat/try-to-ask.tsx index 7e3dcc95f9..665f7b3b13 100644 --- a/web/app/components/base/chat/chat/try-to-ask.tsx +++ b/web/app/components/base/chat/chat/try-to-ask.tsx @@ -4,28 +4,25 @@ import { useTranslation } from 'react-i18next' import type { OnSend } from '../types' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' -import cn from '@/utils/classnames' type TryToAskProps = { suggestedQuestions: string[] onSend: OnSend - isMobile?: boolean } const TryToAsk: FC = ({ suggestedQuestions, onSend, - isMobile, }) => { const { t } = useTranslation() return (
-
- +
+
{t('appDebug.feature.suggestedQuestionsAfterAnswer.tryToAsk')}
- {!isMobile && } +
-
+
{ suggestedQuestions.map((suggestQuestion, index) => ( )}
) diff --git a/web/app/components/billing/plan-upgrade-modal/index.spec.tsx b/web/app/components/billing/plan-upgrade-modal/index.spec.tsx new file mode 100644 index 0000000000..324043d439 --- /dev/null +++ b/web/app/components/billing/plan-upgrade-modal/index.spec.tsx @@ -0,0 +1,118 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import PlanUpgradeModal from './index' + +const mockSetShowPricingModal = jest.fn() + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +jest.mock('@/app/components/base/modal', () => { + const MockModal = ({ isShow, children }: { isShow: boolean; children: React.ReactNode }) => ( + isShow ?
{children}
: null + ) + return { + __esModule: true, + default: MockModal, + } +}) + +jest.mock('@/context/modal-context', () => ({ + useModalContext: () => ({ + setShowPricingModal: mockSetShowPricingModal, + }), +})) + +const baseProps = { + title: 'Upgrade Required', + description: 'You need to upgrade your plan.', + show: true, + onClose: jest.fn(), +} + +const renderComponent = (props: Partial> = {}) => { + const mergedProps = { ...baseProps, ...props } + return render() +} + +describe('PlanUpgradeModal', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + // Rendering and props-driven content + it('should render modal with provided content when visible', () => { + // Arrange + const extraInfoText = 'Additional upgrade details' + renderComponent({ + extraInfo:
{extraInfoText}
, + }) + + // Assert + expect(screen.getByText(baseProps.title)).toBeInTheDocument() + expect(screen.getByText(baseProps.description)).toBeInTheDocument() + expect(screen.getByText(extraInfoText)).toBeInTheDocument() + expect(screen.getByText('billing.triggerLimitModal.dismiss')).toBeInTheDocument() + expect(screen.getByText('billing.triggerLimitModal.upgrade')).toBeInTheDocument() + }) + + // Guard against rendering when modal is hidden + it('should not render content when show is false', () => { + // Act + renderComponent({ show: false }) + + // Assert + expect(screen.queryByText(baseProps.title)).not.toBeInTheDocument() + expect(screen.queryByText(baseProps.description)).not.toBeInTheDocument() + }) + + // User closes the modal from dismiss button + it('should call onClose when dismiss button is clicked', async () => { + // Arrange + const user = userEvent.setup() + const onClose = jest.fn() + renderComponent({ onClose }) + + // Act + await user.click(screen.getByText('billing.triggerLimitModal.dismiss')) + + // Assert + expect(onClose).toHaveBeenCalledTimes(1) + }) + + // Upgrade path uses provided callback over pricing modal + it('should call onUpgrade and onClose when upgrade button is clicked with onUpgrade provided', async () => { + // Arrange + const user = userEvent.setup() + const onClose = jest.fn() + const onUpgrade = jest.fn() + renderComponent({ onClose, onUpgrade }) + + // Act + await user.click(screen.getByText('billing.triggerLimitModal.upgrade')) + + // Assert + expect(onClose).toHaveBeenCalledTimes(1) + expect(onUpgrade).toHaveBeenCalledTimes(1) + expect(mockSetShowPricingModal).not.toHaveBeenCalled() + }) + + // Fallback upgrade path opens pricing modal when no onUpgrade is supplied + it('should open pricing modal when upgrade button is clicked without onUpgrade', async () => { + // Arrange + const user = userEvent.setup() + const onClose = jest.fn() + renderComponent({ onClose, onUpgrade: undefined }) + + // Act + await user.click(screen.getByText('billing.triggerLimitModal.upgrade')) + + // Assert + expect(onClose).toHaveBeenCalledTimes(1) + expect(mockSetShowPricingModal).toHaveBeenCalledTimes(1) + }) +}) diff --git a/web/app/components/billing/plan-upgrade-modal/index.tsx b/web/app/components/billing/plan-upgrade-modal/index.tsx new file mode 100644 index 0000000000..4f5d1ed3a6 --- /dev/null +++ b/web/app/components/billing/plan-upgrade-modal/index.tsx @@ -0,0 +1,87 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import Modal from '@/app/components/base/modal' +import Button from '@/app/components/base/button' +import UpgradeBtn from '@/app/components/billing/upgrade-btn' +import styles from './style.module.css' +import { SquareChecklist } from '../../base/icons/src/vender/other' +import { useModalContext } from '@/context/modal-context' + +type Props = { + Icon?: React.ComponentType> + title: string + description: string + extraInfo?: React.ReactNode + show: boolean + onClose: () => void + onUpgrade?: () => void +} + +const PlanUpgradeModal: FC = ({ + Icon = SquareChecklist, + title, + description, + extraInfo, + show, + onClose, + onUpgrade, +}) => { + const { t } = useTranslation() + const { setShowPricingModal } = useModalContext() + + const handleUpgrade = useCallback(() => { + onClose() + onUpgrade ? onUpgrade() : setShowPricingModal() + }, [onClose, onUpgrade, setShowPricingModal]) + + return ( + +
+
+
+
+ +
+
+
+ {title} +
+
+ {description} +
+
+ {extraInfo} +
+
+ +
+ + +
+ + ) +} + +export default React.memo(PlanUpgradeModal) diff --git a/web/app/components/billing/trigger-events-limit-modal/index.module.css b/web/app/components/billing/plan-upgrade-modal/style.module.css similarity index 92% rename from web/app/components/billing/trigger-events-limit-modal/index.module.css rename to web/app/components/billing/plan-upgrade-modal/style.module.css index e8e86719e6..50ad488388 100644 --- a/web/app/components/billing/trigger-events-limit-modal/index.module.css +++ b/web/app/components/billing/plan-upgrade-modal/style.module.css @@ -19,7 +19,6 @@ background: linear-gradient(180deg, var(--color-components-avatar-bg-mask-stop-0, rgba(255, 255, 255, 0.12)) 0%, var(--color-components-avatar-bg-mask-stop-100, rgba(255, 255, 255, 0.08)) 100%), var(--color-util-colors-blue-brand-blue-brand-500, #296dff); - box-shadow: 0 10px 20px color-mix(in srgb, var(--color-util-colors-blue-brand-blue-brand-500, #296dff) 35%, transparent); } .highlight { diff --git a/web/app/components/billing/plan/index.tsx b/web/app/components/billing/plan/index.tsx index b695302965..fa76ed7b4d 100644 --- a/web/app/components/billing/plan/index.tsx +++ b/web/app/components/billing/plan/index.tsx @@ -1,8 +1,8 @@ 'use client' import type { FC } from 'react' -import React from 'react' +import React, { useEffect } from 'react' import { useTranslation } from 'react-i18next' -import { useRouter } from 'next/navigation' +import { usePathname, useRouter } from 'next/navigation' import { RiBook2Line, RiFileEditLine, @@ -25,6 +25,8 @@ import { EDUCATION_VERIFYING_LOCALSTORAGE_ITEM } from '@/app/education-apply/con import { useEducationVerify } from '@/service/use-education' import { useModalContextSelector } from '@/context/modal-context' import { Enterprise, Professional, Sandbox, Team } from './assets' +import { Loading } from '../../base/icons/src/public/thought' +import { useUnmountedRef } from 'ahooks' type Props = { loc: string @@ -35,6 +37,7 @@ const PlanComp: FC = ({ }) => { const { t } = useTranslation() const router = useRouter() + const path = usePathname() const { userProfile } = useAppContext() const { plan, enableEducationPlan, allowRefreshEducationVerify, isEducationAccount } = useProviderContext() const isAboutToExpire = allowRefreshEducationVerify @@ -61,17 +64,24 @@ const PlanComp: FC = ({ })() const [showModal, setShowModal] = React.useState(false) - const { mutateAsync } = useEducationVerify() + const { mutateAsync, isPending } = useEducationVerify() const setShowAccountSettingModal = useModalContextSelector(s => s.setShowAccountSettingModal) + const unmountedRef = useUnmountedRef() const handleVerify = () => { + if (isPending) return mutateAsync().then((res) => { localStorage.removeItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM) + if (unmountedRef.current) return router.push(`/education-apply?token=${res.token}`) - setShowAccountSettingModal(null) }).catch(() => { setShowModal(true) }) } + useEffect(() => { + // setShowAccountSettingModal would prevent navigation + if (path.startsWith('/education-apply')) + setShowAccountSettingModal(null) + }, [path, setShowAccountSettingModal]) return (
@@ -96,9 +106,10 @@ const PlanComp: FC = ({
{enableEducationPlan && (!isEducationAccount || isAboutToExpire) && ( - )} {(plan.type as any) !== SelfHostedPlan.enterprise && ( diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx index 5d7b11c7e0..52c2883b81 100644 --- a/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx @@ -8,7 +8,8 @@ import { ALL_PLANS } from '../../../config' import Toast from '../../../../base/toast' import { PlanRange } from '../../plan-switcher/plan-range-switcher' import { useAppContext } from '@/context/app-context' -import { fetchSubscriptionUrls } from '@/service/billing' +import { fetchBillingUrl, fetchSubscriptionUrls } from '@/service/billing' +import { useAsyncWindowOpen } from '@/hooks/use-async-window-open' import List from './list' import Button from './button' import { Professional, Sandbox, Team } from '../../assets' @@ -39,8 +40,10 @@ const CloudPlanItem: FC = ({ const planInfo = ALL_PLANS[plan] const isYear = planRange === PlanRange.yearly const isCurrent = plan === currentPlan - const isPlanDisabled = planInfo.level <= ALL_PLANS[currentPlan].level + const isCurrentPaidPlan = isCurrent && !isFreePlan + const isPlanDisabled = isCurrentPaidPlan ? false : planInfo.level <= ALL_PLANS[currentPlan].level const { isCurrentWorkspaceManager } = useAppContext() + const openAsyncWindow = useAsyncWindowOpen() const btnText = useMemo(() => { if (isCurrent) @@ -60,10 +63,6 @@ const CloudPlanItem: FC = ({ if (isPlanDisabled) return - if (isFreePlan) - return - - // Only workspace manager can buy plan if (!isCurrentWorkspaceManager) { Toast.notify({ type: 'error', @@ -74,6 +73,23 @@ const CloudPlanItem: FC = ({ } setLoading(true) try { + if (isCurrentPaidPlan) { + await openAsyncWindow(async () => { + const res = await fetchBillingUrl() + if (res.url) + return res.url + throw new Error('Failed to open billing page') + }, { + onError: (err) => { + Toast.notify({ type: 'error', message: err.message || String(err) }) + }, + }) + return + } + + if (isFreePlan) + return + const res = await fetchSubscriptionUrls(plan, isYear ? 'year' : 'month') // Adb Block additional tracking block the gtag, so we need to redirect directly window.location.href = res.url diff --git a/web/app/components/billing/trigger-events-limit-modal/index.tsx b/web/app/components/billing/trigger-events-limit-modal/index.tsx index c1065a7868..9176c3d542 100644 --- a/web/app/components/billing/trigger-events-limit-modal/index.tsx +++ b/web/app/components/billing/trigger-events-limit-modal/index.tsx @@ -2,27 +2,22 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import Modal from '@/app/components/base/modal' -import Button from '@/app/components/base/button' import { TriggerAll } from '@/app/components/base/icons/src/vender/workflow' import UsageInfo from '@/app/components/billing/usage-info' -import UpgradeBtn from '@/app/components/billing/upgrade-btn' -import type { Plan } from '@/app/components/billing/type' -import styles from './index.module.css' +import PlanUpgradeModal from '@/app/components/billing/plan-upgrade-modal' type Props = { show: boolean - onDismiss: () => void + onClose: () => void onUpgrade: () => void usage: number total: number resetInDays?: number - planType: Plan } const TriggerEventsLimitModal: FC = ({ show, - onDismiss, + onClose, onUpgrade, usage, total, @@ -31,59 +26,25 @@ const TriggerEventsLimitModal: FC = ({ const { t } = useTranslation() return ( - -
-
>} + title={t('billing.triggerLimitModal.title')} + description={t('billing.triggerLimitModal.description')} + extraInfo={( + -
-
- -
-
-
- {t('billing.triggerLimitModal.title')} -
-
- {t('billing.triggerLimitModal.description')} -
-
- -
-
- -
- - -
- + )} + /> ) } diff --git a/web/app/components/billing/upgrade-btn/index.tsx b/web/app/components/billing/upgrade-btn/index.tsx index d576e07f3e..b70daeb2e6 100644 --- a/web/app/components/billing/upgrade-btn/index.tsx +++ b/web/app/components/billing/upgrade-btn/index.tsx @@ -11,7 +11,7 @@ type Props = { className?: string style?: CSSProperties isFull?: boolean - size?: 'md' | 'lg' + size?: 's' | 'm' | 'custom' isPlain?: boolean isShort?: boolean onClick?: () => void @@ -21,6 +21,7 @@ type Props = { const UpgradeBtn: FC = ({ className, + size = 'm', style, isPlain = false, isShort = false, @@ -62,7 +63,7 @@ const UpgradeBtn: FC = ({ return ( = ({ - avatar_url, + avatarUrl, name, size = 20, className = '', }) => { - const [showAvatar, setShowAvatar] = useState(!!avatar_url && avatar_url !== 'default') + const [showAvatar, setShowAvatar] = useState(!!avatarUrl && avatarUrl !== 'default') const firstLetter = useMemo(() => name.charAt(0).toUpperCase(), [name]) const bgColor = useMemo(() => ICON_BG_COLORS[firstLetter.charCodeAt(0) % ICON_BG_COLORS.length], [firstLetter]) @@ -29,17 +29,20 @@ export const CredentialIcon: React.FC = ({ setShowAvatar(false) }, []) - if (avatar_url && avatar_url !== 'default' && showAvatar) { + if (avatarUrl && avatarUrl !== 'default' && showAvatar) { return (
diff --git a/web/app/components/datasets/common/document-status-with-action/index-failed.tsx b/web/app/components/datasets/common/document-status-with-action/index-failed.tsx index 802e3d872f..4713d944e0 100644 --- a/web/app/components/datasets/common/document-status-with-action/index-failed.tsx +++ b/web/app/components/datasets/common/document-status-with-action/index-failed.tsx @@ -2,11 +2,11 @@ import type { FC } from 'react' import React, { useEffect, useReducer } from 'react' import { useTranslation } from 'react-i18next' -import useSWR from 'swr' import StatusWithAction from './status-with-action' -import { getErrorDocs, retryErrorDocs } from '@/service/datasets' +import { retryErrorDocs } from '@/service/datasets' import type { IndexingStatusResponse } from '@/models/datasets' import { noop } from 'lodash-es' +import { useDatasetErrorDocs } from '@/service/knowledge/use-dataset' type Props = { datasetId: string @@ -35,16 +35,19 @@ const indexStateReducer = (state: IIndexState, action: IAction) => { const RetryButton: FC = ({ datasetId }) => { const { t } = useTranslation() const [indexState, dispatch] = useReducer(indexStateReducer, { value: 'success' }) - const { data: errorDocs, isLoading } = useSWR({ datasetId }, getErrorDocs) + const { data: errorDocs, isLoading, refetch: refetchErrorDocs } = useDatasetErrorDocs(datasetId) const onRetryErrorDocs = async () => { dispatch({ type: 'retry' }) const document_ids = errorDocs?.data.map((doc: IndexingStatusResponse) => doc.id) || [] const res = await retryErrorDocs({ datasetId, document_ids }) - if (res.result === 'success') + if (res.result === 'success') { + refetchErrorDocs() dispatch({ type: 'success' }) - else + } + else { dispatch({ type: 'error' }) + } } useEffect(() => { diff --git a/web/app/components/datasets/common/image-list/index.tsx b/web/app/components/datasets/common/image-list/index.tsx new file mode 100644 index 0000000000..8b0cf62e4a --- /dev/null +++ b/web/app/components/datasets/common/image-list/index.tsx @@ -0,0 +1,88 @@ +import { useCallback, useMemo, useState } from 'react' +import type { FileEntity } from '@/app/components/base/file-thumb' +import FileThumb from '@/app/components/base/file-thumb' +import cn from '@/utils/classnames' +import More from './more' +import type { ImageInfo } from '../image-previewer' +import ImagePreviewer from '../image-previewer' + +type Image = { + name: string + mimeType: string + sourceUrl: string + size: number + extension: string +} + +type ImageListProps = { + images: Image[] + size: 'sm' | 'md' + limit?: number + className?: string +} + +const ImageList = ({ + images, + size, + limit = 9, + className, +}: ImageListProps) => { + const [showMore, setShowMore] = useState(false) + const [previewIndex, setPreviewIndex] = useState(0) + const [previewImages, setPreviewImages] = useState([]) + + const limitedImages = useMemo(() => { + return showMore ? images : images.slice(0, limit) + }, [images, limit, showMore]) + + const handleShowMore = useCallback(() => { + setShowMore(true) + }, []) + + const handleImageClick = useCallback((file: FileEntity) => { + const index = limitedImages.findIndex(image => image.sourceUrl === file.sourceUrl) + if (index === -1) return + setPreviewIndex(index) + setPreviewImages(limitedImages.map(image => ({ + url: image.sourceUrl, + name: image.name, + size: image.size, + }))) + }, [limitedImages]) + + const handleClosePreview = useCallback(() => { + setPreviewImages([]) + }, []) + + return ( + <> +
+ { + limitedImages.map(image => ( + + )) + } + {images.length > limit && !showMore && ( + + )} +
+ {previewImages.length > 0 && ( + + )} + + ) +} + +export default ImageList diff --git a/web/app/components/datasets/common/image-list/more.tsx b/web/app/components/datasets/common/image-list/more.tsx new file mode 100644 index 0000000000..6da85e6939 --- /dev/null +++ b/web/app/components/datasets/common/image-list/more.tsx @@ -0,0 +1,39 @@ +import React, { useCallback } from 'react' + +type MoreProps = { + count: number + onClick?: () => void +} + +const More = ({ count, onClick }: MoreProps) => { + const formatNumber = (num: number) => { + if (num === 0) + return '0' + if (num < 1000) + return num.toString() + if (num < 1000000) + return `${(num / 1000).toFixed(1)}k` + return `${(num / 1000000).toFixed(1)}M` + } + + const handleClick = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onClick?.() + }, [onClick]) + + return ( +
+
+
+ + {`+${formatNumber(count)}`} + +
+
+
+
+ ) +} + +export default React.memo(More) diff --git a/web/app/components/datasets/common/image-previewer/index.tsx b/web/app/components/datasets/common/image-previewer/index.tsx new file mode 100644 index 0000000000..14e48d65fc --- /dev/null +++ b/web/app/components/datasets/common/image-previewer/index.tsx @@ -0,0 +1,223 @@ +import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import Button from '@/app/components/base/button' +import Loading from '@/app/components/base/loading' +import { formatFileSize } from '@/utils/format' +import { RiArrowLeftLine, RiArrowRightLine, RiCloseLine, RiRefreshLine } from '@remixicon/react' +import { createPortal } from 'react-dom' +import { useHotkeys } from 'react-hotkeys-hook' + +type CachedImage = { + blobUrl?: string + status: 'loading' | 'loaded' | 'error' + width: number + height: number +} + +const imageCache = new Map() + +export type ImageInfo = { + url: string + name: string + size: number +} + +type ImagePreviewerProps = { + images: ImageInfo[] + initialIndex?: number + onClose: () => void +} + +const ImagePreviewer = ({ + images, + initialIndex = 0, + onClose, +}: ImagePreviewerProps) => { + const [currentIndex, setCurrentIndex] = useState(initialIndex) + const [cachedImages, setCachedImages] = useState>(() => { + return images.reduce((acc, image) => { + acc[image.url] = { + status: 'loading', + width: 0, + height: 0, + } + return acc + }, {} as Record) + }) + const isMounted = useRef(false) + + const fetchImage = useCallback(async (image: ImageInfo) => { + const { url } = image + // Skip if already cached + if (imageCache.has(url)) return + + try { + const res = await fetch(url) + if (!res.ok) throw new Error(`Failed to load: ${url}`) + const blob = await res.blob() + const blobUrl = URL.createObjectURL(blob) + + const img = new Image() + img.src = blobUrl + img.onload = () => { + if (!isMounted.current) return + imageCache.set(url, { + blobUrl, + status: 'loaded', + width: img.naturalWidth, + height: img.naturalHeight, + }) + setCachedImages((prev) => { + return { + ...prev, + [url]: { + blobUrl, + status: 'loaded', + width: img.naturalWidth, + height: img.naturalHeight, + }, + } + }) + } + } + catch { + if (isMounted.current) { + setCachedImages((prev) => { + return { + ...prev, + [url]: { + status: 'error', + width: 0, + height: 0, + }, + } + }) + } + } + }, []) + + useEffect(() => { + isMounted.current = true + + images.forEach((image) => { + fetchImage(image) + }) + + return () => { + isMounted.current = false + // Cleanup released blob URLs not in current list + imageCache.forEach(({ blobUrl }, key) => { + if (blobUrl) + URL.revokeObjectURL(blobUrl) + imageCache.delete(key) + }) + } + }, []) + + const currentImage = useMemo(() => { + return images[currentIndex] + }, [images, currentIndex]) + + const prevImage = useCallback(() => { + if (currentIndex === 0) + return + setCurrentIndex(prevIndex => prevIndex - 1) + }, [currentIndex]) + + const nextImage = useCallback(() => { + if (currentIndex === images.length - 1) + return + setCurrentIndex(prevIndex => prevIndex + 1) + }, [currentIndex, images.length]) + + const retryImage = useCallback((image: ImageInfo) => { + setCachedImages((prev) => { + return { + ...prev, + [image.url]: { + ...prev[image.url], + status: 'loading', + }, + } + }) + fetchImage(image) + }, [fetchImage]) + + useHotkeys('esc', onClose) + useHotkeys('left', prevImage) + useHotkeys('right', nextImage) + + return createPortal( +
e.stopPropagation()} + tabIndex={-1} + > +
+ + + Esc + +
+ {cachedImages[currentImage.url].status === 'loading' && ( + + )} + {cachedImages[currentImage.url].status === 'error' && ( +
+ {`Failed to load image: ${currentImage.url}. Please try again.`} + +
+ )} + {cachedImages[currentImage.url].status === 'loaded' && ( +
+ {currentImage.name} +
+ {currentImage.name} + · + {`${cachedImages[currentImage.url].width} ×  ${cachedImages[currentImage.url].height}`} + · + {formatFileSize(currentImage.size)} +
+
+ )} + + +
, + document.body, + ) +} + +export default ImagePreviewer diff --git a/web/app/components/datasets/common/image-uploader/constants.ts b/web/app/components/datasets/common/image-uploader/constants.ts new file mode 100644 index 0000000000..671ed94fcf --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/constants.ts @@ -0,0 +1,7 @@ +export const ACCEPT_TYPES = ['jpg', 'jpeg', 'png', 'gif'] + +export const DEFAULT_IMAGE_FILE_SIZE_LIMIT = 2 + +export const DEFAULT_IMAGE_FILE_BATCH_LIMIT = 5 + +export const DEFAULT_SINGLE_CHUNK_ATTACHMENT_LIMIT = 10 diff --git a/web/app/components/datasets/common/image-uploader/hooks/use-upload.ts b/web/app/components/datasets/common/image-uploader/hooks/use-upload.ts new file mode 100644 index 0000000000..aefe48f0cd --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/hooks/use-upload.ts @@ -0,0 +1,273 @@ +import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useFileUploadConfig } from '@/service/use-common' +import type { FileEntity, FileUploadConfig } from '../types' +import { getFileType, getFileUploadConfig, traverseFileEntry } from '../utils' +import Toast from '@/app/components/base/toast' +import { useTranslation } from 'react-i18next' +import { ACCEPT_TYPES } from '../constants' +import { useFileStore } from '../store' +import { produce } from 'immer' +import { fileUpload, getFileUploadErrorMessage } from '@/app/components/base/file-uploader/utils' +import { v4 as uuid4 } from 'uuid' + +export const useUpload = () => { + const { t } = useTranslation() + const fileStore = useFileStore() + + const [dragging, setDragging] = useState(false) + const uploaderRef = useRef(null) + const dragRef = useRef(null) + const dropRef = useRef(null) + + const { data: fileUploadConfigResponse } = useFileUploadConfig() + + const fileUploadConfig: FileUploadConfig = useMemo(() => { + return getFileUploadConfig(fileUploadConfigResponse) + }, [fileUploadConfigResponse]) + + const handleDragEnter = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + if (e.target !== dragRef.current) + setDragging(true) + } + const handleDragOver = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + } + const handleDragLeave = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + if (e.target === dragRef.current) + setDragging(false) + } + + const checkFileType = useCallback((file: File) => { + const ext = getFileType(file) + return ACCEPT_TYPES.includes(ext.toLowerCase()) + }, []) + + const checkFileSize = useCallback((file: File) => { + const { size } = file + return size <= fileUploadConfig.imageFileSizeLimit * 1024 * 1024 + }, [fileUploadConfig]) + + const showErrorMessage = useCallback((type: 'type' | 'size') => { + if (type === 'type') + Toast.notify({ type: 'error', message: t('common.fileUploader.fileExtensionNotSupport') }) + else + Toast.notify({ type: 'error', message: t('dataset.imageUploader.fileSizeLimitExceeded', { size: fileUploadConfig.imageFileSizeLimit }) }) + }, [fileUploadConfig, t]) + + const getValidFiles = useCallback((files: File[]) => { + let validType = true + let validSize = true + const validFiles = files.filter((file) => { + if (!checkFileType(file)) { + validType = false + return false + } + if (!checkFileSize(file)) { + validSize = false + return false + } + return true + }) + if (!validType) + showErrorMessage('type') + else if (!validSize) + showErrorMessage('size') + + return validFiles + }, [checkFileType, checkFileSize, showErrorMessage]) + + const selectHandle = () => { + if (uploaderRef.current) + uploaderRef.current.click() + } + + const handleAddFile = useCallback((newFile: FileEntity) => { + const { + files, + setFiles, + } = fileStore.getState() + + const newFiles = produce(files, (draft) => { + draft.push(newFile) + }) + setFiles(newFiles) + }, [fileStore]) + + const handleUpdateFile = useCallback((newFile: FileEntity) => { + const { + files, + setFiles, + } = fileStore.getState() + + const newFiles = produce(files, (draft) => { + const index = draft.findIndex(file => file.id === newFile.id) + + if (index > -1) + draft[index] = newFile + }) + setFiles(newFiles) + }, [fileStore]) + + const handleRemoveFile = useCallback((fileId: string) => { + const { + files, + setFiles, + } = fileStore.getState() + + const newFiles = files.filter(file => file.id !== fileId) + setFiles(newFiles) + }, [fileStore]) + + const handleReUploadFile = useCallback((fileId: string) => { + const { + files, + setFiles, + } = fileStore.getState() + const index = files.findIndex(file => file.id === fileId) + + if (index > -1) { + const uploadingFile = files[index] + const newFiles = produce(files, (draft) => { + draft[index].progress = 0 + }) + setFiles(newFiles) + fileUpload({ + file: uploadingFile.originalFile!, + onProgressCallback: (progress) => { + handleUpdateFile({ ...uploadingFile, progress }) + }, + onSuccessCallback: (res) => { + handleUpdateFile({ ...uploadingFile, uploadedId: res.id, progress: 100 }) + }, + onErrorCallback: (error?: any) => { + const errorMessage = getFileUploadErrorMessage(error, t('common.fileUploader.uploadFromComputerUploadError'), t) + Toast.notify({ type: 'error', message: errorMessage }) + handleUpdateFile({ ...uploadingFile, progress: -1 }) + }, + }) + } + }, [fileStore, t, handleUpdateFile]) + + const handleLocalFileUpload = useCallback((file: File) => { + const reader = new FileReader() + const isImage = file.type.startsWith('image') + + reader.addEventListener( + 'load', + () => { + const uploadingFile = { + id: uuid4(), + name: file.name, + extension: getFileType(file), + mimeType: file.type, + size: file.size, + progress: 0, + originalFile: file, + base64Url: isImage ? reader.result as string : '', + } + handleAddFile(uploadingFile) + fileUpload({ + file: uploadingFile.originalFile, + onProgressCallback: (progress) => { + handleUpdateFile({ ...uploadingFile, progress }) + }, + onSuccessCallback: (res) => { + handleUpdateFile({ + ...uploadingFile, + extension: res.extension, + mimeType: res.mime_type, + size: res.size, + uploadedId: res.id, + progress: 100, + }) + }, + onErrorCallback: (error?: any) => { + const errorMessage = getFileUploadErrorMessage(error, t('common.fileUploader.uploadFromComputerUploadError'), t) + Toast.notify({ type: 'error', message: errorMessage }) + handleUpdateFile({ ...uploadingFile, progress: -1 }) + }, + }) + }, + false, + ) + reader.addEventListener( + 'error', + () => { + Toast.notify({ type: 'error', message: t('common.fileUploader.uploadFromComputerReadError') }) + }, + false, + ) + reader.readAsDataURL(file) + }, [t, handleAddFile, handleUpdateFile]) + + const handleFileUpload = useCallback((newFiles: File[]) => { + const { files } = fileStore.getState() + const { singleChunkAttachmentLimit } = fileUploadConfig + if (newFiles.length === 0) return + if (files.length + newFiles.length > singleChunkAttachmentLimit) { + Toast.notify({ + type: 'error', + message: t('datasetHitTesting.imageUploader.singleChunkAttachmentLimitTooltip', { limit: singleChunkAttachmentLimit }), + }) + return + } + for (const file of newFiles) + handleLocalFileUpload(file) + }, [fileUploadConfig, fileStore, t, handleLocalFileUpload]) + + const fileChangeHandle = useCallback((e: React.ChangeEvent) => { + const { imageFileBatchLimit } = fileUploadConfig + const files = Array.from(e.target.files ?? []).slice(0, imageFileBatchLimit) + const validFiles = getValidFiles(files) + handleFileUpload(validFiles) + }, [getValidFiles, handleFileUpload, fileUploadConfig]) + + const handleDrop = useCallback(async (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + setDragging(false) + if (!e.dataTransfer) return + const nested = await Promise.all( + Array.from(e.dataTransfer.items).map((it) => { + const entry = (it as any).webkitGetAsEntry?.() + if (entry) return traverseFileEntry(entry) + const f = it.getAsFile?.() + return f ? Promise.resolve([f]) : Promise.resolve([]) + }), + ) + const files = nested.flat().slice(0, fileUploadConfig.imageFileBatchLimit) + const validFiles = getValidFiles(files) + handleFileUpload(validFiles) + }, [fileUploadConfig, handleFileUpload, getValidFiles]) + + useEffect(() => { + dropRef.current?.addEventListener('dragenter', handleDragEnter) + dropRef.current?.addEventListener('dragover', handleDragOver) + dropRef.current?.addEventListener('dragleave', handleDragLeave) + dropRef.current?.addEventListener('drop', handleDrop) + return () => { + dropRef.current?.removeEventListener('dragenter', handleDragEnter) + dropRef.current?.removeEventListener('dragover', handleDragOver) + dropRef.current?.removeEventListener('dragleave', handleDragLeave) + dropRef.current?.removeEventListener('drop', handleDrop) + } + }, [handleDrop]) + + return { + dragging, + fileUploadConfig, + dragRef, + dropRef, + uploaderRef, + fileChangeHandle, + selectHandle, + handleRemoveFile, + handleReUploadFile, + handleLocalFileUpload, + } +} diff --git a/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/image-input.tsx b/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/image-input.tsx new file mode 100644 index 0000000000..3e15b92705 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/image-input.tsx @@ -0,0 +1,64 @@ +import React from 'react' +import cn from '@/utils/classnames' +import { RiUploadCloud2Line } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { useUpload } from '../hooks/use-upload' +import { ACCEPT_TYPES } from '../constants' + +const ImageUploader = () => { + const { t } = useTranslation() + + const { + dragging, + fileUploadConfig, + dragRef, + dropRef, + uploaderRef, + fileChangeHandle, + selectHandle, + } = useUpload() + + return ( +
+ `.${ext}`).join(',')} + onChange={fileChangeHandle} + /> +
+
+ +
+ {t('dataset.imageUploader.button')} + + {t('dataset.imageUploader.browse')} + +
+
+
+ {t('dataset.imageUploader.tip', { + size: fileUploadConfig.imageFileSizeLimit, + supportTypes: ACCEPT_TYPES.join(', '), + batchCount: fileUploadConfig.imageFileBatchLimit, + })} +
+ {dragging &&
} +
+
+ ) +} + +export default React.memo(ImageUploader) diff --git a/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/image-item.tsx b/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/image-item.tsx new file mode 100644 index 0000000000..a5bfb65fa2 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/image-item.tsx @@ -0,0 +1,95 @@ +import { + memo, + useCallback, +} from 'react' +import { + RiCloseLine, +} from '@remixicon/react' +import FileImageRender from '@/app/components/base/file-uploader/file-image-render' +import type { FileEntity } from '../types' +import ProgressCircle from '@/app/components/base/progress-bar/progress-circle' +import { ReplayLine } from '@/app/components/base/icons/src/vender/other' +import { fileIsUploaded } from '../utils' +import Button from '@/app/components/base/button' + +type ImageItemProps = { + file: FileEntity + showDeleteAction?: boolean + onRemove?: (fileId: string) => void + onReUpload?: (fileId: string) => void + onPreview?: (fileId: string) => void +} +const ImageItem = ({ + file, + showDeleteAction, + onRemove, + onReUpload, + onPreview, +}: ImageItemProps) => { + const { id, progress, base64Url, sourceUrl } = file + + const handlePreview = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onPreview?.(id) + }, [onPreview, id]) + + const handleRemove = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onRemove?.(id) + }, [onRemove, id]) + + const handleReUpload = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onReUpload?.(id) + }, [onReUpload, id]) + + return ( +
+ { + showDeleteAction && ( + + ) + } + + { + progress >= 0 && !fileIsUploaded(file) && ( +
+ +
+ ) + } + { + progress === -1 && ( +
+ +
+ ) + } +
+ ) +} + +export default memo(ImageItem) diff --git a/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/index.tsx b/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/index.tsx new file mode 100644 index 0000000000..3efa3a19d7 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/image-uploader-in-chunk/index.tsx @@ -0,0 +1,94 @@ +import { + FileContextProvider, + useFileStoreWithSelector, +} from '../store' +import type { FileEntity } from '../types' +import FileItem from './image-item' +import { useUpload } from '../hooks/use-upload' +import ImageInput from './image-input' +import cn from '@/utils/classnames' +import { useCallback, useState } from 'react' +import type { ImageInfo } from '@/app/components/datasets/common/image-previewer' +import ImagePreviewer from '@/app/components/datasets/common/image-previewer' + +type ImageUploaderInChunkProps = { + disabled?: boolean + className?: string +} +const ImageUploaderInChunk = ({ + disabled, + className, +}: ImageUploaderInChunkProps) => { + const files = useFileStoreWithSelector(s => s.files) + const [previewIndex, setPreviewIndex] = useState(0) + const [previewImages, setPreviewImages] = useState([]) + + const handleImagePreview = useCallback((fileId: string) => { + const index = files.findIndex(item => item.id === fileId) + if (index === -1) return + setPreviewIndex(index) + setPreviewImages(files.map(item => ({ + url: item.base64Url || item.sourceUrl || '', + name: item.name, + size: item.size, + }))) + }, [files]) + + const handleClosePreview = useCallback(() => { + setPreviewImages([]) + }, []) + + const { + handleRemoveFile, + handleReUploadFile, + } = useUpload() + + return ( +
+ {!disabled && } +
+ { + files.map(file => ( + + )) + } +
+ {previewImages.length > 0 && ( + + )} +
+ ) +} + +export type ImageUploaderInChunkWrapperProps = { + value?: FileEntity[] + onChange: (files: FileEntity[]) => void +} & ImageUploaderInChunkProps + +const ImageUploaderInChunkWrapper = ({ + value, + onChange, + ...props +}: ImageUploaderInChunkWrapperProps) => { + return ( + + + + ) +} + +export default ImageUploaderInChunkWrapper diff --git a/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/image-input.tsx b/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/image-input.tsx new file mode 100644 index 0000000000..4f230e3957 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/image-input.tsx @@ -0,0 +1,64 @@ +import React from 'react' +import { useTranslation } from 'react-i18next' +import { useUpload } from '../hooks/use-upload' +import { ACCEPT_TYPES } from '../constants' +import { useFileStoreWithSelector } from '../store' +import { RiImageAddLine } from '@remixicon/react' +import Tooltip from '@/app/components/base/tooltip' + +const ImageUploader = () => { + const { t } = useTranslation() + const files = useFileStoreWithSelector(s => s.files) + + const { + fileUploadConfig, + uploaderRef, + fileChangeHandle, + selectHandle, + } = useUpload() + + return ( +
+ `.${ext}`).join(',')} + onChange={fileChangeHandle} + /> +
+ +
+
+ +
+ {files.length === 0 && ( + + {t('datasetHitTesting.imageUploader.tip', { + size: fileUploadConfig.imageFileSizeLimit, + batchCount: fileUploadConfig.imageFileBatchLimit, + })} + + )} +
+
+
+
+ ) +} + +export default React.memo(ImageUploader) diff --git a/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/image-item.tsx b/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/image-item.tsx new file mode 100644 index 0000000000..a47356e560 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/image-item.tsx @@ -0,0 +1,95 @@ +import { + memo, + useCallback, +} from 'react' +import { + RiCloseLine, +} from '@remixicon/react' +import FileImageRender from '@/app/components/base/file-uploader/file-image-render' +import type { FileEntity } from '../types' +import ProgressCircle from '@/app/components/base/progress-bar/progress-circle' +import { ReplayLine } from '@/app/components/base/icons/src/vender/other' +import { fileIsUploaded } from '../utils' +import Button from '@/app/components/base/button' + +type ImageItemProps = { + file: FileEntity + showDeleteAction?: boolean + onRemove?: (fileId: string) => void + onReUpload?: (fileId: string) => void + onPreview?: (fileId: string) => void +} +const ImageItem = ({ + file, + showDeleteAction, + onRemove, + onReUpload, + onPreview, +}: ImageItemProps) => { + const { id, progress, base64Url, sourceUrl } = file + + const handlePreview = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onPreview?.(id) + }, [onPreview, id]) + + const handleRemove = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onRemove?.(id) + }, [onRemove, id]) + + const handleReUpload = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + e.preventDefault() + onReUpload?.(id) + }, [onReUpload, id]) + + return ( +
+ { + showDeleteAction && ( + + ) + } + + { + progress >= 0 && !fileIsUploaded(file) && ( +
+ +
+ ) + } + { + progress === -1 && ( +
+ +
+ ) + } +
+ ) +} + +export default memo(ImageItem) diff --git a/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/index.tsx b/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/index.tsx new file mode 100644 index 0000000000..2d04132842 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing/index.tsx @@ -0,0 +1,131 @@ +import { + useCallback, + useState, +} from 'react' +import { + FileContextProvider, +} from '../store' +import type { FileEntity } from '../types' +import { useUpload } from '../hooks/use-upload' +import ImageInput from './image-input' +import cn from '@/utils/classnames' +import { useTranslation } from 'react-i18next' +import { useFileStoreWithSelector } from '../store' +import ImageItem from './image-item' +import type { ImageInfo } from '@/app/components/datasets/common/image-previewer' +import ImagePreviewer from '@/app/components/datasets/common/image-previewer' + +type ImageUploaderInRetrievalTestingProps = { + textArea: React.ReactNode + actionButton: React.ReactNode + showUploader?: boolean + className?: string + actionAreaClassName?: string +} +const ImageUploaderInRetrievalTesting = ({ + textArea, + actionButton, + showUploader = true, + className, + actionAreaClassName, +}: ImageUploaderInRetrievalTestingProps) => { + const { t } = useTranslation() + const files = useFileStoreWithSelector(s => s.files) + const [previewIndex, setPreviewIndex] = useState(0) + const [previewImages, setPreviewImages] = useState([]) + const { + dragging, + dragRef, + dropRef, + handleRemoveFile, + handleReUploadFile, + } = useUpload() + + const handleImagePreview = useCallback((fileId: string) => { + const index = files.findIndex(item => item.id === fileId) + if (index === -1) return + setPreviewIndex(index) + setPreviewImages(files.map(item => ({ + url: item.base64Url || item.sourceUrl || '', + name: item.name, + size: item.size, + }))) + }, [files]) + + const handleClosePreview = useCallback(() => { + setPreviewImages([]) + }, []) + + return ( +
+ {dragging && ( +
+
{t('datasetHitTesting.imageUploader.dropZoneTip')}
+
+
+ )} + {textArea} + { + showUploader && !!files.length && ( +
+ { + files.map(file => ( + + )) + } +
+ ) + } +
+ {showUploader && } + {actionButton} +
+ {previewImages.length > 0 && ( + + )} +
+ ) +} + +export type ImageUploaderInRetrievalTestingWrapperProps = { + value?: FileEntity[] + onChange: (files: FileEntity[]) => void +} & ImageUploaderInRetrievalTestingProps + +const ImageUploaderInRetrievalTestingWrapper = ({ + value, + onChange, + ...props +}: ImageUploaderInRetrievalTestingWrapperProps) => { + return ( + + + + ) +} + +export default ImageUploaderInRetrievalTestingWrapper diff --git a/web/app/components/datasets/common/image-uploader/store.tsx b/web/app/components/datasets/common/image-uploader/store.tsx new file mode 100644 index 0000000000..e3c9e28a84 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/store.tsx @@ -0,0 +1,67 @@ +import { + createContext, + useContext, + useRef, +} from 'react' +import { + create, + useStore, +} from 'zustand' +import type { + FileEntity, +} from './types' + +type Shape = { + files: FileEntity[] + setFiles: (files: FileEntity[]) => void +} + +export const createFileStore = ( + value: FileEntity[] = [], + onChange?: (files: FileEntity[]) => void, +) => { + return create(set => ({ + files: value ? [...value] : [], + setFiles: (files) => { + set({ files }) + onChange?.(files) + }, + })) +} + +type FileStore = ReturnType +export const FileContext = createContext(null) + +export function useFileStoreWithSelector(selector: (state: Shape) => T): T { + const store = useContext(FileContext) + if (!store) + throw new Error('Missing FileContext.Provider in the tree') + + return useStore(store, selector) +} + +export const useFileStore = () => { + return useContext(FileContext)! +} + +type FileProviderProps = { + children: React.ReactNode + value?: FileEntity[] + onChange?: (files: FileEntity[]) => void +} +export const FileContextProvider = ({ + children, + value, + onChange, +}: FileProviderProps) => { + const storeRef = useRef(undefined) + + if (!storeRef.current) + storeRef.current = createFileStore(value, onChange) + + return ( + + {children} + + ) +} diff --git a/web/app/components/datasets/common/image-uploader/types.ts b/web/app/components/datasets/common/image-uploader/types.ts new file mode 100644 index 0000000000..e918f2b41e --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/types.ts @@ -0,0 +1,18 @@ +export type FileEntity = { + id: string + name: string + size: number + extension: string + mimeType: string + progress: number // -1: error, 0 ~ 99: uploading, 100: uploaded + originalFile?: File // used for re-uploading + uploadedId?: string // for uploaded image id + sourceUrl?: string // for uploaded image + base64Url?: string // for image preview during uploading +} + +export type FileUploadConfig = { + imageFileSizeLimit: number + imageFileBatchLimit: number + singleChunkAttachmentLimit: number +} diff --git a/web/app/components/datasets/common/image-uploader/utils.ts b/web/app/components/datasets/common/image-uploader/utils.ts new file mode 100644 index 0000000000..842b279a98 --- /dev/null +++ b/web/app/components/datasets/common/image-uploader/utils.ts @@ -0,0 +1,92 @@ +import type { FileUploadConfigResponse } from '@/models/common' +import type { FileEntity } from './types' +import { + DEFAULT_IMAGE_FILE_BATCH_LIMIT, + DEFAULT_IMAGE_FILE_SIZE_LIMIT, + DEFAULT_SINGLE_CHUNK_ATTACHMENT_LIMIT, +} from './constants' + +export const getFileType = (currentFile: File) => { + if (!currentFile) + return '' + + const arr = currentFile.name.split('.') + return arr[arr.length - 1] +} + +type FileWithPath = { + relativePath?: string +} & File + +export const traverseFileEntry = (entry: any, prefix = ''): Promise => { + return new Promise((resolve) => { + if (entry.isFile) { + entry.file((file: FileWithPath) => { + file.relativePath = `${prefix}${file.name}` + resolve([file]) + }) + } + else if (entry.isDirectory) { + const reader = entry.createReader() + const entries: any[] = [] + const read = () => { + reader.readEntries(async (results: FileSystemEntry[]) => { + if (!results.length) { + const files = await Promise.all( + entries.map(ent => + traverseFileEntry(ent, `${prefix}${entry.name}/`), + ), + ) + resolve(files.flat()) + } + else { + entries.push(...results) + read() + } + }) + } + read() + } + else { + resolve([]) + } + }) +} + +export const fileIsUploaded = (file: FileEntity) => { + if (file.uploadedId || file.progress === 100) + return true +} + +const getNumberValue = (value: number | string | undefined | null): number => { + if (value === undefined || value === null) + return 0 + if (typeof value === 'number') + return value + if (typeof value === 'string') + return Number(value) + return 0 +} + +export const getFileUploadConfig = (fileUploadConfigResponse: FileUploadConfigResponse | undefined) => { + if (!fileUploadConfigResponse) { + return { + imageFileSizeLimit: DEFAULT_IMAGE_FILE_SIZE_LIMIT, + imageFileBatchLimit: DEFAULT_IMAGE_FILE_BATCH_LIMIT, + singleChunkAttachmentLimit: DEFAULT_SINGLE_CHUNK_ATTACHMENT_LIMIT, + } + } + const { + image_file_batch_limit, + single_chunk_attachment_limit, + attachment_image_file_size_limit, + } = fileUploadConfigResponse + const imageFileSizeLimit = getNumberValue(attachment_image_file_size_limit) + const imageFileBatchLimit = getNumberValue(image_file_batch_limit) + const singleChunkAttachmentLimit = getNumberValue(single_chunk_attachment_limit) + return { + imageFileSizeLimit: imageFileSizeLimit > 0 ? imageFileSizeLimit : DEFAULT_IMAGE_FILE_SIZE_LIMIT, + imageFileBatchLimit: imageFileBatchLimit > 0 ? imageFileBatchLimit : DEFAULT_IMAGE_FILE_BATCH_LIMIT, + singleChunkAttachmentLimit: singleChunkAttachmentLimit > 0 ? singleChunkAttachmentLimit : DEFAULT_SINGLE_CHUNK_ATTACHMENT_LIMIT, + } +} diff --git a/web/app/components/datasets/common/retrieval-method-config/index.tsx b/web/app/components/datasets/common/retrieval-method-config/index.tsx index ed230c52ce..c0952ed4a4 100644 --- a/web/app/components/datasets/common/retrieval-method-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-method-config/index.tsx @@ -20,12 +20,14 @@ import { EffectColor } from '../../settings/chunk-structure/types' type Props = { disabled?: boolean value: RetrievalConfig + showMultiModalTip?: boolean onChange: (value: RetrievalConfig) => void } const RetrievalMethodConfig: FC = ({ disabled = false, value, + showMultiModalTip = false, onChange, }) => { const { t } = useTranslation() @@ -110,6 +112,7 @@ const RetrievalMethodConfig: FC = ({ type={RETRIEVE_METHOD.semantic} value={value} onChange={onChange} + showMultiModalTip={showMultiModalTip} /> )} @@ -132,6 +135,7 @@ const RetrievalMethodConfig: FC = ({ type={RETRIEVE_METHOD.fullText} value={value} onChange={onChange} + showMultiModalTip={showMultiModalTip} /> )} @@ -155,6 +159,7 @@ const RetrievalMethodConfig: FC = ({ type={RETRIEVE_METHOD.hybrid} value={value} onChange={onChange} + showMultiModalTip={showMultiModalTip} /> )} diff --git a/web/app/components/datasets/common/retrieval-param-config/index.tsx b/web/app/components/datasets/common/retrieval-param-config/index.tsx index 0c28149d56..2b703cc44d 100644 --- a/web/app/components/datasets/common/retrieval-param-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-param-config/index.tsx @@ -24,16 +24,19 @@ import { import WeightedScore from '@/app/components/app/configuration/dataset-config/params-config/weighted-score' import Toast from '@/app/components/base/toast' import RadioCard from '@/app/components/base/radio-card' +import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' type Props = { type: RETRIEVE_METHOD value: RetrievalConfig + showMultiModalTip?: boolean onChange: (value: RetrievalConfig) => void } const RetrievalParamConfig: FC = ({ type, value, + showMultiModalTip = false, onChange, }) => { const { t } = useTranslation() @@ -133,19 +136,32 @@ const RetrievalParamConfig: FC = ({
{ value.reranking_enable && ( - { - onChange({ - ...value, - reranking_model: { - reranking_provider_name: v.provider, - reranking_model_name: v.model, - }, - }) - }} - /> + <> + { + onChange({ + ...value, + reranking_model: { + reranking_provider_name: v.provider, + reranking_model_name: v.model, + }, + }) + }} + /> + {showMultiModalTip && ( +
+
+
+ +
+ + {t('datasetSettings.form.retrievalSetting.multiModalTip')} + +
+ )} + ) }
@@ -239,19 +255,32 @@ const RetrievalParamConfig: FC = ({ } { value.reranking_mode !== RerankingModeEnum.WeightedScore && ( - { - onChange({ - ...value, - reranking_model: { - reranking_provider_name: v.provider, - reranking_model_name: v.model, - }, - }) - }} - /> + <> + { + onChange({ + ...value, + reranking_model: { + reranking_provider_name: v.provider, + reranking_model_name: v.model, + }, + }) + }} + /> + {showMultiModalTip && ( +
+
+
+ +
+ + {t('datasetSettings.form.retrievalSetting.multiModalTip')} + +
+ )} + ) }
diff --git a/web/app/components/datasets/create/embedding-process/index.tsx b/web/app/components/datasets/create/embedding-process/index.tsx index 7b2eda1dcd..4e78eb2034 100644 --- a/web/app/components/datasets/create/embedding-process/index.tsx +++ b/web/app/components/datasets/create/embedding-process/index.tsx @@ -1,9 +1,7 @@ import type { FC } from 'react' import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react' -import useSWR from 'swr' import { useRouter } from 'next/navigation' import { useTranslation } from 'react-i18next' -import { omit } from 'lodash-es' import { RiArrowRightLine, RiCheckboxCircleFill, @@ -25,7 +23,7 @@ import type { LegacyDataSourceInfo, ProcessRuleResponse, } from '@/models/datasets' -import { fetchIndexingStatusBatch as doFetchIndexingStatus, fetchProcessRule } from '@/service/datasets' +import { fetchIndexingStatusBatch as doFetchIndexingStatus } from '@/service/datasets' import { DataSourceType, ProcessMode } from '@/models/datasets' import NotionIcon from '@/app/components/base/notion-icon' import PriorityLabel from '@/app/components/billing/priority-label' @@ -40,6 +38,7 @@ import { useInvalidDocumentList } from '@/service/knowledge/use-document' import Divider from '@/app/components/base/divider' import { useDatasetApiAccessUrl } from '@/hooks/use-api-access-url' import Link from 'next/link' +import { useProcessRule } from '@/service/knowledge/use-dataset' type Props = { datasetId: string @@ -207,12 +206,7 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index }, []) // get rule - const { data: ruleDetail } = useSWR({ - action: 'fetchProcessRule', - params: { documentId: getFirstDocument.id }, - }, apiParams => fetchProcessRule(omit(apiParams, 'action')), { - revalidateOnFocus: false, - }) + const { data: ruleDetail } = useProcessRule(getFirstDocument?.id) const router = useRouter() const invalidDocumentList = useInvalidDocumentList() diff --git a/web/app/components/datasets/create/file-uploader/index.tsx b/web/app/components/datasets/create/file-uploader/index.tsx index 4aec0d4082..700a5f7680 100644 --- a/web/app/components/datasets/create/file-uploader/index.tsx +++ b/web/app/components/datasets/create/file-uploader/index.tsx @@ -2,7 +2,6 @@ import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import useSWR from 'swr' import { RiDeleteBinLine, RiUploadCloud2Line } from '@remixicon/react' import DocumentFileIcon from '../../common/document-file-icon' import cn from '@/utils/classnames' @@ -11,8 +10,7 @@ import { ToastContext } from '@/app/components/base/toast' import SimplePieChart from '@/app/components/base/simple-pie-chart' import { upload } from '@/service/base' -import { fetchFileUploadConfig } from '@/service/common' -import { fetchSupportFileTypes } from '@/service/datasets' +import { useFileSupportTypes, useFileUploadConfig } from '@/service/use-common' import I18n from '@/context/i18n' import { LanguagesSupported } from '@/i18n-config/language' import { IS_CE_EDITION } from '@/config' @@ -27,7 +25,7 @@ type IFileUploaderProps = { onFileUpdate: (fileItem: FileItem, progress: number, list: FileItem[]) => void onFileListUpdate?: (files: FileItem[]) => void onPreview: (file: File) => void - notSupportBatchUpload?: boolean + supportBatchUpload?: boolean } const FileUploader = ({ @@ -37,7 +35,7 @@ const FileUploader = ({ onFileUpdate, onFileListUpdate, onPreview, - notSupportBatchUpload, + supportBatchUpload = false, }: IFileUploaderProps) => { const { t } = useTranslation() const { notify } = useContext(ToastContext) @@ -46,10 +44,10 @@ const FileUploader = ({ const dropRef = useRef(null) const dragRef = useRef(null) const fileUploader = useRef(null) - const hideUpload = notSupportBatchUpload && fileList.length > 0 + const hideUpload = !supportBatchUpload && fileList.length > 0 - const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig) - const { data: supportFileTypesResponse } = useSWR({ url: '/files/support-type' }, fetchSupportFileTypes) + const { data: fileUploadConfigResponse } = useFileUploadConfig() + const { data: supportFileTypesResponse } = useFileSupportTypes() const supportTypes = supportFileTypesResponse?.allowed_extensions || [] const supportTypesShowNames = (() => { const extensionMap: { [key: string]: string } = { @@ -68,11 +66,11 @@ const FileUploader = ({ .join(locale !== LanguagesSupported[1] ? ', ' : '、 ') })() const ACCEPTS = supportTypes.map((ext: string) => `.${ext}`) - const fileUploadConfig = useMemo(() => fileUploadConfigResponse ?? { - file_size_limit: 15, - batch_count_limit: 5, - file_upload_limit: 5, - }, [fileUploadConfigResponse]) + const fileUploadConfig = useMemo(() => ({ + file_size_limit: fileUploadConfigResponse?.file_size_limit ?? 15, + batch_count_limit: supportBatchUpload ? (fileUploadConfigResponse?.batch_count_limit ?? 5) : 1, + file_upload_limit: supportBatchUpload ? (fileUploadConfigResponse?.file_upload_limit ?? 5) : 1, + }), [fileUploadConfigResponse, supportBatchUpload]) const fileListRef = useRef([]) @@ -256,12 +254,12 @@ const FileUploader = ({ }), ) let files = nested.flat() - if (notSupportBatchUpload) files = files.slice(0, 1) + if (!supportBatchUpload) files = files.slice(0, 1) files = files.slice(0, fileUploadConfig.batch_count_limit) const valid = files.filter(isValid) initialUpload(valid) }, - [initialUpload, isValid, notSupportBatchUpload, traverseFileEntry, fileUploadConfig], + [initialUpload, isValid, supportBatchUpload, traverseFileEntry, fileUploadConfig], ) const selectHandle = () => { if (fileUploader.current) @@ -305,7 +303,7 @@ const FileUploader = ({ id="fileUploader" className="hidden" type="file" - multiple={!notSupportBatchUpload} + multiple={supportBatchUpload} accept={ACCEPTS.join(',')} onChange={fileChangeHandle} /> @@ -319,7 +317,7 @@ const FileUploader = ({ - {notSupportBatchUpload ? t('datasetCreation.stepOne.uploader.buttonSingleFile') : t('datasetCreation.stepOne.uploader.button')} + {supportBatchUpload ? t('datasetCreation.stepOne.uploader.button') : t('datasetCreation.stepOne.uploader.buttonSingleFile')} {supportTypes.length > 0 && ( )} @@ -328,7 +326,7 @@ const FileUploader = ({
{t('datasetCreation.stepOne.uploader.tip', { size: fileUploadConfig.file_size_limit, supportTypes: supportTypesShowNames, - batchCount: notSupportBatchUpload ? 1 : fileUploadConfig.batch_count_limit, + batchCount: fileUploadConfig.batch_count_limit, totalCount: fileUploadConfig.file_upload_limit, })}
{dragging &&
} diff --git a/web/app/components/datasets/create/step-one/index.tsx b/web/app/components/datasets/create/step-one/index.tsx index cab1637661..e70feb204c 100644 --- a/web/app/components/datasets/create/step-one/index.tsx +++ b/web/app/components/datasets/create/step-one/index.tsx @@ -22,6 +22,10 @@ import classNames from '@/utils/classnames' import { ENABLE_WEBSITE_FIRECRAWL, ENABLE_WEBSITE_JINAREADER, ENABLE_WEBSITE_WATERCRAWL } from '@/config' import NotionConnector from '@/app/components/base/notion-connector' import type { DataSourceAuth } from '@/app/components/header/account-setting/data-source-page-new/types' +import PlanUpgradeModal from '@/app/components/billing/plan-upgrade-modal' +import { useBoolean } from 'ahooks' +import { Plan } from '@/app/components/billing/type' +import UpgradeCard from './upgrade-card' type IStepOneProps = { datasetId?: string @@ -52,7 +56,7 @@ const StepOne = ({ dataSourceTypeDisable, changeType, onSetting, - onStepChange, + onStepChange: doOnStepChange, files, updateFileList, updateFile, @@ -110,7 +114,33 @@ const StepOne = ({ const hasNotin = notionPages.length > 0 const isVectorSpaceFull = plan.usage.vectorSpace >= plan.total.vectorSpace const isShowVectorSpaceFull = (allFileLoaded || hasNotin) && isVectorSpaceFull && enableBilling - const notSupportBatchUpload = enableBilling && plan.type === 'sandbox' + const supportBatchUpload = !enableBilling || plan.type !== Plan.sandbox + const notSupportBatchUpload = !supportBatchUpload + + const [isShowPlanUpgradeModal, { + setTrue: showPlanUpgradeModal, + setFalse: hidePlanUpgradeModal, + }] = useBoolean(false) + const onStepChange = useCallback(() => { + if (notSupportBatchUpload) { + let isMultiple = false + if (dataSourceType === DataSourceType.FILE && files.length > 1) + isMultiple = true + + if (dataSourceType === DataSourceType.NOTION && notionPages.length > 1) + isMultiple = true + + if (dataSourceType === DataSourceType.WEB && websitePages.length > 1) + isMultiple = true + + if (isMultiple) { + showPlanUpgradeModal() + return + } + } + doOnStepChange() + }, [dataSourceType, doOnStepChange, files.length, notSupportBatchUpload, notionPages.length, showPlanUpgradeModal, websitePages.length]) + const nextDisabled = useMemo(() => { if (!files.length) return true @@ -229,7 +259,7 @@ const StepOne = ({ onFileListUpdate={updateFileList} onFileUpdate={updateFile} onPreview={updateCurrentFile} - notSupportBatchUpload={notSupportBatchUpload} + supportBatchUpload={supportBatchUpload} /> {isShowVectorSpaceFull && (
@@ -244,6 +274,14 @@ const StepOne = ({
+ { + enableBilling && plan.type === Plan.sandbox && files.length > 0 && ( +
+
+ +
+ ) + } )} {dataSourceType === DataSourceType.NOTION && ( @@ -330,6 +368,14 @@ const StepOne = ({ /> )} {currentWebsite && } + {isShowPlanUpgradeModal && ( + + )}
diff --git a/web/app/components/datasets/create/step-one/upgrade-card.tsx b/web/app/components/datasets/create/step-one/upgrade-card.tsx new file mode 100644 index 0000000000..af683d8ace --- /dev/null +++ b/web/app/components/datasets/create/step-one/upgrade-card.tsx @@ -0,0 +1,33 @@ +'use client' +import UpgradeBtn from '@/app/components/billing/upgrade-btn' +import { useModalContext } from '@/context/modal-context' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import { useTranslation } from 'react-i18next' + +const UpgradeCard: FC = () => { + const { t } = useTranslation() + const { setShowPricingModal } = useModalContext() + + const handleUpgrade = useCallback(() => { + setShowPricingModal() + }, [setShowPricingModal]) + + return ( +
+
+
{t('billing.upgrade.uploadMultipleFiles.title')}
+
{t('billing.upgrade.uploadMultipleFiles.description')}
+
+ +
+ ) +} +export default React.memo(UpgradeCard) diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx index 22d6837754..43be89c326 100644 --- a/web/app/components/datasets/create/step-two/index.tsx +++ b/web/app/components/datasets/create/step-two/index.tsx @@ -1,6 +1,6 @@ 'use client' import type { FC, PropsWithChildren } from 'react' -import React, { useCallback, useEffect, useState } from 'react' +import React, { useCallback, useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { @@ -63,6 +63,7 @@ import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/aler import { noop } from 'lodash-es' import { useDocLink } from '@/context/i18n' import { useInvalidDatasetList } from '@/service/knowledge/use-dataset' +import { checkShowMultiModalTip } from '../../settings/utils' const TextLabel: FC = (props) => { return @@ -495,12 +496,6 @@ const StepTwo = ({ setDefaultConfig(data.rules) setLimitMaxChunkLength(data.limits.indexing_max_segmentation_tokens_length) }, - onError(error) { - Toast.notify({ - type: 'error', - message: `${error}`, - }) - }, }) const getRulesFromDetail = () => { @@ -538,22 +533,8 @@ const StepTwo = ({ setSegmentationType(documentDetail.dataset_process_rule.mode) } - const createFirstDocumentMutation = useCreateFirstDocument({ - onError(error) { - Toast.notify({ - type: 'error', - message: `${error}`, - }) - }, - }) - const createDocumentMutation = useCreateDocument(datasetId!, { - onError(error) { - Toast.notify({ - type: 'error', - message: `${error}`, - }) - }, - }) + const createFirstDocumentMutation = useCreateFirstDocument() + const createDocumentMutation = useCreateDocument(datasetId!) const isCreating = createFirstDocumentMutation.isPending || createDocumentMutation.isPending const invalidDatasetList = useInvalidDatasetList() @@ -613,6 +594,20 @@ const StepTwo = ({ const isModelAndRetrievalConfigDisabled = !!datasetId && !!currentDataset?.data_source_type + const showMultiModalTip = useMemo(() => { + return checkShowMultiModalTip({ + embeddingModel, + rerankingEnable: retrievalConfig.reranking_enable, + rerankModel: { + rerankingProviderName: retrievalConfig.reranking_model.reranking_provider_name, + rerankingModelName: retrievalConfig.reranking_model.reranking_model_name, + }, + indexMethod: indexType, + embeddingModelList, + rerankModelList, + }) + }, [embeddingModel, retrievalConfig.reranking_enable, retrievalConfig.reranking_model, indexType, embeddingModelList, rerankModelList]) + return (
@@ -1012,6 +1007,7 @@ const StepTwo = ({ disabled={isModelAndRetrievalConfigDisabled} value={retrievalConfig} onChange={setRetrievalConfig} + showMultiModalTip={showMultiModalTip} /> ) : ( diff --git a/web/app/components/datasets/create/website/base/mock-crawl-result.ts b/web/app/components/datasets/create/website/base/mock-crawl-result.ts deleted file mode 100644 index 88c05d3d0a..0000000000 --- a/web/app/components/datasets/create/website/base/mock-crawl-result.ts +++ /dev/null @@ -1,24 +0,0 @@ -import type { CrawlResultItem } from '@/models/datasets' - -const result: CrawlResultItem[] = [ - { - title: 'Start the frontend Docker container separately', - content: 'Markdown 1', - description: 'Description 1', - source_url: 'https://example.com/1', - }, - { - title: 'Advanced Tool Integration', - content: 'Markdown 2', - description: 'Description 2', - source_url: 'https://example.com/2', - }, - { - title: 'Local Source Code Start | English | Dify', - content: 'Markdown 3', - description: 'Description 3', - source_url: 'https://example.com/3', - }, -] - -export default result diff --git a/web/app/components/datasets/create/website/firecrawl/index.tsx b/web/app/components/datasets/create/website/firecrawl/index.tsx index 51c2c7d505..17d8d6416e 100644 --- a/web/app/components/datasets/create/website/firecrawl/index.tsx +++ b/web/app/components/datasets/create/website/firecrawl/index.tsx @@ -166,10 +166,6 @@ const FireCrawl: FC = ({ setCrawlErrorMessage(errorMessage || t(`${I18N_PREFIX}.unknownError`)) } else { - data.data = data.data.map((item: any) => ({ - ...item, - content: item.markdown, - })) setCrawlResult(data) onCheckedCrawlResultChange(data.data || []) // default select the crawl result setCrawlErrorMessage('') diff --git a/web/app/components/datasets/create/website/jina-reader/index.tsx b/web/app/components/datasets/create/website/jina-reader/index.tsx index b6e6177af2..7257e8f7e6 100644 --- a/web/app/components/datasets/create/website/jina-reader/index.tsx +++ b/web/app/components/datasets/create/website/jina-reader/index.tsx @@ -157,7 +157,7 @@ const JinaReader: FC = ({ total: 1, data: [{ title, - content, + markdown: content, description, source_url: url, }], diff --git a/web/app/components/datasets/create/website/preview.tsx b/web/app/components/datasets/create/website/preview.tsx index d148c87196..f43dc83589 100644 --- a/web/app/components/datasets/create/website/preview.tsx +++ b/web/app/components/datasets/create/website/preview.tsx @@ -32,7 +32,7 @@ const WebsitePreview = ({
{payload.source_url}
-
{payload.content}
+
{payload.markdown}
) diff --git a/web/app/components/datasets/create/website/watercrawl/index.tsx b/web/app/components/datasets/create/website/watercrawl/index.tsx index 67a3e53feb..f1aee37e19 100644 --- a/web/app/components/datasets/create/website/watercrawl/index.tsx +++ b/web/app/components/datasets/create/website/watercrawl/index.tsx @@ -132,7 +132,7 @@ const WaterCrawl: FC = ({ }, } } - }, [crawlOptions.limit]) + }, [crawlOptions.limit, onCheckedCrawlResultChange]) const handleRun = useCallback(async (url: string) => { const { isValid, errorMsg } = checkValid(url) @@ -174,7 +174,7 @@ const WaterCrawl: FC = ({ finally { setStep(Step.finished) } - }, [checkValid, crawlOptions, onJobIdChange, t, waitForCrawlFinished]) + }, [checkValid, crawlOptions, onCheckedCrawlResultChange, onJobIdChange, t, waitForCrawlFinished]) return (
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/index.tsx index 0de3879969..0e588e4e1d 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/index.tsx @@ -10,14 +10,12 @@ import Trigger from './trigger' import List from './list' export type CredentialSelectorProps = { - pluginName: string currentCredentialId: string onCredentialChange: (credentialId: string) => void credentials: Array } const CredentialSelector = ({ - pluginName, currentCredentialId, onCredentialChange, credentials, @@ -50,7 +48,6 @@ const CredentialSelector = ({ @@ -58,7 +55,6 @@ const CredentialSelector = ({ diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/item.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/item.tsx index ab8de51fb1..9c8368e299 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/item.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/item.tsx @@ -2,22 +2,18 @@ import { CredentialIcon } from '@/app/components/datasets/common/credential-icon import type { DataSourceCredential } from '@/types/pipeline' import { RiCheckLine } from '@remixicon/react' import React, { useCallback } from 'react' -import { useTranslation } from 'react-i18next' type ItemProps = { credential: DataSourceCredential - pluginName: string isSelected: boolean onCredentialChange: (credentialId: string) => void } const Item = ({ credential, - pluginName, isSelected, onCredentialChange, }: ItemProps) => { - const { t } = useTranslation() const { avatar_url, name } = credential const handleCredentialChange = useCallback(() => { @@ -30,15 +26,12 @@ const Item = ({ onClick={handleCredentialChange} > - {t('datasetPipeline.credentialSelector.name', { - credentialName: name, - pluginName, - })} + {name} { isSelected && ( diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/list.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/list.tsx index b161a80309..cdcb2b5af5 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/list.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/list.tsx @@ -5,14 +5,12 @@ import Item from './item' type ListProps = { currentCredentialId: string credentials: Array - pluginName: string onCredentialChange: (credentialId: string) => void } const List = ({ currentCredentialId, credentials, - pluginName, onCredentialChange, }: ListProps) => { return ( @@ -24,7 +22,6 @@ const List = ({ diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/trigger.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/trigger.tsx index 88f47384f3..dc328ef87f 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/trigger.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/trigger.tsx @@ -1,23 +1,18 @@ import React from 'react' import type { DataSourceCredential } from '@/types/pipeline' -import { useTranslation } from 'react-i18next' import { RiArrowDownSLine } from '@remixicon/react' import cn from '@/utils/classnames' import { CredentialIcon } from '@/app/components/datasets/common/credential-icon' type TriggerProps = { currentCredential: DataSourceCredential | undefined - pluginName: string isOpen: boolean } const Trigger = ({ currentCredential, - pluginName, isOpen, }: TriggerProps) => { - const { t } = useTranslation() - const { avatar_url, name = '', @@ -31,16 +26,13 @@ const Trigger = ({ )} >
- {t('datasetPipeline.credentialSelector.name', { - credentialName: name, - pluginName, - })} + {name}
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/header.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/header.tsx index ef8932ba24..b826e53d93 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/header.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/header.tsx @@ -11,12 +11,14 @@ type HeaderProps = { docTitle: string docLink: string onClickConfiguration?: () => void + pluginName: string } & CredentialSelectorProps const Header = ({ docTitle, docLink, onClickConfiguration, + pluginName, ...rest }: HeaderProps) => { const { t } = useTranslation() @@ -29,7 +31,7 @@ const Header = ({ />
) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/list/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/list/index.tsx index f21f65904b..b313cadbc8 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/list/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/list/index.tsx @@ -11,8 +11,8 @@ type FileListProps = { fileList: OnlineDriveFile[] selectedFileIds: string[] keywords: string - isInPipeline: boolean isLoading: boolean + supportBatchUpload: boolean handleResetKeywords: () => void handleSelectFile: (file: OnlineDriveFile) => void handleOpenFolder: (file: OnlineDriveFile) => void @@ -25,8 +25,8 @@ const List = ({ handleResetKeywords, handleSelectFile, handleOpenFolder, - isInPipeline, isLoading, + supportBatchUpload, }: FileListProps) => { const anchorRef = useRef(null) const observerRef = useRef(null) @@ -80,7 +80,7 @@ const List = ({ isSelected={isSelected} onSelect={handleSelectFile} onOpen={handleOpenFolder} - isMultipleChoice={!isInPipeline} + isMultipleChoice={supportBatchUpload} /> ) }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/index.tsx index da8fd5dcc0..1d279e146d 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/index.tsx @@ -20,14 +20,16 @@ import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/con type OnlineDriveProps = { nodeId: string nodeData: DataSourceNodeType - isInPipeline?: boolean onCredentialChange: (credentialId: string) => void + isInPipeline?: boolean + supportBatchUpload?: boolean } const OnlineDrive = ({ nodeId, nodeData, isInPipeline = false, + supportBatchUpload = true, onCredentialChange, }: OnlineDriveProps) => { const docLink = useDocLink() @@ -111,7 +113,7 @@ const OnlineDrive = ({ }, }, ) - }, [datasourceNodeRunURL, dataSourceStore]) + }, [dataSourceStore, datasourceNodeRunURL, breadcrumbs]) useEffect(() => { if (!currentCredentialId) return @@ -152,12 +154,12 @@ const OnlineDrive = ({ draft.splice(index, 1) } else { - if (isInPipeline && draft.length >= 1) return + if (!supportBatchUpload && draft.length >= 1) return draft.push(file.id) } }) setSelectedFileIds(newSelectedFileList) - }, [dataSourceStore, isInPipeline]) + }, [dataSourceStore, supportBatchUpload]) const handleOpenFolder = useCallback((file: OnlineDriveFile) => { const { breadcrumbs, prefix, setBreadcrumbs, setPrefix, setBucket, setOnlineDriveFileList, setSelectedFileIds } = dataSourceStore.getState() @@ -177,7 +179,7 @@ const OnlineDrive = ({ setBreadcrumbs(newBreadcrumbs) setPrefix(newPrefix) } - }, [dataSourceStore, getOnlineDriveFiles]) + }, [dataSourceStore]) const handleSetting = useCallback(() => { setShowAccountSettingModal({ @@ -209,6 +211,7 @@ const OnlineDrive = ({ handleOpenFolder={handleOpenFolder} isInPipeline={isInPipeline} isLoading={isLoading} + supportBatchUpload={supportBatchUpload} />
) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/crawled-result-item.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/crawled-result-item.tsx index 753b32c396..bdfcddfd77 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/crawled-result-item.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/crawled-result-item.tsx @@ -46,6 +46,7 @@ const CrawledResultItem = ({ /> ) : ( diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/index.tsx index 648f6a5d93..d9981a4638 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/index.tsx @@ -33,14 +33,16 @@ const I18N_PREFIX = 'datasetCreation.stepOne.website' export type WebsiteCrawlProps = { nodeId: string nodeData: DataSourceNodeType - isInPipeline?: boolean onCredentialChange: (credentialId: string) => void + isInPipeline?: boolean + supportBatchUpload?: boolean } const WebsiteCrawl = ({ nodeId, nodeData, isInPipeline = false, + supportBatchUpload = true, onCredentialChange, }: WebsiteCrawlProps) => { const { t } = useTranslation() @@ -122,7 +124,7 @@ const WebsiteCrawl = ({ time_consuming: time_consuming ?? 0, } setCrawlResult(crawlResultData) - handleCheckedCrawlResultChange(isInPipeline ? [crawlData[0]] : crawlData) // default select the crawl result + handleCheckedCrawlResultChange(supportBatchUpload ? crawlData : crawlData.slice(0, 1)) // default select the crawl result setCrawlErrorMessage('') setStep(CrawlStep.finished) }, @@ -132,7 +134,7 @@ const WebsiteCrawl = ({ }, }, ) - }, [dataSourceStore, datasourceNodeRunURL, handleCheckedCrawlResultChange, isInPipeline, t]) + }, [dataSourceStore, datasourceNodeRunURL, handleCheckedCrawlResultChange, supportBatchUpload, t]) const handleSubmit = useCallback((value: Record) => { handleRun(value) @@ -149,7 +151,7 @@ const WebsiteCrawl = ({ setTotalNum(0) setCrawlErrorMessage('') onCredentialChange(credentialId) - }, [dataSourceStore, onCredentialChange]) + }, [onCredentialChange]) return (
@@ -195,7 +197,7 @@ const WebsiteCrawl = ({ previewIndex={previewIndex} onPreview={handlePreview} showPreview={!isInPipeline} - isMultipleChoice={!isInPipeline} // only support single choice in test run + isMultipleChoice={supportBatchUpload} // only support single choice in test run /> )}
diff --git a/web/app/components/datasets/documents/create-from-pipeline/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/index.tsx index 77b77700ca..79e3694da8 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/index.tsx @@ -36,6 +36,10 @@ import { useAddDocumentsSteps, useLocalFile, useOnlineDocument, useOnlineDrive, import DataSourceProvider from './data-source/store/provider' import { useDataSourceStore } from './data-source/store' import { useFileUploadConfig } from '@/service/use-common' +import UpgradeCard from '../../create/step-one/upgrade-card' +import Divider from '@/app/components/base/divider' +import { useBoolean } from 'ahooks' +import PlanUpgradeModal from '@/app/components/billing/plan-upgrade-modal' const CreateFormPipeline = () => { const { t } = useTranslation() @@ -57,7 +61,7 @@ const CreateFormPipeline = () => { const { steps, currentStep, - handleNextStep, + handleNextStep: doHandleNextStep, handleBackStep, } = useAddDocumentsSteps() const { @@ -102,7 +106,34 @@ const CreateFormPipeline = () => { return onlineDriveFileList.length > 0 && isVectorSpaceFull && enableBilling return false }, [allFileLoaded, datasource, datasourceType, enableBilling, isVectorSpaceFull, onlineDocuments.length, onlineDriveFileList.length, websitePages.length]) - const notSupportBatchUpload = enableBilling && plan.type === 'sandbox' + const supportBatchUpload = !enableBilling || plan.type !== 'sandbox' + + const [isShowPlanUpgradeModal, { + setTrue: showPlanUpgradeModal, + setFalse: hidePlanUpgradeModal, + }] = useBoolean(false) + const handleNextStep = useCallback(() => { + if (!supportBatchUpload) { + let isMultiple = false + if (datasourceType === DatasourceType.localFile && localFileList.length > 1) + isMultiple = true + + if (datasourceType === DatasourceType.onlineDocument && onlineDocuments.length > 1) + isMultiple = true + + if (datasourceType === DatasourceType.websiteCrawl && websitePages.length > 1) + isMultiple = true + + if (datasourceType === DatasourceType.onlineDrive && selectedFileIds.length > 1) + isMultiple = true + + if (isMultiple) { + showPlanUpgradeModal() + return + } + } + doHandleNextStep() + }, [datasourceType, doHandleNextStep, localFileList.length, onlineDocuments.length, selectedFileIds.length, showPlanUpgradeModal, supportBatchUpload, websitePages.length]) const nextBtnDisabled = useMemo(() => { if (!datasource) return true @@ -133,6 +164,7 @@ const CreateFormPipeline = () => { return item.type !== 'bucket' }).length > 0 } + return false }, [currentWorkspace?.pages.length, datasourceType, onlineDriveFileList]) const totalOptions = useMemo(() => { @@ -389,13 +421,14 @@ const CreateFormPipeline = () => { }, [PagesMapAndSelectedPagesId, currentWorkspace?.pages, dataSourceStore, datasourceType]) const clearDataSourceData = useCallback((dataSource: Datasource) => { - if (dataSource.nodeData.provider_type === DatasourceType.onlineDocument) + const providerType = dataSource.nodeData.provider_type + if (providerType === DatasourceType.onlineDocument) clearOnlineDocumentData() - else if (dataSource.nodeData.provider_type === DatasourceType.websiteCrawl) + else if (providerType === DatasourceType.websiteCrawl) clearWebsiteCrawlData() - else if (dataSource.nodeData.provider_type === DatasourceType.onlineDrive) + else if (providerType === DatasourceType.onlineDrive) clearOnlineDriveData() - }, []) + }, [clearOnlineDocumentData, clearOnlineDriveData, clearWebsiteCrawlData]) const handleSwitchDataSource = useCallback((dataSource: Datasource) => { const { @@ -406,13 +439,13 @@ const CreateFormPipeline = () => { setCurrentCredentialId('') currentNodeIdRef.current = dataSource.nodeId setDatasource(dataSource) - }, [dataSourceStore]) + }, [clearDataSourceData, dataSourceStore]) const handleCredentialChange = useCallback((credentialId: string) => { const { setCurrentCredentialId } = dataSourceStore.getState() clearDataSourceData(datasource!) setCurrentCredentialId(credentialId) - }, [dataSourceStore, datasource]) + }, [clearDataSourceData, dataSourceStore, datasource]) if (isFetchingPipelineInfo) { return ( @@ -443,7 +476,7 @@ const CreateFormPipeline = () => { {datasourceType === DatasourceType.localFile && ( )} {datasourceType === DatasourceType.onlineDocument && ( @@ -479,6 +512,14 @@ const CreateFormPipeline = () => { handleNextStep={handleNextStep} tip={tip} /> + { + !supportBatchUpload && datasourceType === DatasourceType.localFile && localFileList.length > 0 && ( + <> + + + + ) + }
) } @@ -557,6 +598,14 @@ const CreateFormPipeline = () => {
) } + {isShowPlanUpgradeModal && ( + + )}
) } diff --git a/web/app/components/datasets/documents/create-from-pipeline/preview/web-preview.tsx b/web/app/components/datasets/documents/create-from-pipeline/preview/web-preview.tsx index bae4deb86e..ce7a5da24c 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/preview/web-preview.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/preview/web-preview.tsx @@ -27,7 +27,7 @@ const WebsitePreview = ({ {currentWebsite.source_url} · · - {`${formatNumberAbbreviated(currentWebsite.content.length)} ${t('datasetPipeline.addDocuments.characters')}`} + {`${formatNumberAbbreviated(currentWebsite.markdown.length)} ${t('datasetPipeline.addDocuments.characters')}`}
- {currentWebsite.content} + {currentWebsite.markdown}
) diff --git a/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx b/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx index 317db84c43..2049ae0d03 100644 --- a/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx +++ b/web/app/components/datasets/documents/detail/batch-modal/csv-uploader.tsx @@ -13,11 +13,10 @@ import Button from '@/app/components/base/button' import type { FileItem } from '@/models/datasets' import { upload } from '@/service/base' import { getFileUploadErrorMessage } from '@/app/components/base/file-uploader/utils' -import useSWR from 'swr' -import { fetchFileUploadConfig } from '@/service/common' import SimplePieChart from '@/app/components/base/simple-pie-chart' import { Theme } from '@/types/app' import useTheme from '@/hooks/use-theme' +import { useFileUploadConfig } from '@/service/use-common' export type Props = { file: FileItem | undefined @@ -34,7 +33,7 @@ const CSVUploader: FC = ({ const dropRef = useRef(null) const dragRef = useRef(null) const fileUploader = useRef(null) - const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig) + const { data: fileUploadConfigResponse } = useFileUploadConfig() const fileUploadConfig = useMemo(() => fileUploadConfigResponse ?? { file_size_limit: 15, }, [fileUploadConfigResponse]) diff --git a/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx b/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx index 4bed7b461d..c5d3bf5629 100644 --- a/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx +++ b/web/app/components/datasets/documents/detail/completed/common/action-buttons.tsx @@ -13,6 +13,7 @@ type IActionButtonsProps = { actionType?: 'edit' | 'add' handleRegeneration?: () => void isChildChunk?: boolean + showRegenerationButton?: boolean } const ActionButtons: FC = ({ @@ -22,6 +23,7 @@ const ActionButtons: FC = ({ actionType = 'edit', handleRegeneration, isChildChunk = false, + showRegenerationButton = true, }) => { const { t } = useTranslation() const docForm = useDocumentContext(s => s.docForm) @@ -54,7 +56,7 @@ const ActionButtons: FC = ({ ESC
- {(isParentChildParagraphMode && actionType === 'edit' && !isChildChunk) + {(isParentChildParagraphMode && actionType === 'edit' && !isChildChunk && showRegenerationButton) ? @@ -116,7 +137,7 @@ const SegmentAdd: FC = ({ } btnElement={
- +
} btnClassName={open => cn( @@ -129,7 +150,16 @@ const SegmentAdd: FC = ({ className='h-fit min-w-[128px]' disabled={embedding} /> + {isShowPlanUpgradeModal && ( + + )}
+ ) } export default React.memo(SegmentAdd) diff --git a/web/app/components/datasets/documents/detail/settings/document-settings.tsx b/web/app/components/datasets/documents/detail/settings/document-settings.tsx index 3bcb8ef3aa..16c90c925f 100644 --- a/web/app/components/datasets/documents/detail/settings/document-settings.tsx +++ b/web/app/components/datasets/documents/detail/settings/document-settings.tsx @@ -113,7 +113,7 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => { return [{ title: websiteInfo.title, source_url: websiteInfo.source_url, - content: websiteInfo.content, + markdown: websiteInfo.content, description: websiteInfo.description, }] }, [websiteInfo]) diff --git a/web/app/components/datasets/documents/detail/settings/pipeline-settings/index.tsx b/web/app/components/datasets/documents/detail/settings/pipeline-settings/index.tsx index 1ab47be445..0381222415 100644 --- a/web/app/components/datasets/documents/detail/settings/pipeline-settings/index.tsx +++ b/web/app/components/datasets/documents/detail/settings/pipeline-settings/index.tsx @@ -55,7 +55,7 @@ const PipelineSettings = ({ if (lastRunData?.datasource_type === DatasourceType.websiteCrawl) { const { content, description, source_url, title } = lastRunData.datasource_info websitePages.push({ - content, + markdown: content, description, source_url, title, @@ -135,7 +135,7 @@ const PipelineSettings = ({ push(`/datasets/${datasetId}/documents`) }, }) - }, [datasetId, invalidDocumentDetail, invalidDocumentList, lastRunData, pipelineId, push, runPublishedPipeline]) + }, [datasetId, documentId, invalidDocumentDetail, invalidDocumentList, lastRunData, pipelineId, push, runPublishedPipeline]) const onClickProcess = useCallback(() => { isPreview.current = false diff --git a/web/app/components/datasets/documents/list.tsx b/web/app/components/datasets/documents/list.tsx index 6f95d3cecb..5c9f832bb8 100644 --- a/web/app/components/datasets/documents/list.tsx +++ b/web/app/components/datasets/documents/list.tsx @@ -2,9 +2,9 @@ import type { FC } from 'react' import React, { useCallback, useEffect, useMemo, useState } from 'react' import { useBoolean } from 'ahooks' -import { ArrowDownIcon } from '@heroicons/react/24/outline' import { pick, uniq } from 'lodash-es' import { + RiArrowDownLine, RiEditLine, RiGlobalLine, } from '@remixicon/react' @@ -181,8 +181,8 @@ const DocumentList: FC = ({ return (
handleSort(field)}> {label} - ({ + useRouter: () => ({ + back: mockRouterBack, + replace: mockReplace, + push: jest.fn(), + refresh: jest.fn(), + }), +})) + +// Mock react-i18next +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock useDocLink hook +jest.mock('@/context/i18n', () => ({ + useDocLink: () => (path?: string) => `https://docs.dify.ai/en${path || ''}`, +})) + +// Mock toast context +const mockNotify = jest.fn() +jest.mock('@/app/components/base/toast', () => ({ + useToastContext: () => ({ + notify: mockNotify, + }), +})) + +// Mock modal context +jest.mock('@/context/modal-context', () => ({ + useModalContext: () => ({ + setShowExternalKnowledgeAPIModal: jest.fn(), + }), +})) + +// Mock API service +jest.mock('@/service/datasets', () => ({ + createExternalKnowledgeBase: jest.fn(), +})) + +// Factory function to create mock ExternalAPIItem +const createMockExternalAPIItem = (overrides: Partial = {}): ExternalAPIItem => ({ + id: 'api-default', + tenant_id: 'tenant-1', + name: 'Default API', + description: 'Default API description', + settings: { + endpoint: 'https://api.example.com', + api_key: 'test-api-key', + }, + dataset_bindings: [], + created_by: 'user-1', + created_at: '2024-01-01T00:00:00Z', + ...overrides, +}) + +// Default mock API list +const createDefaultMockApiList = (): ExternalAPIItem[] => [ + createMockExternalAPIItem({ + id: 'api-1', + name: 'Test API 1', + settings: { endpoint: 'https://api1.example.com', api_key: 'key-1' }, + }), + createMockExternalAPIItem({ + id: 'api-2', + name: 'Test API 2', + settings: { endpoint: 'https://api2.example.com', api_key: 'key-2' }, + }), +] + +let mockExternalKnowledgeApiList: ExternalAPIItem[] = createDefaultMockApiList() + +jest.mock('@/context/external-knowledge-api-context', () => ({ + useExternalKnowledgeApi: () => ({ + externalKnowledgeApiList: mockExternalKnowledgeApiList, + mutateExternalKnowledgeApis: jest.fn(), + isLoading: false, + }), +})) + +// Suppress console.error helper +const suppressConsoleError = () => jest.spyOn(console, 'error').mockImplementation(jest.fn()) + +// Helper to create a pending promise with external resolver +function createPendingPromise() { + let resolve: (value: T) => void = jest.fn() + const promise = new Promise((r) => { + resolve = r + }) + return { promise, resolve } +} + +// Helper to fill required form fields and submit +async function fillFormAndSubmit(user: ReturnType) { + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + const knowledgeIdInput = screen.getByPlaceholderText('dataset.externalKnowledgeIdPlaceholder') + + fireEvent.change(nameInput, { target: { value: 'Test Knowledge Base' } }) + fireEvent.change(knowledgeIdInput, { target: { value: 'kb-123' } }) + + // Wait for button to be enabled + await waitFor(() => { + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).not.toBeDisabled() + }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + await user.click(connectButton!) +} + +describe('ExternalKnowledgeBaseConnector', () => { + beforeEach(() => { + jest.clearAllMocks() + mockExternalKnowledgeApiList = createDefaultMockApiList() + ;(createExternalKnowledgeBase as jest.Mock).mockResolvedValue({ id: 'new-kb-id' }) + }) + + // Tests for rendering with real ExternalKnowledgeBaseCreate component + describe('Rendering', () => { + it('should render the create form with all required elements', () => { + render() + + // Verify main title and form elements + expect(screen.getByText('dataset.connectDataset')).toBeInTheDocument() + expect(screen.getByText('dataset.externalKnowledgeName')).toBeInTheDocument() + expect(screen.getByText('dataset.externalKnowledgeId')).toBeInTheDocument() + expect(screen.getByText('dataset.retrievalSettings')).toBeInTheDocument() + + // Verify buttons + expect(screen.getByText('dataset.externalKnowledgeForm.cancel')).toBeInTheDocument() + expect(screen.getByText('dataset.externalKnowledgeForm.connect')).toBeInTheDocument() + }) + + it('should render connect button disabled initially', () => { + render() + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).toBeDisabled() + }) + }) + + // Tests for API success flow + describe('API Success Flow', () => { + it('should call API and show success notification when form is submitted', async () => { + const user = userEvent.setup() + render() + + await fillFormAndSubmit(user) + + // Verify API was called with form data + await waitFor(() => { + expect(createExternalKnowledgeBase).toHaveBeenCalledWith({ + body: expect.objectContaining({ + name: 'Test Knowledge Base', + external_knowledge_id: 'kb-123', + external_knowledge_api_id: 'api-1', + provider: 'external', + }), + }) + }) + + // Verify success notification + expect(mockNotify).toHaveBeenCalledWith({ + type: 'success', + message: 'External Knowledge Base Connected Successfully', + }) + + // Verify navigation back + expect(mockRouterBack).toHaveBeenCalledTimes(1) + }) + + it('should include retrieval settings in API call', async () => { + const user = userEvent.setup() + render() + + await fillFormAndSubmit(user) + + await waitFor(() => { + expect(createExternalKnowledgeBase).toHaveBeenCalledWith({ + body: expect.objectContaining({ + external_retrieval_model: expect.objectContaining({ + top_k: 4, + score_threshold: 0.5, + score_threshold_enabled: false, + }), + }), + }) + }) + }) + }) + + // Tests for API error flow + describe('API Error Flow', () => { + it('should show error notification when API fails', async () => { + const user = userEvent.setup() + const consoleErrorSpy = suppressConsoleError() + ;(createExternalKnowledgeBase as jest.Mock).mockRejectedValue(new Error('Network Error')) + + render() + + await fillFormAndSubmit(user) + + // Verify error notification + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith({ + type: 'error', + message: 'Failed to connect External Knowledge Base', + }) + }) + + // Verify no navigation + expect(mockRouterBack).not.toHaveBeenCalled() + + consoleErrorSpy.mockRestore() + }) + + it('should show error notification when API returns invalid result', async () => { + const user = userEvent.setup() + const consoleErrorSpy = suppressConsoleError() + ;(createExternalKnowledgeBase as jest.Mock).mockResolvedValue({}) + + render() + + await fillFormAndSubmit(user) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith({ + type: 'error', + message: 'Failed to connect External Knowledge Base', + }) + }) + + expect(mockRouterBack).not.toHaveBeenCalled() + + consoleErrorSpy.mockRestore() + }) + }) + + // Tests for loading state + describe('Loading State', () => { + it('should show loading state during API call', async () => { + const user = userEvent.setup() + + // Create a promise that won't resolve immediately + const { promise, resolve: resolvePromise } = createPendingPromise<{ id: string }>() + ;(createExternalKnowledgeBase as jest.Mock).mockReturnValue(promise) + + render() + + // Fill form + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + const knowledgeIdInput = screen.getByPlaceholderText('dataset.externalKnowledgeIdPlaceholder') + fireEvent.change(nameInput, { target: { value: 'Test' } }) + fireEvent.change(knowledgeIdInput, { target: { value: 'kb-1' } }) + + await waitFor(() => { + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).not.toBeDisabled() + }) + + // Click connect + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + await user.click(connectButton!) + + // Button should show loading (the real Button component has loading prop) + await waitFor(() => { + expect(createExternalKnowledgeBase).toHaveBeenCalled() + }) + + // Resolve the promise + resolvePromise({ id: 'new-id' }) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith({ + type: 'success', + message: 'External Knowledge Base Connected Successfully', + }) + }) + }) + }) + + // Tests for form validation (integration with real create component) + describe('Form Validation', () => { + it('should keep button disabled when only name is filled', () => { + render() + + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + fireEvent.change(nameInput, { target: { value: 'Test' } }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).toBeDisabled() + }) + + it('should keep button disabled when only knowledge id is filled', () => { + render() + + const knowledgeIdInput = screen.getByPlaceholderText('dataset.externalKnowledgeIdPlaceholder') + fireEvent.change(knowledgeIdInput, { target: { value: 'kb-1' } }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).toBeDisabled() + }) + + it('should enable button when all required fields are filled', async () => { + render() + + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + const knowledgeIdInput = screen.getByPlaceholderText('dataset.externalKnowledgeIdPlaceholder') + + fireEvent.change(nameInput, { target: { value: 'Test' } }) + fireEvent.change(knowledgeIdInput, { target: { value: 'kb-1' } }) + + await waitFor(() => { + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).not.toBeDisabled() + }) + }) + }) + + // Tests for user interactions + describe('User Interactions', () => { + it('should allow typing in form fields', async () => { + const user = userEvent.setup() + render() + + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + const descriptionInput = screen.getByPlaceholderText('dataset.externalKnowledgeDescriptionPlaceholder') + + await user.type(nameInput, 'My Knowledge Base') + await user.type(descriptionInput, 'My Description') + + expect((nameInput as HTMLInputElement).value).toBe('My Knowledge Base') + expect((descriptionInput as HTMLTextAreaElement).value).toBe('My Description') + }) + + it('should handle cancel button click', async () => { + const user = userEvent.setup() + render() + + const cancelButton = screen.getByText('dataset.externalKnowledgeForm.cancel').closest('button') + await user.click(cancelButton!) + + expect(mockReplace).toHaveBeenCalledWith('/datasets') + }) + + it('should handle back button click', async () => { + const user = userEvent.setup() + render() + + const buttons = screen.getAllByRole('button') + const backButton = buttons.find(btn => btn.classList.contains('rounded-full')) + await user.click(backButton!) + + expect(mockReplace).toHaveBeenCalledWith('/datasets') + }) + }) +}) diff --git a/web/app/components/datasets/external-knowledge-base/create/index.spec.tsx b/web/app/components/datasets/external-knowledge-base/create/index.spec.tsx new file mode 100644 index 0000000000..c315743424 --- /dev/null +++ b/web/app/components/datasets/external-knowledge-base/create/index.spec.tsx @@ -0,0 +1,1059 @@ +import React from 'react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import type { ExternalAPIItem } from '@/models/datasets' +import ExternalKnowledgeBaseCreate from './index' + +// Mock next/navigation +const mockReplace = jest.fn() +const mockRefresh = jest.fn() +jest.mock('next/navigation', () => ({ + useRouter: () => ({ + replace: mockReplace, + push: jest.fn(), + refresh: mockRefresh, + }), +})) + +// Mock react-i18next +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock useDocLink hook +jest.mock('@/context/i18n', () => ({ + useDocLink: () => (path?: string) => `https://docs.dify.ai/en${path || ''}`, +})) + +// Mock external context providers (these are external dependencies) +const mockSetShowExternalKnowledgeAPIModal = jest.fn() +jest.mock('@/context/modal-context', () => ({ + useModalContext: () => ({ + setShowExternalKnowledgeAPIModal: mockSetShowExternalKnowledgeAPIModal, + }), +})) + +// Factory function to create mock ExternalAPIItem (following project conventions) +const createMockExternalAPIItem = (overrides: Partial = {}): ExternalAPIItem => ({ + id: 'api-default', + tenant_id: 'tenant-1', + name: 'Default API', + description: 'Default API description', + settings: { + endpoint: 'https://api.example.com', + api_key: 'test-api-key', + }, + dataset_bindings: [], + created_by: 'user-1', + created_at: '2024-01-01T00:00:00Z', + ...overrides, +}) + +// Default mock API list +const createDefaultMockApiList = (): ExternalAPIItem[] => [ + createMockExternalAPIItem({ + id: 'api-1', + name: 'Test API 1', + settings: { endpoint: 'https://api1.example.com', api_key: 'key-1' }, + }), + createMockExternalAPIItem({ + id: 'api-2', + name: 'Test API 2', + settings: { endpoint: 'https://api2.example.com', api_key: 'key-2' }, + }), +] + +const mockMutateExternalKnowledgeApis = jest.fn() +let mockExternalKnowledgeApiList: ExternalAPIItem[] = createDefaultMockApiList() + +jest.mock('@/context/external-knowledge-api-context', () => ({ + useExternalKnowledgeApi: () => ({ + externalKnowledgeApiList: mockExternalKnowledgeApiList, + mutateExternalKnowledgeApis: mockMutateExternalKnowledgeApis, + isLoading: false, + }), +})) + +// Helper to render component with default props +const renderComponent = (props: Partial> = {}) => { + const defaultProps = { + onConnect: jest.fn(), + loading: false, + } + return render() +} + +describe('ExternalKnowledgeBaseCreate', () => { + beforeEach(() => { + jest.clearAllMocks() + // Reset API list to default using factory function + mockExternalKnowledgeApiList = createDefaultMockApiList() + }) + + // Tests for basic rendering + describe('Rendering', () => { + it('should render without crashing', () => { + renderComponent() + + expect(screen.getByText('dataset.connectDataset')).toBeInTheDocument() + }) + + it('should render KnowledgeBaseInfo component with correct labels', () => { + renderComponent() + + // KnowledgeBaseInfo renders these labels + expect(screen.getByText('dataset.externalKnowledgeName')).toBeInTheDocument() + expect(screen.getByText('dataset.externalKnowledgeDescription')).toBeInTheDocument() + }) + + it('should render ExternalApiSelection component', () => { + renderComponent() + + // ExternalApiSelection renders this label + expect(screen.getByText('dataset.externalAPIPanelTitle')).toBeInTheDocument() + expect(screen.getByText('dataset.externalKnowledgeId')).toBeInTheDocument() + }) + + it('should render RetrievalSettings component', () => { + renderComponent() + + // RetrievalSettings renders this label + expect(screen.getByText('dataset.retrievalSettings')).toBeInTheDocument() + }) + + it('should render InfoPanel component', () => { + renderComponent() + + // InfoPanel renders these texts + expect(screen.getByText('dataset.connectDatasetIntro.title')).toBeInTheDocument() + expect(screen.getByText('dataset.connectDatasetIntro.learnMore')).toBeInTheDocument() + }) + + it('should render helper text with translation keys', () => { + renderComponent() + + expect(screen.getByText('dataset.connectHelper.helper1')).toBeInTheDocument() + expect(screen.getByText('dataset.connectHelper.helper2')).toBeInTheDocument() + expect(screen.getByText('dataset.connectHelper.helper3')).toBeInTheDocument() + expect(screen.getByText('dataset.connectHelper.helper4')).toBeInTheDocument() + expect(screen.getByText('dataset.connectHelper.helper5')).toBeInTheDocument() + }) + + it('should render cancel and connect buttons', () => { + renderComponent() + + expect(screen.getByText('dataset.externalKnowledgeForm.cancel')).toBeInTheDocument() + expect(screen.getByText('dataset.externalKnowledgeForm.connect')).toBeInTheDocument() + }) + + it('should render documentation link with correct href', () => { + renderComponent() + + const docLink = screen.getByText('dataset.connectHelper.helper4') + expect(docLink).toHaveAttribute('href', 'https://docs.dify.ai/en/guides/knowledge-base/connect-external-knowledge-base') + expect(docLink).toHaveAttribute('target', '_blank') + expect(docLink).toHaveAttribute('rel', 'noopener noreferrer') + }) + }) + + // Tests for props handling + describe('Props', () => { + it('should pass loading prop to connect button', () => { + renderComponent({ loading: true }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).toBeInTheDocument() + }) + + it('should call onConnect with form data when connect button is clicked', async () => { + const user = userEvent.setup() + const onConnect = jest.fn() + renderComponent({ onConnect }) + + // Fill in name field (using the actual Input component) + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + fireEvent.change(nameInput, { target: { value: 'Test Knowledge Base' } }) + + // Fill in external knowledge id + const knowledgeIdInput = screen.getByPlaceholderText('dataset.externalKnowledgeIdPlaceholder') + fireEvent.change(knowledgeIdInput, { target: { value: 'knowledge-456' } }) + + // Wait for useEffect to auto-select the first API + await waitFor(() => { + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).not.toBeDisabled() + }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + await user.click(connectButton!) + + expect(onConnect).toHaveBeenCalledWith( + expect.objectContaining({ + name: 'Test Knowledge Base', + external_knowledge_id: 'knowledge-456', + external_knowledge_api_id: 'api-1', // Auto-selected first API + provider: 'external', + }), + ) + }) + + it('should not call onConnect when form is invalid and button is disabled', async () => { + const user = userEvent.setup() + const onConnect = jest.fn() + renderComponent({ onConnect }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).toBeDisabled() + + await user.click(connectButton!) + expect(onConnect).not.toHaveBeenCalled() + }) + }) + + // Tests for state management with real child components + describe('State Management', () => { + it('should initialize form data with default values', () => { + renderComponent() + + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') as HTMLInputElement + const descriptionInput = screen.getByPlaceholderText('dataset.externalKnowledgeDescriptionPlaceholder') as HTMLTextAreaElement + + expect(nameInput.value).toBe('') + expect(descriptionInput.value).toBe('') + }) + + it('should update name when input changes', () => { + renderComponent() + + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + fireEvent.change(nameInput, { target: { value: 'New Name' } }) + + expect((nameInput as HTMLInputElement).value).toBe('New Name') + }) + + it('should update description when textarea changes', () => { + renderComponent() + + const descriptionInput = screen.getByPlaceholderText('dataset.externalKnowledgeDescriptionPlaceholder') + fireEvent.change(descriptionInput, { target: { value: 'New Description' } }) + + expect((descriptionInput as HTMLTextAreaElement).value).toBe('New Description') + }) + + it('should update external_knowledge_id when input changes', () => { + renderComponent() + + const knowledgeIdInput = screen.getByPlaceholderText('dataset.externalKnowledgeIdPlaceholder') + fireEvent.change(knowledgeIdInput, { target: { value: 'new-knowledge-id' } }) + + expect((knowledgeIdInput as HTMLInputElement).value).toBe('new-knowledge-id') + }) + + it('should apply filled text style when description has value', () => { + renderComponent() + + const descriptionInput = screen.getByPlaceholderText('dataset.externalKnowledgeDescriptionPlaceholder') as HTMLTextAreaElement + + // Initially empty - should have placeholder style + expect(descriptionInput.className).toContain('text-components-input-text-placeholder') + + // Add description - should have filled style + fireEvent.change(descriptionInput, { target: { value: 'Some description' } }) + expect(descriptionInput.className).toContain('text-components-input-text-filled') + }) + + it('should apply placeholder text style when description is empty', () => { + renderComponent() + + const descriptionInput = screen.getByPlaceholderText('dataset.externalKnowledgeDescriptionPlaceholder') as HTMLTextAreaElement + + // Add then clear description + fireEvent.change(descriptionInput, { target: { value: 'Some description' } }) + fireEvent.change(descriptionInput, { target: { value: '' } }) + + expect(descriptionInput.className).toContain('text-components-input-text-placeholder') + }) + }) + + // Tests for form validation + describe('Form Validation', () => { + it('should disable connect button when name is empty', async () => { + renderComponent() + + // Fill knowledge id but not name + const knowledgeIdInput = screen.getByPlaceholderText('dataset.externalKnowledgeIdPlaceholder') + fireEvent.change(knowledgeIdInput, { target: { value: 'knowledge-456' } }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).toBeDisabled() + }) + + it('should disable connect button when name is only whitespace', async () => { + renderComponent() + + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + const knowledgeIdInput = screen.getByPlaceholderText('dataset.externalKnowledgeIdPlaceholder') + + fireEvent.change(nameInput, { target: { value: ' ' } }) + fireEvent.change(knowledgeIdInput, { target: { value: 'knowledge-456' } }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).toBeDisabled() + }) + + it('should disable connect button when external_knowledge_id is empty', () => { + renderComponent() + + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + fireEvent.change(nameInput, { target: { value: 'Test Name' } }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).toBeDisabled() + }) + + it('should enable connect button when all required fields are filled', async () => { + renderComponent() + + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + const knowledgeIdInput = screen.getByPlaceholderText('dataset.externalKnowledgeIdPlaceholder') + + fireEvent.change(nameInput, { target: { value: 'Test Name' } }) + fireEvent.change(knowledgeIdInput, { target: { value: 'knowledge-456' } }) + + // Wait for auto-selection of API + await waitFor(() => { + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).not.toBeDisabled() + }) + }) + }) + + // Tests for user interactions + describe('User Interactions', () => { + it('should navigate back when back button is clicked', async () => { + const user = userEvent.setup() + renderComponent() + + const buttons = screen.getAllByRole('button') + const backButton = buttons.find(btn => btn.classList.contains('rounded-full')) + await user.click(backButton!) + + expect(mockReplace).toHaveBeenCalledWith('/datasets') + }) + + it('should navigate back when cancel button is clicked', async () => { + const user = userEvent.setup() + renderComponent() + + const cancelButton = screen.getByText('dataset.externalKnowledgeForm.cancel').closest('button') + await user.click(cancelButton!) + + expect(mockReplace).toHaveBeenCalledWith('/datasets') + }) + + it('should call onConnect with complete form data when connect is clicked', async () => { + const user = userEvent.setup() + const onConnect = jest.fn() + renderComponent({ onConnect }) + + // Fill all fields using real components + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + const descriptionInput = screen.getByPlaceholderText('dataset.externalKnowledgeDescriptionPlaceholder') + const knowledgeIdInput = screen.getByPlaceholderText('dataset.externalKnowledgeIdPlaceholder') + + fireEvent.change(nameInput, { target: { value: 'My Knowledge Base' } }) + fireEvent.change(descriptionInput, { target: { value: 'Test description' } }) + fireEvent.change(knowledgeIdInput, { target: { value: 'knowledge-abc' } }) + + await waitFor(() => { + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).not.toBeDisabled() + }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + await user.click(connectButton!) + + expect(onConnect).toHaveBeenCalledWith( + expect.objectContaining({ + name: 'My Knowledge Base', + description: 'Test description', + external_knowledge_id: 'knowledge-abc', + provider: 'external', + }), + ) + }) + + it('should allow user to type in all input fields', async () => { + const user = userEvent.setup() + renderComponent() + + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + const descriptionInput = screen.getByPlaceholderText('dataset.externalKnowledgeDescriptionPlaceholder') + const knowledgeIdInput = screen.getByPlaceholderText('dataset.externalKnowledgeIdPlaceholder') + + await user.type(nameInput, 'Typed Name') + await user.type(descriptionInput, 'Typed Description') + await user.type(knowledgeIdInput, 'typed-knowledge') + + expect((nameInput as HTMLInputElement).value).toBe('Typed Name') + expect((descriptionInput as HTMLTextAreaElement).value).toBe('Typed Description') + expect((knowledgeIdInput as HTMLInputElement).value).toBe('typed-knowledge') + }) + }) + + // Tests for ExternalApiSelection integration + describe('ExternalApiSelection Integration', () => { + it('should auto-select first API when API list is available', async () => { + const user = userEvent.setup() + const onConnect = jest.fn() + renderComponent({ onConnect }) + + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + const knowledgeIdInput = screen.getByPlaceholderText('dataset.externalKnowledgeIdPlaceholder') + + fireEvent.change(nameInput, { target: { value: 'Test' } }) + fireEvent.change(knowledgeIdInput, { target: { value: 'kb-1' } }) + + await waitFor(() => { + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).not.toBeDisabled() + }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + await user.click(connectButton!) + + // Should have auto-selected the first API + expect(onConnect).toHaveBeenCalledWith( + expect.objectContaining({ + external_knowledge_api_id: 'api-1', + }), + ) + }) + + it('should display API selector when APIs are available', () => { + renderComponent() + + // The ExternalApiSelect should show the first selected API name + expect(screen.getByText('Test API 1')).toBeInTheDocument() + }) + + it('should allow selecting different API from dropdown', async () => { + const user = userEvent.setup() + const onConnect = jest.fn() + renderComponent({ onConnect }) + + // Click on the API selector to open dropdown + const apiSelector = screen.getByText('Test API 1') + await user.click(apiSelector) + + // Select the second API + const secondApi = screen.getByText('Test API 2') + await user.click(secondApi) + + // Fill required fields + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + const knowledgeIdInput = screen.getByPlaceholderText('dataset.externalKnowledgeIdPlaceholder') + + fireEvent.change(nameInput, { target: { value: 'Test' } }) + fireEvent.change(knowledgeIdInput, { target: { value: 'kb-1' } }) + + await waitFor(() => { + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).not.toBeDisabled() + }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + await user.click(connectButton!) + + // Should have selected the second API + expect(onConnect).toHaveBeenCalledWith( + expect.objectContaining({ + external_knowledge_api_id: 'api-2', + }), + ) + }) + + it('should show add API button when no APIs are available', () => { + // Set empty API list + mockExternalKnowledgeApiList = [] + renderComponent() + + // Should show "no external knowledge" button + expect(screen.getByText('dataset.noExternalKnowledge')).toBeInTheDocument() + }) + + it('should open add API modal when add button is clicked', async () => { + const user = userEvent.setup() + // Set empty API list + mockExternalKnowledgeApiList = [] + renderComponent() + + // Click the add button + const addButton = screen.getByText('dataset.noExternalKnowledge').closest('button') + await user.click(addButton!) + + // Should call the modal context function + expect(mockSetShowExternalKnowledgeAPIModal).toHaveBeenCalledWith( + expect.objectContaining({ + payload: { name: '', settings: { endpoint: '', api_key: '' } }, + isEditMode: false, + }), + ) + }) + + it('should call mutate and router.refresh on modal save callback', async () => { + const user = userEvent.setup() + // Set empty API list + mockExternalKnowledgeApiList = [] + renderComponent() + + // Click the add button + const addButton = screen.getByText('dataset.noExternalKnowledge').closest('button') + await user.click(addButton!) + + // Get the callback and invoke it + const modalCall = mockSetShowExternalKnowledgeAPIModal.mock.calls[0][0] + await modalCall.onSaveCallback() + + expect(mockMutateExternalKnowledgeApis).toHaveBeenCalled() + expect(mockRefresh).toHaveBeenCalled() + }) + + it('should call mutate on modal cancel callback', async () => { + const user = userEvent.setup() + // Set empty API list + mockExternalKnowledgeApiList = [] + renderComponent() + + // Click the add button + const addButton = screen.getByText('dataset.noExternalKnowledge').closest('button') + await user.click(addButton!) + + // Get the callback and invoke it + const modalCall = mockSetShowExternalKnowledgeAPIModal.mock.calls[0][0] + modalCall.onCancelCallback() + + expect(mockMutateExternalKnowledgeApis).toHaveBeenCalled() + }) + + it('should display API URL in dropdown', async () => { + const user = userEvent.setup() + renderComponent() + + // Click on the API selector to open dropdown + const apiSelector = screen.getByText('Test API 1') + await user.click(apiSelector) + + // Should show API URLs + expect(screen.getByText('https://api1.example.com')).toBeInTheDocument() + expect(screen.getByText('https://api2.example.com')).toBeInTheDocument() + }) + + it('should show create new API option in dropdown', async () => { + const user = userEvent.setup() + renderComponent() + + // Click on the API selector to open dropdown + const apiSelector = screen.getByText('Test API 1') + await user.click(apiSelector) + + // Should show create new API option + expect(screen.getByText('dataset.createNewExternalAPI')).toBeInTheDocument() + }) + + it('should open add API modal when clicking create new API in dropdown', async () => { + const user = userEvent.setup() + renderComponent() + + // Click on the API selector to open dropdown + const apiSelector = screen.getByText('Test API 1') + await user.click(apiSelector) + + // Click on create new API option + const createNewApiOption = screen.getByText('dataset.createNewExternalAPI') + await user.click(createNewApiOption) + + // Should call the modal context function + expect(mockSetShowExternalKnowledgeAPIModal).toHaveBeenCalledWith( + expect.objectContaining({ + payload: { name: '', settings: { endpoint: '', api_key: '' } }, + isEditMode: false, + }), + ) + }) + + it('should call mutate and refresh on save callback from ExternalApiSelect dropdown', async () => { + const user = userEvent.setup() + renderComponent() + + // Click on the API selector to open dropdown + const apiSelector = screen.getByText('Test API 1') + await user.click(apiSelector) + + // Click on create new API option + const createNewApiOption = screen.getByText('dataset.createNewExternalAPI') + await user.click(createNewApiOption) + + // Get the callback from the modal call and invoke it + const modalCall = mockSetShowExternalKnowledgeAPIModal.mock.calls[0][0] + await modalCall.onSaveCallback() + + expect(mockMutateExternalKnowledgeApis).toHaveBeenCalled() + expect(mockRefresh).toHaveBeenCalled() + }) + + it('should call mutate on cancel callback from ExternalApiSelect dropdown', async () => { + const user = userEvent.setup() + renderComponent() + + // Click on the API selector to open dropdown + const apiSelector = screen.getByText('Test API 1') + await user.click(apiSelector) + + // Click on create new API option + const createNewApiOption = screen.getByText('dataset.createNewExternalAPI') + await user.click(createNewApiOption) + + // Get the callback from the modal call and invoke it + const modalCall = mockSetShowExternalKnowledgeAPIModal.mock.calls[0][0] + modalCall.onCancelCallback() + + expect(mockMutateExternalKnowledgeApis).toHaveBeenCalled() + }) + + it('should close dropdown after selecting an API', async () => { + const user = userEvent.setup() + renderComponent() + + // Click on the API selector to open dropdown + const apiSelector = screen.getByText('Test API 1') + await user.click(apiSelector) + + // Dropdown should be open - API URLs visible + expect(screen.getByText('https://api1.example.com')).toBeInTheDocument() + + // Select the second API + const secondApi = screen.getByText('Test API 2') + await user.click(secondApi) + + // Dropdown should be closed - API URLs not visible + expect(screen.queryByText('https://api1.example.com')).not.toBeInTheDocument() + }) + + it('should toggle dropdown open/close on selector click', async () => { + const user = userEvent.setup() + renderComponent() + + // Click to open + const apiSelector = screen.getByText('Test API 1') + await user.click(apiSelector) + expect(screen.getByText('https://api1.example.com')).toBeInTheDocument() + + // Click again to close + await user.click(apiSelector) + expect(screen.queryByText('https://api1.example.com')).not.toBeInTheDocument() + }) + }) + + // Tests for callback stability + describe('Callback Stability', () => { + it('should maintain stable navBackHandle callback reference', async () => { + const user = userEvent.setup() + const { rerender } = render( + , + ) + + const buttons = screen.getAllByRole('button') + const backButton = buttons.find(btn => btn.classList.contains('rounded-full')) + await user.click(backButton!) + + expect(mockReplace).toHaveBeenCalledTimes(1) + + rerender() + + await user.click(backButton!) + expect(mockReplace).toHaveBeenCalledTimes(2) + }) + + it('should not recreate handlers on prop changes', async () => { + const user = userEvent.setup() + const onConnect1 = jest.fn() + const onConnect2 = jest.fn() + + const { rerender } = render( + , + ) + + // Fill form + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + const knowledgeIdInput = screen.getByPlaceholderText('dataset.externalKnowledgeIdPlaceholder') + + fireEvent.change(nameInput, { target: { value: 'Test' } }) + fireEvent.change(knowledgeIdInput, { target: { value: 'knowledge' } }) + + // Rerender with new callback + rerender() + + await waitFor(() => { + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).not.toBeDisabled() + }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + await user.click(connectButton!) + + // Should use the new callback + expect(onConnect1).not.toHaveBeenCalled() + expect(onConnect2).toHaveBeenCalled() + }) + }) + + // Tests for edge cases + describe('Edge Cases', () => { + it('should handle empty description gracefully', async () => { + const user = userEvent.setup() + const onConnect = jest.fn() + renderComponent({ onConnect }) + + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + const knowledgeIdInput = screen.getByPlaceholderText('dataset.externalKnowledgeIdPlaceholder') + + fireEvent.change(nameInput, { target: { value: 'Test' } }) + fireEvent.change(knowledgeIdInput, { target: { value: 'knowledge' } }) + + await waitFor(() => { + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).not.toBeDisabled() + }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + await user.click(connectButton!) + + expect(onConnect).toHaveBeenCalledWith( + expect.objectContaining({ + description: '', + }), + ) + }) + + it('should handle special characters in name', () => { + renderComponent() + + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + const specialName = 'Test Name' + + fireEvent.change(nameInput, { target: { value: specialName } }) + + expect((nameInput as HTMLInputElement).value).toBe(specialName) + }) + + it('should handle very long input values', () => { + renderComponent() + + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + const longName = 'A'.repeat(1000) + + fireEvent.change(nameInput, { target: { value: longName } }) + + expect((nameInput as HTMLInputElement).value).toBe(longName) + }) + + it('should handle rapid sequential updates', () => { + renderComponent() + + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + + // Rapid updates + for (let i = 0; i < 10; i++) + fireEvent.change(nameInput, { target: { value: `Name ${i}` } }) + + expect((nameInput as HTMLInputElement).value).toBe('Name 9') + }) + + it('should preserve provider value as external', async () => { + const user = userEvent.setup() + const onConnect = jest.fn() + renderComponent({ onConnect }) + + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + const knowledgeIdInput = screen.getByPlaceholderText('dataset.externalKnowledgeIdPlaceholder') + + fireEvent.change(nameInput, { target: { value: 'Test' } }) + fireEvent.change(knowledgeIdInput, { target: { value: 'knowledge' } }) + + await waitFor(() => { + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).not.toBeDisabled() + }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + await user.click(connectButton!) + + expect(onConnect).toHaveBeenCalledWith( + expect.objectContaining({ + provider: 'external', + }), + ) + }) + }) + + // Tests for loading state + describe('Loading State', () => { + it('should pass loading state to connect button', () => { + renderComponent({ loading: true }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).toBeInTheDocument() + }) + + it('should render correctly when not loading', () => { + renderComponent({ loading: false }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).toBeInTheDocument() + }) + }) + + // Tests for RetrievalSettings integration + describe('RetrievalSettings Integration', () => { + it('should toggle score threshold enabled when switch is clicked', async () => { + const user = userEvent.setup() + const onConnect = jest.fn() + renderComponent({ onConnect }) + + // Find and click the switch for score threshold + const switches = screen.getAllByRole('switch') + const scoreThresholdSwitch = switches[0] // The score threshold switch + await user.click(scoreThresholdSwitch) + + // Fill required fields + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + const knowledgeIdInput = screen.getByPlaceholderText('dataset.externalKnowledgeIdPlaceholder') + + fireEvent.change(nameInput, { target: { value: 'Test' } }) + fireEvent.change(knowledgeIdInput, { target: { value: 'kb-1' } }) + + await waitFor(() => { + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).not.toBeDisabled() + }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + await user.click(connectButton!) + + expect(onConnect).toHaveBeenCalledWith( + expect.objectContaining({ + external_retrieval_model: expect.objectContaining({ + score_threshold_enabled: true, + }), + }), + ) + }) + + it('should display retrieval settings labels', () => { + renderComponent() + + // Should show the retrieval settings section title + expect(screen.getByText('dataset.retrievalSettings')).toBeInTheDocument() + // Should show Top K and Score Threshold labels + expect(screen.getByText('appDebug.datasetConfig.top_k')).toBeInTheDocument() + expect(screen.getByText('appDebug.datasetConfig.score_threshold')).toBeInTheDocument() + }) + }) + + // Direct unit tests for RetrievalSettings component to cover all branches + describe('RetrievalSettings Component Direct Tests', () => { + // Import RetrievalSettings directly for unit testing + const RetrievalSettings = require('./RetrievalSettings').default + + it('should render with isInHitTesting mode', () => { + const onChange = jest.fn() + render( + , + ) + + // In hit testing mode, the title should not be shown + expect(screen.queryByText('dataset.retrievalSettings')).not.toBeInTheDocument() + }) + + it('should render with isInRetrievalSetting mode', () => { + const onChange = jest.fn() + render( + , + ) + + // In retrieval setting mode, the title should not be shown + expect(screen.queryByText('dataset.retrievalSettings')).not.toBeInTheDocument() + }) + + it('should call onChange with score_threshold_enabled when switch is toggled', async () => { + const user = userEvent.setup() + const onChange = jest.fn() + render( + , + ) + + // Find and click the switch + const switches = screen.getAllByRole('switch') + await user.click(switches[0]) + + expect(onChange).toHaveBeenCalledWith({ score_threshold_enabled: true }) + }) + + it('should call onChange with top_k when top k value changes', () => { + const onChange = jest.fn() + render( + , + ) + + // The TopKItem should render an input + const inputs = screen.getAllByRole('spinbutton') + const topKInput = inputs[0] + fireEvent.change(topKInput, { target: { value: '8' } }) + + expect(onChange).toHaveBeenCalledWith({ top_k: 8 }) + }) + + it('should call onChange with score_threshold when threshold value changes', () => { + const onChange = jest.fn() + render( + , + ) + + // The ScoreThresholdItem should render an input + const inputs = screen.getAllByRole('spinbutton') + const scoreThresholdInput = inputs[1] + fireEvent.change(scoreThresholdInput, { target: { value: '0.8' } }) + + expect(onChange).toHaveBeenCalledWith({ score_threshold: 0.8 }) + }) + }) + + // Tests for complete form submission flow + describe('Complete Form Submission Flow', () => { + it('should submit form with all default retrieval settings', async () => { + const user = userEvent.setup() + const onConnect = jest.fn() + renderComponent({ onConnect }) + + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + const knowledgeIdInput = screen.getByPlaceholderText('dataset.externalKnowledgeIdPlaceholder') + + fireEvent.change(nameInput, { target: { value: 'Test KB' } }) + fireEvent.change(knowledgeIdInput, { target: { value: 'kb-1' } }) + + await waitFor(() => { + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).not.toBeDisabled() + }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + await user.click(connectButton!) + + expect(onConnect).toHaveBeenCalledWith({ + name: 'Test KB', + description: '', + external_knowledge_api_id: 'api-1', + external_knowledge_id: 'kb-1', + external_retrieval_model: { + top_k: 4, + score_threshold: 0.5, + score_threshold_enabled: false, + }, + provider: 'external', + }) + }) + + it('should submit form with modified retrieval settings', async () => { + const user = userEvent.setup() + const onConnect = jest.fn() + renderComponent({ onConnect }) + + // Toggle score threshold switch + const switches = screen.getAllByRole('switch') + const scoreThresholdSwitch = switches[0] + await user.click(scoreThresholdSwitch) + + // Fill required fields + const nameInput = screen.getByPlaceholderText('dataset.externalKnowledgeNamePlaceholder') + const knowledgeIdInput = screen.getByPlaceholderText('dataset.externalKnowledgeIdPlaceholder') + + fireEvent.change(nameInput, { target: { value: 'Custom KB' } }) + fireEvent.change(knowledgeIdInput, { target: { value: 'custom-kb' } }) + + await waitFor(() => { + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + expect(connectButton).not.toBeDisabled() + }) + + const connectButton = screen.getByText('dataset.externalKnowledgeForm.connect').closest('button') + await user.click(connectButton!) + + expect(onConnect).toHaveBeenCalledWith( + expect.objectContaining({ + name: 'Custom KB', + external_retrieval_model: expect.objectContaining({ + score_threshold_enabled: true, + }), + }), + ) + }) + }) + + // Tests for accessibility + describe('Accessibility', () => { + it('should have accessible buttons', () => { + renderComponent() + + const buttons = screen.getAllByRole('button') + expect(buttons.length).toBeGreaterThanOrEqual(3) // back, cancel, connect + }) + + it('should have proper link attributes for external links', () => { + renderComponent() + + const externalLink = screen.getByText('dataset.connectHelper.helper4') + expect(externalLink.tagName).toBe('A') + expect(externalLink).toHaveAttribute('target', '_blank') + expect(externalLink).toHaveAttribute('rel', 'noopener noreferrer') + }) + + it('should have labels for form inputs', () => { + renderComponent() + + // Check labels exist + expect(screen.getByText('dataset.externalKnowledgeName')).toBeInTheDocument() + expect(screen.getByText('dataset.externalKnowledgeDescription')).toBeInTheDocument() + expect(screen.getByText('dataset.externalKnowledgeId')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/datasets/hit-testing/components/chunk-detail-modal.tsx b/web/app/components/datasets/hit-testing/components/chunk-detail-modal.tsx index ab848a5871..fb67089890 100644 --- a/web/app/components/datasets/hit-testing/components/chunk-detail-modal.tsx +++ b/web/app/components/datasets/hit-testing/components/chunk-detail-modal.tsx @@ -1,6 +1,5 @@ 'use client' -import type { FC } from 'react' -import React from 'react' +import React, { useMemo } from 'react' import { useTranslation } from 'react-i18next' import { SegmentIndexTag } from '../../documents/detail/completed/common/segment-index-tag' import Dot from '../../documents/detail/completed/common/dot' @@ -13,25 +12,42 @@ import type { FileAppearanceTypeEnum } from '@/app/components/base/file-uploader import cn from '@/utils/classnames' import Tag from '@/app/components/datasets/documents/detail/completed/common/tag' import { Markdown } from '@/app/components/base/markdown' +import ImageList from '../../common/image-list' +import Mask from './mask' const i18nPrefix = 'datasetHitTesting' -type Props = { +type ChunkDetailModalProps = { payload: HitTesting onHide: () => void } -const ChunkDetailModal: FC = ({ +const ChunkDetailModal = ({ payload, onHide, -}) => { +}: ChunkDetailModalProps) => { const { t } = useTranslation() - const { segment, score, child_chunks } = payload + const { segment, score, child_chunks, files } = payload const { position, content, sign_content, keywords, document, answer } = segment const isParentChildRetrieval = !!(child_chunks && child_chunks.length > 0) const extension = document.name.split('.').slice(-1)[0] as FileAppearanceTypeEnum const heighClassName = isParentChildRetrieval ? 'h-[min(627px,_80vh)] overflow-y-auto' : 'h-[min(539px,_80vh)] overflow-y-auto' const labelPrefix = isParentChildRetrieval ? t('datasetDocuments.segment.parentChunk') : t('datasetDocuments.segment.chunk') + + const images = useMemo(() => { + if (!files) return [] + return files.map(file => ({ + name: file.name, + mimeType: file.mime_type, + sourceUrl: file.source_url, + size: file.size, + extension: file.extension, + })) + }, [files]) + + const showImages = images.length > 0 + const showKeywords = !isParentChildRetrieval && keywords && keywords.length > 0 + return ( = ({
- {!answer && ( - - )} - {answer && ( -
-
-
Q
-
- {content} + {/* Content */} +
+ {!answer && ( + + )} + {answer && ( +
+
+
Q
+
+ {content} +
+
+
+
A
+
+ {answer} +
-
-
A
-
- {answer} + )} + {/* Mask */} + +
+ {(showImages || showKeywords) && ( +
+ {showImages && ( + + )} + {showKeywords && ( +
+
{t(`${i18nPrefix}.keyword`)}
+
+ {keywords.map(keyword => ( + + ))} +
-
-
- )} - {!isParentChildRetrieval && keywords && keywords.length > 0 && ( -
-
{t(`${i18nPrefix}.keyword`)}
-
- {keywords.map(keyword => ( - - ))} -
+ )}
)}
diff --git a/web/app/components/datasets/hit-testing/components/empty-records.tsx b/web/app/components/datasets/hit-testing/components/empty-records.tsx new file mode 100644 index 0000000000..db7d724b17 --- /dev/null +++ b/web/app/components/datasets/hit-testing/components/empty-records.tsx @@ -0,0 +1,15 @@ +import { RiHistoryLine } from '@remixicon/react' +import React from 'react' +import { useTranslation } from 'react-i18next' + +const EmptyRecords = () => { + const { t } = useTranslation() + return
+
+ +
+
{t('datasetHitTesting.noRecentTip')}
+
+} + +export default React.memo(EmptyRecords) diff --git a/web/app/components/datasets/hit-testing/components/mask.tsx b/web/app/components/datasets/hit-testing/components/mask.tsx new file mode 100644 index 0000000000..799d7656b2 --- /dev/null +++ b/web/app/components/datasets/hit-testing/components/mask.tsx @@ -0,0 +1,19 @@ +import React from 'react' +import cn from '@/utils/classnames' + +type MaskProps = { + className?: string +} + +export const Mask = ({ + className, +}: MaskProps) => { + return ( +
+ ) +} + +export default React.memo(Mask) diff --git a/web/app/components/datasets/hit-testing/components/query-input/index.tsx b/web/app/components/datasets/hit-testing/components/query-input/index.tsx new file mode 100644 index 0000000000..75b59fe09a --- /dev/null +++ b/web/app/components/datasets/hit-testing/components/query-input/index.tsx @@ -0,0 +1,257 @@ +import type { ChangeEvent } from 'react' +import React, { useCallback, useMemo, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { + RiEqualizer2Line, + RiPlayCircleLine, +} from '@remixicon/react' +import Image from 'next/image' +import Button from '@/app/components/base/button' +import { getIcon } from '@/app/components/datasets/common/retrieval-method-info' +import ModifyExternalRetrievalModal from '@/app/components/datasets/hit-testing/modify-external-retrieval-modal' +import cn from '@/utils/classnames' +import type { + Attachment, + ExternalKnowledgeBaseHitTestingRequest, + ExternalKnowledgeBaseHitTestingResponse, + HitTestingRequest, + HitTestingResponse, + Query, +} from '@/models/datasets' +import { RETRIEVE_METHOD, type RetrievalConfig } from '@/types/app' +import type { UseMutateAsyncFunction } from '@tanstack/react-query' +import ImageUploaderInRetrievalTesting from '@/app/components/datasets/common/image-uploader/image-uploader-in-retrieval-testing' +import Textarea from './textarea' +import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' +import type { FileEntity } from '@/app/components/datasets/common/image-uploader/types' +import { v4 as uuid4 } from 'uuid' + +type QueryInputProps = { + onUpdateList: () => void + setHitResult: (res: HitTestingResponse) => void + setExternalHitResult: (res: ExternalKnowledgeBaseHitTestingResponse) => void + loading: boolean + queries: Query[] + setQueries: (v: Query[]) => void + isExternal?: boolean + onClickRetrievalMethod: () => void + retrievalConfig: RetrievalConfig + isEconomy: boolean + onSubmit?: () => void + hitTestingMutation: UseMutateAsyncFunction + externalKnowledgeBaseHitTestingMutation: UseMutateAsyncFunction< + ExternalKnowledgeBaseHitTestingResponse, + Error, + ExternalKnowledgeBaseHitTestingRequest, + unknown + > +} + +const QueryInput = ({ + onUpdateList, + setHitResult, + setExternalHitResult, + loading, + queries, + setQueries, + isExternal = false, + onClickRetrievalMethod, + retrievalConfig, + isEconomy, + onSubmit: _onSubmit, + hitTestingMutation, + externalKnowledgeBaseHitTestingMutation, +}: QueryInputProps) => { + const { t } = useTranslation() + const isMultimodal = useDatasetDetailContextWithSelector(s => !!s.dataset?.is_multimodal) + const [isSettingsOpen, setIsSettingsOpen] = useState(false) + const [externalRetrievalSettings, setExternalRetrievalSettings] = useState({ + top_k: 4, + score_threshold: 0.5, + score_threshold_enabled: false, + }) + + const text = useMemo(() => { + return queries.find(query => query.content_type === 'text_query')?.content ?? '' + }, [queries]) + + const images = useMemo(() => { + const imageQueries = queries + .filter(query => query.content_type === 'image_query') + .map(query => query.file_info) + .filter(Boolean) as Attachment[] + return imageQueries.map(item => ({ + id: uuid4(), + name: item.name, + size: item.size, + mimeType: item.mime_type, + extension: item.extension, + sourceUrl: item.source_url, + uploadedId: item.id, + progress: 100, + })) || [] + }, [queries]) + + const isAllUploaded = useMemo(() => { + return images.every(image => !!image.uploadedId) + }, [images]) + + const handleSaveExternalRetrievalSettings = useCallback((data: { + top_k: number + score_threshold: number + score_threshold_enabled: boolean + }) => { + setExternalRetrievalSettings(data) + setIsSettingsOpen(false) + }, []) + + const handleTextChange = useCallback((event: ChangeEvent) => { + const newQueries = [...queries] + const textQuery = newQueries.find(query => query.content_type === 'text_query') + if (!textQuery) { + newQueries.push({ + content: event.target.value, + content_type: 'text_query', + file_info: null, + }) + } + else { + textQuery.content = event.target.value + } + setQueries(newQueries) + }, [queries, setQueries]) + + const handleImageChange = useCallback((files: FileEntity[]) => { + let newQueries = [...queries] + newQueries = newQueries.filter(query => query.content_type !== 'image_query') + files.forEach((file) => { + newQueries.push({ + content: file.sourceUrl || '', + content_type: 'image_query', + file_info: { + id: file.uploadedId || '', + mime_type: file.mimeType, + source_url: file.sourceUrl || '', + name: file.name, + size: file.size, + extension: file.extension, + }, + }) + }) + setQueries(newQueries) + }, [queries, setQueries]) + + const onSubmit = useCallback(async () => { + await hitTestingMutation({ + query: text, + attachment_ids: images.map(image => image.uploadedId), + retrieval_model: { + ...retrievalConfig, + search_method: isEconomy ? RETRIEVE_METHOD.keywordSearch : retrievalConfig.search_method, + }, + }, { + onSuccess: (data) => { + setHitResult(data) + onUpdateList?.() + if (_onSubmit) + _onSubmit() + }, + }) + }, [text, retrievalConfig, isEconomy, hitTestingMutation, onUpdateList, _onSubmit, images, setHitResult]) + + const externalRetrievalTestingOnSubmit = useCallback(async () => { + await externalKnowledgeBaseHitTestingMutation({ + query: text, + external_retrieval_model: { + top_k: externalRetrievalSettings.top_k, + score_threshold: externalRetrievalSettings.score_threshold, + score_threshold_enabled: externalRetrievalSettings.score_threshold_enabled, + }, + }, { + onSuccess: (data) => { + setExternalHitResult(data) + onUpdateList?.() + }, + }) + }, [text, externalRetrievalSettings, externalKnowledgeBaseHitTestingMutation, onUpdateList, setExternalHitResult]) + + const retrievalMethod = isEconomy ? RETRIEVE_METHOD.keywordSearch : retrievalConfig.search_method + const icon = + const TextAreaComp = useMemo(() => { + return ( +