diff --git a/AGENTS.md b/AGENTS.md index deab7c8629..7d96ac3a6d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -25,6 +25,30 @@ pnpm type-check:tsgo pnpm test ``` +### Frontend Linting + +ESLint is used for frontend code quality. Available commands: + +```bash +# Lint all files (report only) +pnpm lint + +# Lint and auto-fix issues +pnpm lint:fix + +# Lint specific files or directories +pnpm lint:fix app/components/base/button/ +pnpm lint:fix app/components/base/button/index.tsx + +# Lint quietly (errors only, no warnings) +pnpm lint:quiet + +# Check code complexity +pnpm lint:complexity +``` + +**Important**: Always run `pnpm lint:fix` before committing. The pre-commit hook runs `lint-staged` which only lints staged files. + ## Testing & Quality Practices - Follow TDD: red → green → refactor. diff --git a/agent-notes/.gitkeep b/agent-notes/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/AGENTS.md b/api/AGENTS.md index 6ce419828b..13adb42276 100644 --- a/api/AGENTS.md +++ b/api/AGENTS.md @@ -1,97 +1,47 @@ # API Agent Guide -## Agent Notes (must-check) +## Notes for Agent (must-check) -Before you start work on any backend file under `api/`, you MUST check whether a related note exists under: +Before changing any backend code under `api/`, you MUST read the surrounding docstrings and comments. These notes contain required context (invariants, edge cases, trade-offs) and are treated as part of the spec. -- `agent-notes/.md` +Look for: -Rules: +- The module (file) docstring at the top of a source code file +- Docstrings on classes and functions/methods +- Paragraph/block comments for non-obvious logic -- **Path mapping**: for a target file `/.py`, the note must be `agent-notes//.py.md` (same folder structure, same filename, plus `.md`). -- **Before working**: - - If the note exists, read it first and follow any constraints/decisions recorded there. - - If the note conflicts with the current code, or references an "origin" file/path that has been deleted, renamed, or migrated, treat the **code as the single source of truth** and update the note to match reality. - - If the note does not exist, create it with a short architecture/intent summary and any relevant invariants/edge cases. -- **During working**: - - Keep the note in sync as you discover constraints, make decisions, or change approach. - - If you move/rename a file, migrate its note to the new mapped path (and fix any outdated references inside the note). - - Record non-obvious edge cases, trade-offs, and the test/verification plan as you go (not just at the end). - - Keep notes **coherent**: integrate new findings into the relevant sections and rewrite for clarity; avoid append-only “recent fix” / changelog-style additions unless the note is explicitly intended to be a changelog. -- **When finishing work**: - - Update the related note(s) to reflect what changed, why, and any new edge cases/tests. - - If a file is deleted, remove or clearly deprecate the corresponding note so it cannot be mistaken as current guidance. - - Keep notes concise and accurate; they are meant to prevent repeated rediscovery. +### What to write where -## Skill Index +- Keep notes scoped: module notes cover module-wide context, class notes cover class-wide context, function/method notes cover behavioural contracts, and paragraph/block comments cover local “why”. Avoid duplicating the same content across scopes unless repetition prevents misuse. +- **Module (file) docstring**: purpose, boundaries, key invariants, and “gotchas” that a new reader must know before editing. + - Include cross-links to the key collaborators (modules/services) when discovery is otherwise hard. + - Prefer stable facts (invariants, contracts) over ephemeral “today we…” notes. +- **Class docstring**: responsibility, lifecycle, invariants, and how it should be used (or not used). + - If the class is intentionally stateful, note what state exists and what methods mutate it. + - If concurrency/async assumptions matter, state them explicitly. +- **Function/method docstring**: behavioural contract. + - Document arguments, return shape, side effects (DB writes, external I/O, task dispatch), and raised domain exceptions. + - Add examples only when they prevent misuse. +- **Paragraph/block comments**: explain *why* (trade-offs, historical constraints, surprising edge cases), not what the code already states. + - Keep comments adjacent to the logic they justify; delete or rewrite comments that no longer match reality. -Start with the section that best matches your need. Each entry lists the problems it solves plus key files/concepts so you know what to expect before opening it. +### Rules (must follow) -### Platform Foundations +In this section, “notes” means module/class/function docstrings plus any relevant paragraph/block comments. -#### [Infrastructure Overview](agent_skills/infra.md) - -- **When to read this** - - You need to understand where a feature belongs in the architecture. - - You’re wiring storage, Redis, vector stores, or OTEL. - - You’re about to add CLI commands or async jobs. -- **What it covers** - - Configuration stack (`configs/app_config.py`, remote settings) - - Storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`) - - Redis conventions (`extensions/ext_redis.py`) - - Plugin runtime topology - - Vector-store factory (`core/rag/datasource/vdb/*`) - - Observability hooks - - SSRF proxy usage - - Core CLI commands - -### Plugin & Extension Development - -#### [Plugin Systems](agent_skills/plugin.md) - -- **When to read this** - - You’re building or debugging a marketplace plugin. - - You need to know how manifests, providers, daemons, and migrations fit together. -- **What it covers** - - Plugin manifests (`core/plugin/entities/plugin.py`) - - Installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands) - - Runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent) - - Daemon coordination (`core/plugin/entities/plugin_daemon.py`) - - How provider registries surface capabilities to the rest of the platform - -#### [Plugin OAuth](agent_skills/plugin_oauth.md) - -- **When to read this** - - You must integrate OAuth for a plugin or datasource. - - You’re handling credential encryption or refresh flows. -- **Topics** - - Credential storage - - Encryption helpers (`core/helper/provider_encryption.py`) - - OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`) - - How console/API layers expose the flows - -### Workflow Entry & Execution - -#### [Trigger Concepts](agent_skills/trigger.md) - -- **When to read this** - - You’re debugging why a workflow didn’t start. - - You’re adding a new trigger type or hook. - - You need to trace async execution, draft debugging, or webhook/schedule pipelines. -- **Details** - - Start-node taxonomy - - Webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`) - - Async orchestration (`services/async_workflow_service.py`, Celery queues) - - Debug event bus - - Storage/logging interactions - -## General Reminders - -- All skill docs assume you follow the coding style rules below—run the lint/type/test commands before submitting changes. -- When you cannot find an answer in these briefs, search the codebase using the paths referenced (e.g., `core/plugin/impl/tool.py`, `services/dataset_service.py`). -- If you run into cross-cutting concerns (tenancy, configuration, storage), check the infrastructure guide first; it links to most supporting modules. -- Keep multi-tenancy and configuration central: everything flows through `configs.dify_config` and `tenant_id`. -- When touching plugins or triggers, consult both the system overview and the specialised doc to ensure you adjust lifecycle, storage, and observability consistently. +- **Before working** + - Read the notes in the area you’ll touch; treat them as part of the spec. + - If a docstring or comment conflicts with the current code, treat the **code as the single source of truth** and update the docstring or comment to match reality. + - If important intent/invariants/edge cases are missing, add them in the closest docstring or comment (module for overall scope, function for behaviour). +- **During working** + - Keep the notes in sync as you discover constraints, make decisions, or change approach. + - If you move/rename responsibilities across modules/classes, update the affected docstrings and comments so readers can still find the “why” and the invariants. + - Record non-obvious edge cases, trade-offs, and the test/verification plan in the nearest docstring or comment that will stay correct. + - Keep the notes **coherent**: integrate new findings into the relevant docstrings and comments; avoid append-only “recent fix” / changelog-style additions. +- **When finishing** + - Update the notes to reflect what changed, why, and any new edge cases/tests. + - Remove or rewrite any comments that could be mistaken as current guidance but no longer apply. + - Keep docstrings and comments concise and accurate; they are meant to prevent repeated rediscovery. ## Coding Style @@ -226,7 +176,7 @@ Before opening a PR / submitting: - Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic. - Services: coordinate repositories, providers, background tasks; keep side effects explicit. -- Document non-obvious behaviour with concise comments. +- Document non-obvious behaviour with concise docstrings and comments. ### Miscellaneous diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index cd958bbb36..d05e726dcb 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -36,6 +36,16 @@ class NotionEstimatePayload(BaseModel): doc_language: str = Field(default="English") +class DataSourceNotionListQuery(BaseModel): + dataset_id: str | None = Field(default=None, description="Dataset ID") + credential_id: str = Field(..., description="Credential ID", min_length=1) + datasource_parameters: dict[str, Any] | None = Field(default=None, description="Datasource parameters JSON string") + + +class DataSourceNotionPreviewQuery(BaseModel): + credential_id: str = Field(..., description="Credential ID", min_length=1) + + register_schema_model(console_ns, NotionEstimatePayload) @@ -136,26 +146,15 @@ class DataSourceNotionListApi(Resource): def get(self): current_user, current_tenant_id = current_account_with_tenant() - dataset_id = request.args.get("dataset_id", default=None, type=str) - credential_id = request.args.get("credential_id", default=None, type=str) - if not credential_id: - raise ValueError("Credential id is required.") + query = DataSourceNotionListQuery.model_validate(request.args.to_dict()) # Get datasource_parameters from query string (optional, for GitHub and other datasources) - datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str) - datasource_parameters = {} - if datasource_parameters_str: - try: - datasource_parameters = json.loads(datasource_parameters_str) - if not isinstance(datasource_parameters, dict): - raise ValueError("datasource_parameters must be a JSON object.") - except json.JSONDecodeError: - raise ValueError("Invalid datasource_parameters JSON format.") + datasource_parameters = query.datasource_parameters or {} datasource_provider_service = DatasourceProviderService() credential = datasource_provider_service.get_datasource_credentials( tenant_id=current_tenant_id, - credential_id=credential_id, + credential_id=query.credential_id, provider="notion_datasource", plugin_id="langgenius/notion_datasource", ) @@ -164,8 +163,8 @@ class DataSourceNotionListApi(Resource): exist_page_ids = [] with Session(db.engine) as session: # import notion in the exist dataset - if dataset_id: - dataset = DatasetService.get_dataset(dataset_id) + if query.dataset_id: + dataset = DatasetService.get_dataset(query.dataset_id) if not dataset: raise NotFound("Dataset not found.") if dataset.data_source_type != "notion_import": @@ -173,7 +172,7 @@ class DataSourceNotionListApi(Resource): documents = session.scalars( select(Document).filter_by( - dataset_id=dataset_id, + dataset_id=query.dataset_id, tenant_id=current_tenant_id, data_source_type="notion_import", enabled=True, @@ -240,13 +239,12 @@ class DataSourceNotionApi(Resource): def get(self, page_id, page_type): _, current_tenant_id = current_account_with_tenant() - credential_id = request.args.get("credential_id", default=None, type=str) - if not credential_id: - raise ValueError("Credential id is required.") + query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict()) + datasource_provider_service = DatasourceProviderService() credential = datasource_provider_service.get_datasource_credentials( tenant_id=current_tenant_id, - credential_id=credential_id, + credential_id=query.credential_id, provider="notion_datasource", plugin_id="langgenius/notion_datasource", ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 8ceb896d4f..e9371b608c 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -176,7 +176,18 @@ class IndexingEstimatePayload(BaseModel): return result -register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload) +class ConsoleDatasetListQuery(BaseModel): + page: int = Field(default=1, description="Page number") + limit: int = Field(default=20, description="Number of items per page") + keyword: str | None = Field(default=None, description="Search keyword") + include_all: bool = Field(default=False, description="Include all datasets") + ids: list[str] = Field(default_factory=list, description="Filter by dataset IDs") + tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs") + + +register_schema_models( + console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload, ConsoleDatasetListQuery +) def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]: @@ -275,18 +286,19 @@ class DatasetListApi(Resource): @enterprise_license_required def get(self): current_user, current_tenant_id = current_account_with_tenant() - page = request.args.get("page", default=1, type=int) - limit = request.args.get("limit", default=20, type=int) - ids = request.args.getlist("ids") + query = ConsoleDatasetListQuery.model_validate(request.args.to_dict(flat=False)) # provider = request.args.get("provider", default="vendor") - search = request.args.get("keyword", default=None, type=str) - tag_ids = request.args.getlist("tag_ids") - include_all = request.args.get("include_all", default="false").lower() == "true" - if ids: - datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id) + if query.ids: + datasets, total = DatasetService.get_datasets_by_ids(query.ids, current_tenant_id) else: datasets, total = DatasetService.get_datasets( - page, limit, current_tenant_id, current_user, search, tag_ids, include_all + query.page, + query.limit, + current_tenant_id, + current_user, + query.keyword, + query.tag_ids, + query.include_all, ) # check embedding setting @@ -318,7 +330,13 @@ class DatasetListApi(Resource): else: item.update({"partial_member_list": []}) - response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} + response = { + "data": data, + "has_more": len(datasets) == query.limit, + "limit": query.limit, + "total": total, + "page": query.page, + } return response, 200 @console_ns.doc("create_dataset") diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index a70a7ce480..588eb6e1b8 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -98,12 +98,19 @@ class BedrockRetrievalPayload(BaseModel): knowledge_id: str +class ExternalApiTemplateListQuery(BaseModel): + page: int = Field(default=1, description="Page number") + limit: int = Field(default=20, description="Number of items per page") + keyword: str | None = Field(default=None, description="Search keyword") + + register_schema_models( console_ns, ExternalKnowledgeApiPayload, ExternalDatasetCreatePayload, ExternalHitTestingPayload, BedrockRetrievalPayload, + ExternalApiTemplateListQuery, ) @@ -124,19 +131,17 @@ class ExternalApiTemplateListApi(Resource): @account_initialization_required def get(self): _, current_tenant_id = current_account_with_tenant() - page = request.args.get("page", default=1, type=int) - limit = request.args.get("limit", default=20, type=int) - search = request.args.get("keyword", default=None, type=str) + query = ExternalApiTemplateListQuery.model_validate(request.args.to_dict()) external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis( - page, limit, current_tenant_id, search + query.page, query.limit, current_tenant_id, query.keyword ) response = { "data": [item.to_dict() for item in external_knowledge_apis], - "has_more": len(external_knowledge_apis) == limit, - "limit": limit, + "has_more": len(external_knowledge_apis) == query.limit, + "limit": query.limit, "total": total, - "page": page, + "page": query.page, } return response, 200 diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index e42db10ba6..b77eac605e 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -3,7 +3,7 @@ from typing import Any from flask import request from flask_restx import Resource, marshal_with -from pydantic import BaseModel +from pydantic import BaseModel, Field from sqlalchemy import and_, select from werkzeug.exceptions import BadRequest, Forbidden, NotFound @@ -28,6 +28,10 @@ class InstalledAppUpdatePayload(BaseModel): is_pinned: bool | None = None +class InstalledAppsListQuery(BaseModel): + app_id: str | None = Field(default=None, description="App ID to filter by") + + logger = logging.getLogger(__name__) @@ -37,13 +41,13 @@ class InstalledAppsListApi(Resource): @account_initialization_required @marshal_with(installed_app_list_fields) def get(self): - app_id = request.args.get("app_id", default=None, type=str) + query = InstalledAppsListQuery.model_validate(request.args.to_dict()) current_user, current_tenant_id = current_account_with_tenant() - if app_id: + if query.app_id: installed_apps = db.session.scalars( select(InstalledApp).where( - and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id) + and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == query.app_id) ) ).all() else: diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index ed22ef045d..e1ea007232 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,20 +1,19 @@ +from typing import Literal + from flask import request -from flask_restx import Resource, fields from pydantic import BaseModel, Field, field_validator from configs import dify_config +from controllers.fastopenapi import console_router from libs.helper import EmailStr, extract_remote_ip from libs.password import valid_password from models.model import DifySetup, db from services.account_service import RegisterService, TenantService -from . import console_ns from .error import AlreadySetupError, NotInitValidateError from .init_validate import get_init_validate_status from .wraps import only_edition_self_hosted -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class SetupRequestPayload(BaseModel): email: EmailStr = Field(..., description="Admin email address") @@ -28,78 +27,66 @@ class SetupRequestPayload(BaseModel): return valid_password(value) -console_ns.schema_model( - SetupRequestPayload.__name__, - SetupRequestPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +class SetupStatusResponse(BaseModel): + step: Literal["not_started", "finished"] = Field(description="Setup step status") + setup_at: str | None = Field(default=None, description="Setup completion time (ISO format)") + + +class SetupResponse(BaseModel): + result: str = Field(description="Setup result", examples=["success"]) + + +@console_router.get( + "/setup", + response_model=SetupStatusResponse, + tags=["console"], ) +def get_setup_status_api() -> SetupStatusResponse: + """Get system setup status.""" + if dify_config.EDITION == "SELF_HOSTED": + setup_status = get_setup_status() + if setup_status and not isinstance(setup_status, bool): + return SetupStatusResponse(step="finished", setup_at=setup_status.setup_at.isoformat()) + if setup_status: + return SetupStatusResponse(step="finished") + return SetupStatusResponse(step="not_started") + return SetupStatusResponse(step="finished") -@console_ns.route("/setup") -class SetupApi(Resource): - @console_ns.doc("get_setup_status") - @console_ns.doc(description="Get system setup status") - @console_ns.response( - 200, - "Success", - console_ns.model( - "SetupStatusResponse", - { - "step": fields.String(description="Setup step status", enum=["not_started", "finished"]), - "setup_at": fields.String(description="Setup completion time (ISO format)", required=False), - }, - ), +@console_router.post( + "/setup", + response_model=SetupResponse, + tags=["console"], + status_code=201, +) +@only_edition_self_hosted +def setup_system(payload: SetupRequestPayload) -> SetupResponse: + """Initialize system setup with admin account.""" + if get_setup_status(): + raise AlreadySetupError() + + tenant_count = TenantService.get_tenant_count() + if tenant_count > 0: + raise AlreadySetupError() + + if not get_init_validate_status(): + raise NotInitValidateError() + + normalized_email = payload.email.lower() + + RegisterService.setup( + email=normalized_email, + name=payload.name, + password=payload.password, + ip_address=extract_remote_ip(request), + language=payload.language, ) - def get(self): - """Get system setup status""" - if dify_config.EDITION == "SELF_HOSTED": - setup_status = get_setup_status() - # Check if setup_status is a DifySetup object rather than a bool - if setup_status and not isinstance(setup_status, bool): - return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()} - elif setup_status: - return {"step": "finished"} - return {"step": "not_started"} - return {"step": "finished"} - @console_ns.doc("setup_system") - @console_ns.doc(description="Initialize system setup with admin account") - @console_ns.expect(console_ns.models[SetupRequestPayload.__name__]) - @console_ns.response( - 201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")}) - ) - @console_ns.response(400, "Already setup or validation failed") - @only_edition_self_hosted - def post(self): - """Initialize system setup with admin account""" - # is set up - if get_setup_status(): - raise AlreadySetupError() - - # is tenant created - tenant_count = TenantService.get_tenant_count() - if tenant_count > 0: - raise AlreadySetupError() - - if not get_init_validate_status(): - raise NotInitValidateError() - - args = SetupRequestPayload.model_validate(console_ns.payload) - normalized_email = args.email.lower() - - # setup - RegisterService.setup( - email=normalized_email, - name=args.name, - password=args.password, - ip_address=extract_remote_ip(request), - language=args.language, - ) - - return {"result": "success"}, 201 + return SetupResponse(result="success") -def get_setup_status(): +def get_setup_status() -> DifySetup | bool | None: if dify_config.EDITION == "SELF_HOSTED": return db.session.query(DifySetup).first() - else: - return True + + return True diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 023ffc991a..9988524a80 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -40,6 +40,7 @@ register_schema_models( TagBasePayload, TagBindingPayload, TagBindingRemovePayload, + TagListQueryParam, ) diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 419261ba2a..fdb23acf52 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -1,15 +1,11 @@ -import json import logging import httpx -from flask import request -from flask_restx import Resource, fields from packaging import version from pydantic import BaseModel, Field from configs import dify_config - -from . import console_ns +from controllers.fastopenapi import console_router logger = logging.getLogger(__name__) @@ -18,69 +14,61 @@ class VersionQuery(BaseModel): current_version: str = Field(..., description="Current application version") -console_ns.schema_model( - VersionQuery.__name__, - VersionQuery.model_json_schema(ref_template="#/definitions/{model}"), +class VersionFeatures(BaseModel): + can_replace_logo: bool = Field(description="Whether logo replacement is supported") + model_load_balancing_enabled: bool = Field(description="Whether model load balancing is enabled") + + +class VersionResponse(BaseModel): + version: str = Field(description="Latest version number") + release_date: str = Field(description="Release date of latest version") + release_notes: str = Field(description="Release notes for latest version") + can_auto_update: bool = Field(description="Whether auto-update is supported") + features: VersionFeatures = Field(description="Feature flags and capabilities") + + +@console_router.get( + "/version", + response_model=VersionResponse, + tags=["console"], ) +def check_version_update(query: VersionQuery) -> VersionResponse: + """Check for application version updates.""" + check_update_url = dify_config.CHECK_UPDATE_URL - -@console_ns.route("/version") -class VersionApi(Resource): - @console_ns.doc("check_version_update") - @console_ns.doc(description="Check for application version updates") - @console_ns.expect(console_ns.models[VersionQuery.__name__]) - @console_ns.response( - 200, - "Success", - console_ns.model( - "VersionResponse", - { - "version": fields.String(description="Latest version number"), - "release_date": fields.String(description="Release date of latest version"), - "release_notes": fields.String(description="Release notes for latest version"), - "can_auto_update": fields.Boolean(description="Whether auto-update is supported"), - "features": fields.Raw(description="Feature flags and capabilities"), - }, + result = VersionResponse( + version=dify_config.project.version, + release_date="", + release_notes="", + can_auto_update=False, + features=VersionFeatures( + can_replace_logo=dify_config.CAN_REPLACE_LOGO, + model_load_balancing_enabled=dify_config.MODEL_LB_ENABLED, ), ) - def get(self): - """Check for application version updates""" - args = VersionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - check_update_url = dify_config.CHECK_UPDATE_URL - result = { - "version": dify_config.project.version, - "release_date": "", - "release_notes": "", - "can_auto_update": False, - "features": { - "can_replace_logo": dify_config.CAN_REPLACE_LOGO, - "model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED, - }, - } - - if not check_update_url: - return result - - try: - response = httpx.get( - check_update_url, - params={"current_version": args.current_version}, - timeout=httpx.Timeout(timeout=10.0, connect=3.0), - ) - except Exception as error: - logger.warning("Check update version error: %s.", str(error)) - result["version"] = args.current_version - return result - - content = json.loads(response.content) - if _has_new_version(latest_version=content["version"], current_version=f"{args.current_version}"): - result["version"] = content["version"] - result["release_date"] = content["releaseDate"] - result["release_notes"] = content["releaseNotes"] - result["can_auto_update"] = content["canAutoUpdate"] + if not check_update_url: return result + try: + response = httpx.get( + check_update_url, + params={"current_version": query.current_version}, + timeout=httpx.Timeout(timeout=10.0, connect=3.0), + ) + content = response.json() + except Exception as error: + logger.warning("Check update version error: %s.", str(error)) + result.version = query.current_version + return result + latest_version = content.get("version", result.version) + if _has_new_version(latest_version=latest_version, current_version=f"{query.current_version}"): + result.version = latest_version + result.release_date = content.get("releaseDate", "") + result.release_notes = content.get("releaseNotes", "") + result.can_auto_update = content.get("canAutoUpdate", False) + return result + def _has_new_version(*, latest_version: str, current_version: str) -> bool: try: diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 94faf8dd42..b036a71f18 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -87,6 +87,14 @@ class TagUnbindingPayload(BaseModel): target_id: str +class DatasetListQuery(BaseModel): + page: int = Field(default=1, description="Page number") + limit: int = Field(default=20, description="Number of items per page") + keyword: str | None = Field(default=None, description="Search keyword") + include_all: bool = Field(default=False, description="Include all datasets") + tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs") + + register_schema_models( service_api_ns, DatasetCreatePayload, @@ -96,6 +104,7 @@ register_schema_models( TagDeletePayload, TagBindingPayload, TagUnbindingPayload, + DatasetListQuery, ) @@ -113,15 +122,11 @@ class DatasetListApi(DatasetApiResource): ) def get(self, tenant_id): """Resource for getting datasets.""" - page = request.args.get("page", default=1, type=int) - limit = request.args.get("limit", default=20, type=int) + query = DatasetListQuery.model_validate(request.args.to_dict(flat=False)) # provider = request.args.get("provider", default="vendor") - search = request.args.get("keyword", default=None, type=str) - tag_ids = request.args.getlist("tag_ids") - include_all = request.args.get("include_all", default="false").lower() == "true" datasets, total = DatasetService.get_datasets( - page, limit, tenant_id, current_user, search, tag_ids, include_all + query.page, query.limit, tenant_id, current_user, query.keyword, query.tag_ids, query.include_all ) # check embedding setting provider_manager = ProviderManager() @@ -147,7 +152,13 @@ class DatasetListApi(DatasetApiResource): item["embedding_available"] = False else: item["embedding_available"] = True - response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} + response = { + "data": data, + "has_more": len(datasets) == query.limit, + "limit": query.limit, + "total": total, + "page": query.page, + } return response, 200 @service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__]) diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 49ff4f57dc..1260645624 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -69,7 +69,14 @@ class DocumentTextUpdate(BaseModel): return self -for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate]: +class DocumentListQuery(BaseModel): + page: int = Field(default=1, description="Page number") + limit: int = Field(default=20, description="Number of items per page") + keyword: str | None = Field(default=None, description="Search keyword") + status: str | None = Field(default=None, description="Document status filter") + + +for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate, DocumentListQuery]: service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore @@ -460,34 +467,33 @@ class DocumentListApi(DatasetApiResource): def get(self, tenant_id, dataset_id): dataset_id = str(dataset_id) tenant_id = str(tenant_id) - page = request.args.get("page", default=1, type=int) - limit = request.args.get("limit", default=20, type=int) - search = request.args.get("keyword", default=None, type=str) - status = request.args.get("status", default=None, type=str) + query_params = DocumentListQuery.model_validate(request.args.to_dict()) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) - if status: - query = DocumentService.apply_display_status_filter(query, status) + if query_params.status: + query = DocumentService.apply_display_status_filter(query, query_params.status) - if search: - search = f"%{search}%" + if query_params.keyword: + search = f"%{query_params.keyword}%" query = query.where(Document.name.like(search)) query = query.order_by(desc(Document.created_at), desc(Document.position)) - paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_documents = db.paginate( + select=query, page=query_params.page, per_page=query_params.limit, max_per_page=100, error_out=False + ) documents = paginated_documents.items response = { "data": marshal(documents, document_fields), - "has_more": len(documents) == limit, - "limit": limit, + "has_more": len(documents) == query_params.limit, + "limit": query_params.limit, "total": paginated_documents.total, - "page": page, + "page": query_params.page, } return response diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index aab25c1af3..b8d9508004 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -11,7 +11,9 @@ from controllers.service_api.wraps import DatasetApiResource, cloud_edition_bill from fields.dataset_fields import dataset_metadata_fields from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, MetadataArgs, + MetadataDetail, MetadataOperationData, ) from services.metadata_service import MetadataService @@ -22,7 +24,13 @@ class MetadataUpdatePayload(BaseModel): register_schema_model(service_api_ns, MetadataUpdatePayload) -register_schema_models(service_api_ns, MetadataArgs, MetadataOperationData) +register_schema_models( + service_api_ns, + MetadataArgs, + MetadataDetail, + DocumentMetadataOperation, + MetadataOperationData, +) @service_api_ns.route("/datasets//metadata") diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 2760466a3b..8b6b8f227b 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -236,4 +236,7 @@ class AgentChatAppRunner(AppRunner): queue_manager=queue_manager, stream=application_generate_entity.stream, agent=True, + message_id=message.id, + user_id=application_generate_entity.user_id, + tenant_id=app_config.tenant_id, ) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index e2e6c11480..617515945b 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,6 +1,8 @@ +import base64 import logging import time from collections.abc import Generator, Mapping, Sequence +from mimetypes import guess_extension from typing import TYPE_CHECKING, Any, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity @@ -11,10 +13,16 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, ModelConfigWithCredentialsEntity, ) -from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent +from core.app.entities.queue_entities import ( + QueueAgentMessageEvent, + QueueLLMChunkEvent, + QueueMessageEndEvent, + QueueMessageFileEvent, +) from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.external_data_tool.external_data_fetch import ExternalDataFetch +from core.file.enums import FileTransferMethod, FileType from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -22,6 +30,7 @@ from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, + TextPromptMessageContent, ) from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.invoke import InvokeBadRequestError @@ -29,7 +38,10 @@ from core.moderation.input_moderation import InputModeration from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform -from models.model import App, AppMode, Message, MessageAnnotation +from core.tools.tool_file_manager import ToolFileManager +from extensions.ext_database import db +from models.enums import CreatorUserRole +from models.model import App, AppMode, Message, MessageAnnotation, MessageFile if TYPE_CHECKING: from core.file.models import File @@ -203,6 +215,9 @@ class AppRunner: queue_manager: AppQueueManager, stream: bool, agent: bool = False, + message_id: str | None = None, + user_id: str | None = None, + tenant_id: str | None = None, ): """ Handle invoke result @@ -210,21 +225,41 @@ class AppRunner: :param queue_manager: application queue manager :param stream: stream :param agent: agent + :param message_id: message id for multimodal output + :param user_id: user id for multimodal output + :param tenant_id: tenant id for multimodal output :return: """ if not stream and isinstance(invoke_result, LLMResult): - self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) + self._handle_invoke_result_direct( + invoke_result=invoke_result, + queue_manager=queue_manager, + ) elif stream and isinstance(invoke_result, Generator): - self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) + self._handle_invoke_result_stream( + invoke_result=invoke_result, + queue_manager=queue_manager, + agent=agent, + message_id=message_id, + user_id=user_id, + tenant_id=tenant_id, + ) else: raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") - def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool): + def _handle_invoke_result_direct( + self, + invoke_result: LLMResult, + queue_manager: AppQueueManager, + ): """ Handle invoke result direct :param invoke_result: invoke result :param queue_manager: application queue manager :param agent: agent + :param message_id: message id for multimodal output + :param user_id: user id for multimodal output + :param tenant_id: tenant id for multimodal output :return: """ queue_manager.publish( @@ -235,13 +270,22 @@ class AppRunner: ) def _handle_invoke_result_stream( - self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool + self, + invoke_result: Generator[LLMResultChunk, None, None], + queue_manager: AppQueueManager, + agent: bool, + message_id: str | None = None, + user_id: str | None = None, + tenant_id: str | None = None, ): """ Handle invoke result :param invoke_result: invoke result :param queue_manager: application queue manager :param agent: agent + :param message_id: message id for multimodal output + :param user_id: user id for multimodal output + :param tenant_id: tenant id for multimodal output :return: """ model: str = "" @@ -259,12 +303,26 @@ class AppRunner: text += message.content elif isinstance(message.content, list): for content in message.content: - if not isinstance(content, str): - # TODO(QuantumGhost): Add multimodal output support for easy ui. - _logger.warning("received multimodal output, type=%s", type(content)) + if isinstance(content, str): + text += content + elif isinstance(content, TextPromptMessageContent): text += content.data + elif isinstance(content, ImagePromptMessageContent): + if message_id and user_id and tenant_id: + try: + self._handle_multimodal_image_content( + content=content, + message_id=message_id, + user_id=user_id, + tenant_id=tenant_id, + queue_manager=queue_manager, + ) + except Exception: + _logger.exception("Failed to handle multimodal image output") + else: + _logger.warning("Received multimodal output but missing required parameters") else: - text += content # failback to str + text += content.data if hasattr(content, "data") else str(content) if not model: model = result.model @@ -289,6 +347,101 @@ class AppRunner: PublishFrom.APPLICATION_MANAGER, ) + def _handle_multimodal_image_content( + self, + content: ImagePromptMessageContent, + message_id: str, + user_id: str, + tenant_id: str, + queue_manager: AppQueueManager, + ): + """ + Handle multimodal image content from LLM response. + Save the image and create a MessageFile record. + + :param content: ImagePromptMessageContent instance + :param message_id: message id + :param user_id: user id + :param tenant_id: tenant id + :param queue_manager: queue manager + :return: + """ + _logger.info("Handling multimodal image content for message %s", message_id) + + image_url = content.url + base64_data = content.base64_data + + _logger.info("Image URL: %s, Base64 data present: %s", image_url, base64_data) + + if not image_url and not base64_data: + _logger.warning("Image content has neither URL nor base64 data") + return + + tool_file_manager = ToolFileManager() + + # Save the image file + try: + if image_url: + # Download image from URL + _logger.info("Downloading image from URL: %s", image_url) + tool_file = tool_file_manager.create_file_by_url( + user_id=user_id, + tenant_id=tenant_id, + file_url=image_url, + conversation_id=None, + ) + _logger.info("Image saved successfully, tool_file_id: %s", tool_file.id) + elif base64_data: + if base64_data.startswith("data:"): + base64_data = base64_data.split(",", 1)[1] + + image_binary = base64.b64decode(base64_data) + mimetype = content.mime_type or "image/png" + extension = guess_extension(mimetype) or ".png" + + tool_file = tool_file_manager.create_file_by_raw( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_binary=image_binary, + mimetype=mimetype, + filename=f"generated_image{extension}", + ) + _logger.info("Image saved successfully, tool_file_id: %s", tool_file.id) + else: + return + except Exception: + _logger.exception("Failed to save image file") + return + + # Create MessageFile record + message_file = MessageFile( + message_id=message_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + belongs_to="assistant", + url=f"/files/tools/{tool_file.id}", + upload_file_id=tool_file.id, + created_by_role=( + CreatorUserRole.ACCOUNT + if queue_manager.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} + else CreatorUserRole.END_USER + ), + created_by=user_id, + ) + + db.session.add(message_file) + db.session.commit() + db.session.refresh(message_file) + + # Publish QueueMessageFileEvent + queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file.id), + PublishFrom.APPLICATION_MANAGER, + ) + + _logger.info("QueueMessageFileEvent published for message_file_id: %s", message_file.id) + def moderation_for_inputs( self, *, diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index f8338b226b..7d1a4c619f 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -226,5 +226,10 @@ class ChatAppRunner(AppRunner): # handle invoke result self._handle_invoke_result( - invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream + invoke_result=invoke_result, + queue_manager=queue_manager, + stream=application_generate_entity.stream, + message_id=message.id, + user_id=application_generate_entity.user_id, + tenant_id=app_config.tenant_id, ) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index ddfb5725b4..a872c2e1f7 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -184,5 +184,10 @@ class CompletionAppRunner(AppRunner): # handle invoke result self._handle_invoke_result( - invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream + invoke_result=invoke_result, + queue_manager=queue_manager, + stream=application_generate_entity.stream, + message_id=message.id, + user_id=application_generate_entity.user_id, + tenant_id=app_config.tenant_id, ) diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 5bb93fa44a..6c997753fa 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -39,6 +39,7 @@ from core.app.entities.task_entities import ( MessageAudioEndStreamResponse, MessageAudioStreamResponse, MessageEndStreamResponse, + StreamEvent, StreamResponse, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline @@ -70,6 +71,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): _task_state: EasyUITaskState _application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity] + _precomputed_event_type: StreamEvent | None = None def __init__( self, @@ -342,11 +344,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): - event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id) + # Determine the event type once, on first LLM chunk, and reuse for subsequent chunks + if not hasattr(self, "_precomputed_event_type") or self._precomputed_event_type is None: + self._precomputed_event_type = self._message_cycle_manager.get_message_event_type( + message_id=self._message_id + ) yield self._message_cycle_manager.message_to_stream_response( answer=cast(str, delta_text), message_id=self._message_id, - event_type=event_type, + event_type=self._precomputed_event_type, ) else: yield self._agent_message_to_stream_response( diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 0e7f300cee..2d4ee08daf 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -5,7 +5,7 @@ from threading import Thread from typing import Union from flask import Flask, current_app -from sqlalchemy import exists, select +from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config @@ -30,6 +30,7 @@ from core.app.entities.task_entities import ( StreamEvent, WorkflowTaskState, ) +from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator from core.tools.signature import sign_tool_file from extensions.ext_database import db @@ -57,13 +58,15 @@ class MessageCycleManager: self._message_has_file: set[str] = set() def get_message_event_type(self, message_id: str) -> StreamEvent: + # Fast path: cached determination from prior QueueMessageFileEvent if message_id in self._message_has_file: return StreamEvent.MESSAGE_FILE - with Session(db.engine, expire_on_commit=False) as session: - has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar() + # Use SQLAlchemy 2.x style session.scalar(select(...)) + with session_factory.create_session() as session: + message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id)) - if has_file: + if message_file: self._message_has_file.add(message_id) return StreamEvent.MESSAGE_FILE @@ -199,6 +202,8 @@ class MessageCycleManager: message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id)) if message_file and message_file.url is not None: + self._message_has_file.add(message_file.message_id) + # get tool file id tool_file_id = message_file.url.split("/")[-1] # trim extension diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index 98ea15e3fc..ce23da1e09 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -1,4 +1,4 @@ -from collections.abc import Generator, Mapping +from collections.abc import Generator from typing import Any from core.datasource.__base.datasource_plugin import DatasourcePlugin @@ -34,7 +34,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): def get_online_document_pages( self, user_id: str, - datasource_parameters: Mapping[str, Any], + datasource_parameters: dict[str, Any], provider_type: str, ) -> Generator[OnlineDocumentPagesMessage, None, None]: manager = PluginDatasourceManager() diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index c0f4c504d9..7a0757f219 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -1,7 +1,7 @@ import logging import time import uuid -from collections.abc import Generator, Sequence +from collections.abc import Callable, Generator, Iterator, Sequence from typing import Union from pydantic import ConfigDict @@ -30,6 +30,142 @@ def _gen_tool_call_id() -> str: return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" +def _run_callbacks(callbacks: Sequence[Callback] | None, *, event: str, invoke: Callable[[Callback], None]) -> None: + if not callbacks: + return + + for callback in callbacks: + try: + invoke(callback) + except Exception as e: + if callback.raise_error: + raise + logger.warning("Callback %s %s failed with error %s", callback.__class__.__name__, event, e) + + +def _get_or_create_tool_call( + existing_tools_calls: list[AssistantPromptMessage.ToolCall], + tool_call_id: str, +) -> AssistantPromptMessage.ToolCall: + """ + Get or create a tool call by ID. + + If `tool_call_id` is empty, returns the most recently created tool call. + """ + if not tool_call_id: + if not existing_tools_calls: + raise ValueError("tool_call_id is empty but no existing tool call is available to apply the delta") + return existing_tools_calls[-1] + + tool_call = next((tool_call for tool_call in existing_tools_calls if tool_call.id == tool_call_id), None) + if tool_call is None: + tool_call = AssistantPromptMessage.ToolCall( + id=tool_call_id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), + ) + existing_tools_calls.append(tool_call) + + return tool_call + + +def _merge_tool_call_delta( + tool_call: AssistantPromptMessage.ToolCall, + delta: AssistantPromptMessage.ToolCall, +) -> None: + if delta.id: + tool_call.id = delta.id + if delta.type: + tool_call.type = delta.type + if delta.function.name: + tool_call.function.name = delta.function.name + if delta.function.arguments: + tool_call.function.arguments += delta.function.arguments + + +def _build_llm_result_from_first_chunk( + model: str, + prompt_messages: Sequence[PromptMessage], + chunks: Iterator[LLMResultChunk], +) -> LLMResult: + """ + Build a single `LLMResult` from the first returned chunk. + + This is used for `stream=False` because the plugin side may still implement the response via a chunked stream. + """ + content = "" + content_list: list[PromptMessageContentUnionTypes] = [] + usage = LLMUsage.empty_usage() + system_fingerprint: str | None = None + tools_calls: list[AssistantPromptMessage.ToolCall] = [] + + first_chunk = next(chunks, None) + if first_chunk is not None: + if isinstance(first_chunk.delta.message.content, str): + content += first_chunk.delta.message.content + elif isinstance(first_chunk.delta.message.content, list): + content_list.extend(first_chunk.delta.message.content) + + if first_chunk.delta.message.tool_calls: + _increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls) + + usage = first_chunk.delta.usage or LLMUsage.empty_usage() + system_fingerprint = first_chunk.system_fingerprint + + return LLMResult( + model=model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage( + content=content or content_list, + tool_calls=tools_calls, + ), + usage=usage, + system_fingerprint=system_fingerprint, + ) + + +def _invoke_llm_via_plugin( + *, + tenant_id: str, + user_id: str, + plugin_id: str, + provider: str, + model: str, + credentials: dict, + model_parameters: dict, + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, +) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: + from core.plugin.impl.model import PluginModelClient + + plugin_model_manager = PluginModelClient() + return plugin_model_manager.invoke_llm( + tenant_id=tenant_id, + user_id=user_id, + plugin_id=plugin_id, + provider=provider, + model=model, + credentials=credentials, + model_parameters=model_parameters, + prompt_messages=list(prompt_messages), + tools=tools, + stop=list(stop) if stop else None, + stream=stream, + ) + + +def _normalize_non_stream_plugin_result( + model: str, + prompt_messages: Sequence[PromptMessage], + result: Union[LLMResult, Iterator[LLMResultChunk]], +) -> LLMResult: + if isinstance(result, LLMResult): + return result + return _build_llm_result_from_first_chunk(model=model, prompt_messages=prompt_messages, chunks=result) + + def _increase_tool_call( new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall] ): @@ -40,42 +176,13 @@ def _increase_tool_call( :param existing_tools_calls: List of existing tool calls to be modified IN-PLACE. """ - def get_tool_call(tool_call_id: str): - """ - Get or create a tool call by ID - - :param tool_call_id: tool call ID - :return: existing or new tool call - """ - if not tool_call_id: - return existing_tools_calls[-1] - - _tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None) - if _tool_call is None: - _tool_call = AssistantPromptMessage.ToolCall( - id=tool_call_id, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), - ) - existing_tools_calls.append(_tool_call) - - return _tool_call - for new_tool_call in new_tool_calls: # generate ID for tool calls with function name but no ID to track them if new_tool_call.function.name and not new_tool_call.id: new_tool_call.id = _gen_tool_call_id() - # get tool call - tool_call = get_tool_call(new_tool_call.id) - # update tool call - if new_tool_call.id: - tool_call.id = new_tool_call.id - if new_tool_call.type: - tool_call.type = new_tool_call.type - if new_tool_call.function.name: - tool_call.function.name = new_tool_call.function.name - if new_tool_call.function.arguments: - tool_call.function.arguments += new_tool_call.function.arguments + + tool_call = _get_or_create_tool_call(existing_tools_calls, new_tool_call.id) + _merge_tool_call_delta(tool_call, new_tool_call) class LargeLanguageModel(AIModel): @@ -141,10 +248,7 @@ class LargeLanguageModel(AIModel): result: Union[LLMResult, Generator[LLMResultChunk, None, None]] try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - result = plugin_model_manager.invoke_llm( + result = _invoke_llm_via_plugin( tenant_id=self.tenant_id, user_id=user or "unknown", plugin_id=self.plugin_id, @@ -154,38 +258,13 @@ class LargeLanguageModel(AIModel): model_parameters=model_parameters, prompt_messages=prompt_messages, tools=tools, - stop=list(stop) if stop else None, + stop=stop, stream=stream, ) if not stream: - content = "" - content_list = [] - usage = LLMUsage.empty_usage() - system_fingerprint = None - tools_calls: list[AssistantPromptMessage.ToolCall] = [] - - for chunk in result: - if isinstance(chunk.delta.message.content, str): - content += chunk.delta.message.content - elif isinstance(chunk.delta.message.content, list): - content_list.extend(chunk.delta.message.content) - if chunk.delta.message.tool_calls: - _increase_tool_call(chunk.delta.message.tool_calls, tools_calls) - - usage = chunk.delta.usage or LLMUsage.empty_usage() - system_fingerprint = chunk.system_fingerprint - break - - result = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=content or content_list, - tool_calls=tools_calls, - ), - usage=usage, - system_fingerprint=system_fingerprint, + result = _normalize_non_stream_plugin_result( + model=model, prompt_messages=prompt_messages, result=result ) except Exception as e: self._trigger_invoke_error_callbacks( @@ -425,27 +504,21 @@ class LargeLanguageModel(AIModel): :param user: unique user id :param callbacks: callbacks """ - if callbacks: - for callback in callbacks: - try: - callback.on_before_invoke( - llm_instance=self, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ) - except Exception as e: - if callback.raise_error: - raise e - else: - logger.warning( - "Callback %s on_before_invoke failed with error %s", callback.__class__.__name__, e - ) + _run_callbacks( + callbacks, + event="on_before_invoke", + invoke=lambda callback: callback.on_before_invoke( + llm_instance=self, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ), + ) def _trigger_new_chunk_callbacks( self, @@ -473,26 +546,22 @@ class LargeLanguageModel(AIModel): :param stream: is stream response :param user: unique user id """ - if callbacks: - for callback in callbacks: - try: - callback.on_new_chunk( - llm_instance=self, - chunk=chunk, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ) - except Exception as e: - if callback.raise_error: - raise e - else: - logger.warning("Callback %s on_new_chunk failed with error %s", callback.__class__.__name__, e) + _run_callbacks( + callbacks, + event="on_new_chunk", + invoke=lambda callback: callback.on_new_chunk( + llm_instance=self, + chunk=chunk, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ), + ) def _trigger_after_invoke_callbacks( self, @@ -521,28 +590,22 @@ class LargeLanguageModel(AIModel): :param user: unique user id :param callbacks: callbacks """ - if callbacks: - for callback in callbacks: - try: - callback.on_after_invoke( - llm_instance=self, - result=result, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ) - except Exception as e: - if callback.raise_error: - raise e - else: - logger.warning( - "Callback %s on_after_invoke failed with error %s", callback.__class__.__name__, e - ) + _run_callbacks( + callbacks, + event="on_after_invoke", + invoke=lambda callback: callback.on_after_invoke( + llm_instance=self, + result=result, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ), + ) def _trigger_invoke_error_callbacks( self, @@ -571,25 +634,19 @@ class LargeLanguageModel(AIModel): :param user: unique user id :param callbacks: callbacks """ - if callbacks: - for callback in callbacks: - try: - callback.on_invoke_error( - llm_instance=self, - ex=ex, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ) - except Exception as e: - if callback.raise_error: - raise e - else: - logger.warning( - "Callback %s on_invoke_error failed with error %s", callback.__class__.__name__, e - ) + _run_callbacks( + callbacks, + event="on_invoke_error", + invoke=lambda callback: callback.on_invoke_error( + llm_instance=self, + ex=ex, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ), + ) diff --git a/api/enums/hosted_provider.py b/api/enums/hosted_provider.py new file mode 100644 index 0000000000..c6d3715dc1 --- /dev/null +++ b/api/enums/hosted_provider.py @@ -0,0 +1,21 @@ +from enum import StrEnum + + +class HostedTrialProvider(StrEnum): + """ + Enum representing hosted model provider names for trial access. + """ + + OPENAI = "langgenius/openai/openai" + ANTHROPIC = "langgenius/anthropic/anthropic" + GEMINI = "langgenius/gemini/google" + X = "langgenius/x/x" + DEEPSEEK = "langgenius/deepseek/deepseek" + TONGYI = "langgenius/tongyi/tongyi" + + @property + def config_key(self) -> str: + """Return the config key used in dify_config (e.g., HOSTED_{config_key}_PAID_ENABLED).""" + if self == HostedTrialProvider.X: + return "XAI" + return self.name diff --git a/api/extensions/ext_fastopenapi.py b/api/extensions/ext_fastopenapi.py index 0ef1513e11..5f98aa7b67 100644 --- a/api/extensions/ext_fastopenapi.py +++ b/api/extensions/ext_fastopenapi.py @@ -28,8 +28,10 @@ def init_app(app: DifyApp) -> None: # Ensure route decorators are evaluated. import controllers.console.ping as ping_module + from controllers.console import setup _ = ping_module + _ = setup router.include_router(console_router, prefix="/console/api") CORS( diff --git a/api/pyproject.toml b/api/pyproject.toml index 9f9bd11fa6..575c1434c5 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -64,7 +64,7 @@ dependencies = [ "pandas[excel,output-formatting,performance]~=2.2.2", "psycogreen~=1.0.2", "psycopg2-binary~=2.9.6", - "pycryptodome==3.19.1", + "pycryptodome==3.23.0", "pydantic~=2.11.4", "pydantic-extra-types~=2.10.3", "pydantic-settings~=2.11.0", diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index da22464d39..edcb2a7870 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -781,15 +781,16 @@ class AppDslService: return dependencies @classmethod - def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]: + def get_leaked_dependencies( + cls, tenant_id: str, dsl_dependencies: list[PluginDependency] + ) -> list[PluginDependency]: """ Returns the leaked dependencies in current workspace """ - dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies] - if not dependencies: + if not dsl_dependencies: return [] - return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies) + return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dsl_dependencies) @staticmethod def _generate_aes_key(tenant_id: str) -> bytes: diff --git a/api/services/feature_service.py b/api/services/feature_service.py index b2fb3784e8..d94ae49d91 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, ConfigDict, Field from configs import dify_config from enums.cloud_plan import CloudPlan +from enums.hosted_provider import HostedTrialProvider from services.billing_service import BillingService from services.enterprise.enterprise_service import EnterpriseService @@ -170,6 +171,7 @@ class SystemFeatureModel(BaseModel): plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel() enable_change_email: bool = True plugin_manager: PluginManagerModel = PluginManagerModel() + trial_models: list[str] = [] enable_trial_app: bool = False enable_explore_banner: bool = False @@ -227,9 +229,21 @@ class FeatureService: system_features.is_allow_register = dify_config.ALLOW_REGISTER system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != "" + system_features.trial_models = cls._fulfill_trial_models_from_env() system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER + @classmethod + def _fulfill_trial_models_from_env(cls) -> list[str]: + return [ + provider.value + for provider in HostedTrialProvider + if ( + getattr(dify_config, f"HOSTED_{provider.config_key}_PAID_ENABLED", False) + and getattr(dify_config, f"HOSTED_{provider.config_key}_TRIAL_ENABLED", False) + ) + ] + @classmethod def _fulfill_params_from_env(cls, features: FeatureModel): features.can_replace_logo = dify_config.CAN_REPLACE_LOGO diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 06f294863d..c1c6e204fb 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -870,15 +870,16 @@ class RagPipelineDslService: return dependencies @classmethod - def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]: + def get_leaked_dependencies( + cls, tenant_id: str, dsl_dependencies: list[PluginDependency] + ) -> list[PluginDependency]: """ Returns the leaked dependencies in current workspace """ - dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies] - if not dependencies: + if not dsl_dependencies: return [] - return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies) + return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dsl_dependencies) def _generate_aes_key(self, tenant_id: str) -> bytes: """Generate AES key based on tenant_id""" diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 84f97907c0..8ea365e907 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -44,7 +44,7 @@ class RagPipelineTransformService: doc_form = dataset.doc_form if not doc_form: return self._transform_to_empty_pipeline(dataset) - retrieval_model = dataset.retrieval_model + retrieval_model = RetrievalSetting.model_validate(dataset.retrieval_model) if dataset.retrieval_model else None pipeline_yaml = self._get_transform_yaml(doc_form, datasource_type, indexing_technique) # deal dependencies self._deal_dependencies(pipeline_yaml, dataset.tenant_id) @@ -154,7 +154,12 @@ class RagPipelineTransformService: return node def _deal_knowledge_index( - self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict + self, + dataset: Dataset, + doc_form: str, + indexing_technique: str | None, + retrieval_model: RetrievalSetting | None, + node: dict, ): knowledge_configuration_dict = node.get("data", {}) knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict) @@ -163,10 +168,9 @@ class RagPipelineTransformService: knowledge_configuration.embedding_model = dataset.embedding_model knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider if retrieval_model: - retrieval_setting = RetrievalSetting.model_validate(retrieval_model) if indexing_technique == "economy": - retrieval_setting.search_method = RetrievalMethod.KEYWORD_SEARCH - knowledge_configuration.retrieval_model = retrieval_setting + retrieval_model.search_method = RetrievalMethod.KEYWORD_SEARCH + knowledge_configuration.retrieval_model = retrieval_model else: dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py index e6492c230d..b5e6508006 100644 --- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py +++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3 CACHE_REDIS_KEY_PREFIX = "plugin_autoupgrade_check_task:cached_plugin_manifests:" -CACHE_REDIS_TTL = 60 * 15 # 15 minutes +CACHE_REDIS_TTL = 60 * 60 # 1 hour def _get_redis_cache_key(plugin_id: str) -> str: diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_setup.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_setup.py new file mode 100644 index 0000000000..385539b6f3 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_fastopenapi_setup.py @@ -0,0 +1,56 @@ +import builtins +from unittest.mock import patch + +import pytest +from flask import Flask +from flask.views import MethodView + +from extensions import ext_fastopenapi + +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +@pytest.fixture +def app() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + return app + + +def test_console_setup_fastopenapi_get_not_started(app: Flask): + ext_fastopenapi.init_app(app) + + with ( + patch("controllers.console.setup.dify_config.EDITION", "SELF_HOSTED"), + patch("controllers.console.setup.get_setup_status", return_value=None), + ): + client = app.test_client() + response = client.get("/console/api/setup") + + assert response.status_code == 200 + assert response.get_json() == {"step": "not_started", "setup_at": None} + + +def test_console_setup_fastopenapi_post_success(app: Flask): + ext_fastopenapi.init_app(app) + + payload = { + "email": "admin@example.com", + "name": "Admin", + "password": "Passw0rd1", + "language": "en-US", + } + + with ( + patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"), + patch("controllers.console.setup.get_setup_status", return_value=None), + patch("controllers.console.setup.TenantService.get_tenant_count", return_value=0), + patch("controllers.console.setup.get_init_validate_status", return_value=True), + patch("controllers.console.setup.RegisterService.setup"), + ): + client = app.test_client() + response = client.post("/console/api/setup", json=payload) + + assert response.status_code == 201 + assert response.get_json() == {"result": "success"} diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_version.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_version.py new file mode 100644 index 0000000000..c5b4e0dfcf --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_fastopenapi_version.py @@ -0,0 +1,35 @@ +import builtins +from unittest.mock import patch + +import pytest +from flask import Flask +from flask.views import MethodView + +from configs import dify_config +from extensions import ext_fastopenapi + +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +@pytest.fixture +def app() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + return app + + +def test_console_version_fastopenapi_returns_current_version(app: Flask): + ext_fastopenapi.init_app(app) + + with patch("controllers.console.version.dify_config.CHECK_UPDATE_URL", None): + client = app.test_client() + response = client.get("/console/api/version", query_string={"current_version": "0.0.0"}) + + assert response.status_code == 200 + data = response.get_json() + assert data["version"] == dify_config.project.version + assert data["release_date"] == "" + assert data["release_notes"] == "" + assert data["can_auto_update"] is False + assert "features" in data diff --git a/api/tests/unit_tests/controllers/console/test_setup.py b/api/tests/unit_tests/controllers/console/test_setup.py deleted file mode 100644 index e7882dcd2b..0000000000 --- a/api/tests/unit_tests/controllers/console/test_setup.py +++ /dev/null @@ -1,39 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import patch - -from controllers.console.setup import SetupApi - - -class TestSetupApi: - def test_post_lowercases_email_before_register(self): - """Ensure setup registration normalizes email casing.""" - payload = { - "email": "Admin@Example.com", - "name": "Admin User", - "password": "ValidPass123!", - "language": "en-US", - } - setup_api = SetupApi(api=None) - - mock_console_ns = SimpleNamespace(payload=payload) - - with ( - patch("controllers.console.setup.console_ns", mock_console_ns), - patch("controllers.console.setup.get_setup_status", return_value=False), - patch("controllers.console.setup.TenantService.get_tenant_count", return_value=0), - patch("controllers.console.setup.get_init_validate_status", return_value=True), - patch("controllers.console.setup.extract_remote_ip", return_value="127.0.0.1"), - patch("controllers.console.setup.request", object()), - patch("controllers.console.setup.RegisterService.setup") as mock_register, - ): - response, status = setup_api.post() - - assert response == {"result": "success"} - assert status == 201 - mock_register.assert_called_once_with( - email="admin@example.com", - name=payload["name"], - password=payload["password"], - ip_address="127.0.0.1", - language=payload["language"], - ) diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py new file mode 100644 index 0000000000..421a5246eb --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -0,0 +1,454 @@ +"""Test multimodal image output handling in BaseAppRunner.""" + +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueMessageFileEvent +from core.file.enums import FileTransferMethod, FileType +from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from models.enums import CreatorUserRole + + +class TestBaseAppRunnerMultimodal: + """Test that BaseAppRunner correctly handles multimodal image content.""" + + @pytest.fixture + def mock_user_id(self): + """Mock user ID.""" + return str(uuid4()) + + @pytest.fixture + def mock_tenant_id(self): + """Mock tenant ID.""" + return str(uuid4()) + + @pytest.fixture + def mock_message_id(self): + """Mock message ID.""" + return str(uuid4()) + + @pytest.fixture + def mock_queue_manager(self): + """Create a mock queue manager.""" + manager = MagicMock() + manager.invoke_from = InvokeFrom.SERVICE_API + return manager + + @pytest.fixture + def mock_tool_file(self): + """Create a mock tool file.""" + tool_file = MagicMock() + tool_file.id = str(uuid4()) + return tool_file + + @pytest.fixture + def mock_message_file(self): + """Create a mock message file.""" + message_file = MagicMock() + message_file.id = str(uuid4()) + return message_file + + def test_handle_multimodal_image_content_with_url( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test handling image from URL.""" + # Arrange + image_url = "http://example.com/image.png" + content = ImagePromptMessageContent( + url=image_url, + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_url.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert + # Verify tool file was created from URL + mock_mgr.create_file_by_url.assert_called_once_with( + user_id=mock_user_id, + tenant_id=mock_tenant_id, + file_url=image_url, + conversation_id=None, + ) + + # Verify message file was created with correct parameters + mock_msg_file_class.assert_called_once() + call_kwargs = mock_msg_file_class.call_args[1] + assert call_kwargs["message_id"] == mock_message_id + assert call_kwargs["type"] == FileType.IMAGE + assert call_kwargs["transfer_method"] == FileTransferMethod.TOOL_FILE + assert call_kwargs["belongs_to"] == "assistant" + assert call_kwargs["created_by"] == mock_user_id + + # Verify database operations + mock_session.add.assert_called_once_with(mock_message_file) + mock_session.commit.assert_called_once() + mock_session.refresh.assert_called_once_with(mock_message_file) + + # Verify event was published + mock_queue_manager.publish.assert_called_once() + publish_call = mock_queue_manager.publish.call_args + assert isinstance(publish_call[0][0], QueueMessageFileEvent) + assert publish_call[0][0].message_file_id == mock_message_file.id + # publish_from might be passed as positional or keyword argument + assert ( + publish_call[0][1] == PublishFrom.APPLICATION_MANAGER + or publish_call.kwargs.get("publish_from") == PublishFrom.APPLICATION_MANAGER + ) + + def test_handle_multimodal_image_content_with_base64( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test handling image from base64 data.""" + # Arrange + import base64 + + # Create a small test image (1x1 PNG) + test_image_data = base64.b64encode( + b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde" + ).decode() + content = ImagePromptMessageContent( + base64_data=test_image_data, + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_raw.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert + # Verify tool file was created from base64 + mock_mgr.create_file_by_raw.assert_called_once() + call_kwargs = mock_mgr.create_file_by_raw.call_args[1] + assert call_kwargs["user_id"] == mock_user_id + assert call_kwargs["tenant_id"] == mock_tenant_id + assert call_kwargs["conversation_id"] is None + assert "file_binary" in call_kwargs + assert call_kwargs["mimetype"] == "image/png" + assert call_kwargs["filename"].startswith("generated_image") + assert call_kwargs["filename"].endswith(".png") + + # Verify message file was created + mock_msg_file_class.assert_called_once() + + # Verify database operations + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + mock_session.refresh.assert_called_once() + + # Verify event was published + mock_queue_manager.publish.assert_called_once() + + def test_handle_multimodal_image_content_with_base64_data_uri( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test handling image from base64 data with URI prefix.""" + # Arrange + # Data URI format: data:image/png;base64, + test_image_data = ( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + ) + content = ImagePromptMessageContent( + base64_data=f"data:image/png;base64,{test_image_data}", + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_raw.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - verify that base64 data was extracted correctly (without prefix) + mock_mgr.create_file_by_raw.assert_called_once() + call_kwargs = mock_mgr.create_file_by_raw.call_args[1] + # The base64 data should be decoded, so we check the binary was passed + assert "file_binary" in call_kwargs + + def test_handle_multimodal_image_content_without_url_or_base64( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + ): + """Test handling image content without URL or base64 data.""" + # Arrange + content = ImagePromptMessageContent( + url="", + base64_data="", + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - should not create any files or publish events + mock_mgr_class.assert_not_called() + mock_msg_file_class.assert_not_called() + mock_session.add.assert_not_called() + mock_queue_manager.publish.assert_not_called() + + def test_handle_multimodal_image_content_with_error( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + ): + """Test handling image content when an error occurs.""" + # Arrange + image_url = "http://example.com/image.png" + content = ImagePromptMessageContent( + url=image_url, + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock to raise exception + mock_mgr = MagicMock() + mock_mgr.create_file_by_url.side_effect = Exception("Network error") + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + # Should not raise exception, just log it + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - should not create message file or publish event on error + mock_msg_file_class.assert_not_called() + mock_session.add.assert_not_called() + mock_queue_manager.publish.assert_not_called() + + def test_handle_multimodal_image_content_debugger_mode( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test that debugger mode sets correct created_by_role.""" + # Arrange + image_url = "http://example.com/image.png" + content = ImagePromptMessageContent( + url=image_url, + format="png", + mime_type="image/png", + ) + mock_queue_manager.invoke_from = InvokeFrom.DEBUGGER + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_url.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - verify created_by_role is ACCOUNT for debugger mode + call_kwargs = mock_msg_file_class.call_args[1] + assert call_kwargs["created_by_role"] == CreatorUserRole.ACCOUNT + + def test_handle_multimodal_image_content_service_api_mode( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test that service API mode sets correct created_by_role.""" + # Arrange + image_url = "http://example.com/image.png" + content = ImagePromptMessageContent( + url=image_url, + format="png", + mime_type="image/png", + ) + mock_queue_manager.invoke_from = InvokeFrom.SERVICE_API + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_url.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - verify created_by_role is END_USER for service API + call_kwargs = mock_msg_file_class.call_args[1] + assert call_kwargs["created_by_role"] == CreatorUserRole.END_USER diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py index 5ef7f0d7f4..5a43a247e3 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py @@ -1,7 +1,6 @@ """Unit tests for the message cycle manager optimization.""" -from types import SimpleNamespace -from unittest.mock import ANY, Mock, patch +from unittest.mock import Mock, patch import pytest from flask import current_app @@ -28,17 +27,14 @@ class TestMessageCycleManagerOptimization: def test_get_message_event_type_with_message_file(self, message_cycle_manager): """Test get_message_event_type returns MESSAGE_FILE when message has files.""" - with ( - patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, - patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), - ): + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: # Setup mock session and message file mock_session = Mock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session mock_message_file = Mock() - # Current implementation uses session.query(...).scalar() - mock_session.query.return_value.scalar.return_value = mock_message_file + # Current implementation uses session.scalar(select(...)) + mock_session.scalar.return_value = mock_message_file # Execute with current_app.app_context(): @@ -46,19 +42,16 @@ class TestMessageCycleManagerOptimization: # Assert assert result == StreamEvent.MESSAGE_FILE - mock_session.query.return_value.scalar.assert_called_once() + mock_session.scalar.assert_called_once() def test_get_message_event_type_without_message_file(self, message_cycle_manager): """Test get_message_event_type returns MESSAGE when message has no files.""" - with ( - patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, - patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), - ): + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: # Setup mock session and no message file mock_session = Mock() - mock_session_class.return_value.__enter__.return_value = mock_session - # Current implementation uses session.query(...).scalar() - mock_session.query.return_value.scalar.return_value = None + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + # Current implementation uses session.scalar(select(...)) + mock_session.scalar.return_value = None # Execute with current_app.app_context(): @@ -66,21 +59,18 @@ class TestMessageCycleManagerOptimization: # Assert assert result == StreamEvent.MESSAGE - mock_session.query.return_value.scalar.assert_called_once() + mock_session.scalar.assert_called_once() def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager): """MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it.""" - with ( - patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, - patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), - ): + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: # Setup mock session and message file mock_session = Mock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session mock_message_file = Mock() - # Current implementation uses session.query(...).scalar() - mock_session.query.return_value.scalar.return_value = mock_message_file + # Current implementation uses session.scalar(select(...)) + mock_session.scalar.return_value = mock_message_file # Execute: compute event type once, then pass to message_to_stream_response with current_app.app_context(): @@ -94,11 +84,11 @@ class TestMessageCycleManagerOptimization: assert result.answer == "Hello world" assert result.id == "test-message-id" assert result.event == StreamEvent.MESSAGE_FILE - mock_session.query.return_value.scalar.assert_called_once() + mock_session.scalar.assert_called_once() def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager): """Test that message_to_stream_response skips database query when event_type is provided.""" - with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class: + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: # Execute with event_type provided result = message_cycle_manager.message_to_stream_response( answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE @@ -109,8 +99,8 @@ class TestMessageCycleManagerOptimization: assert result.answer == "Hello world" assert result.id == "test-message-id" assert result.event == StreamEvent.MESSAGE - # Should not query database when event_type is provided - mock_session_class.assert_not_called() + # Should not open a session when event_type is provided + mock_session_factory.create_session.assert_not_called() def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager): """Test message_to_stream_response with from_variable_selector parameter.""" @@ -130,24 +120,21 @@ class TestMessageCycleManagerOptimization: def test_optimization_usage_example(self, message_cycle_manager): """Test the optimization pattern that should be used by callers.""" # Step 1: Get event type once (this queries database) - with ( - patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, - patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), - ): + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: mock_session = Mock() - mock_session_class.return_value.__enter__.return_value = mock_session - # Current implementation uses session.query(...).scalar() - mock_session.query.return_value.scalar.return_value = None # No files + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + # Current implementation uses session.scalar(select(...)) + mock_session.scalar.return_value = None # No files with current_app.app_context(): event_type = message_cycle_manager.get_message_event_type("test-message-id") - # Should query database once - mock_session_class.assert_called_once_with(ANY, expire_on_commit=False) + # Should open session once + mock_session_factory.create_session.assert_called_once() assert event_type == StreamEvent.MESSAGE # Step 2: Use event_type for multiple calls (no additional queries) - with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class: - mock_session_class.return_value.__enter__.return_value = Mock() + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: + mock_session_factory.create_session.return_value.__enter__.return_value = Mock() chunk1_response = message_cycle_manager.message_to_stream_response( answer="Chunk 1", message_id="test-message-id", event_type=event_type @@ -157,8 +144,8 @@ class TestMessageCycleManagerOptimization: answer="Chunk 2", message_id="test-message-id", event_type=event_type ) - # Should not query database again - mock_session_class.assert_not_called() + # Should not open session again when event_type provided + mock_session_factory.create_session.assert_not_called() assert chunk1_response.event == StreamEvent.MESSAGE assert chunk2_response.event == StreamEvent.MESSAGE diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py b/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py index 93d8a20cac..5fbdabceed 100644 --- a/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py +++ b/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py @@ -1,5 +1,7 @@ from unittest.mock import MagicMock, patch +import pytest + from core.model_runtime.entities.message_entities import AssistantPromptMessage from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call @@ -97,3 +99,14 @@ def test__increase_tool_call(): mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4] with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator): _run_case(INPUTS_CASE_4, EXPECTED_CASE_4) + + +def test__increase_tool_call__no_id_no_name_first_delta_should_raise(): + inputs = [ + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='"value"}')), + ] + actual: list[ToolCall] = [] + with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", MagicMock()): + with pytest.raises(ValueError): + _increase_tool_call(inputs, actual) diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py b/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py new file mode 100644 index 0000000000..91352b2a5f --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py @@ -0,0 +1,103 @@ +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.model_providers.__base.large_language_model import _normalize_non_stream_plugin_result + + +def _make_chunk( + *, + model: str = "test-model", + content: str | list[TextPromptMessageContent] | None, + tool_calls: list[AssistantPromptMessage.ToolCall] | None = None, + usage: LLMUsage | None = None, + system_fingerprint: str | None = None, +) -> LLMResultChunk: + message = AssistantPromptMessage(content=content, tool_calls=tool_calls or []) + delta = LLMResultChunkDelta(index=0, message=message, usage=usage) + return LLMResultChunk(model=model, delta=delta, system_fingerprint=system_fingerprint) + + +def test__normalize_non_stream_plugin_result__from_first_chunk_str_content_and_tool_calls(): + prompt_messages = [UserPromptMessage(content="hi")] + + tool_calls = [ + AssistantPromptMessage.ToolCall( + id="1", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments=""), + ), + AssistantPromptMessage.ToolCall( + id="", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='{"arg1": '), + ), + AssistantPromptMessage.ToolCall( + id="", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='"value"}'), + ), + ] + + usage = LLMUsage.empty_usage().model_copy(update={"prompt_tokens": 1, "total_tokens": 1}) + chunk = _make_chunk(content="hello", tool_calls=tool_calls, usage=usage, system_fingerprint="fp-1") + + result = _normalize_non_stream_plugin_result( + model="test-model", prompt_messages=prompt_messages, result=iter([chunk]) + ) + + assert result.model == "test-model" + assert result.prompt_messages == prompt_messages + assert result.message.content == "hello" + assert result.usage.prompt_tokens == 1 + assert result.system_fingerprint == "fp-1" + assert result.message.tool_calls == [ + AssistantPromptMessage.ToolCall( + id="1", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'), + ) + ] + + +def test__normalize_non_stream_plugin_result__from_first_chunk_list_content(): + prompt_messages = [UserPromptMessage(content="hi")] + + content_list = [TextPromptMessageContent(data="a"), TextPromptMessageContent(data="b")] + chunk = _make_chunk(content=content_list, usage=LLMUsage.empty_usage()) + + result = _normalize_non_stream_plugin_result( + model="test-model", prompt_messages=prompt_messages, result=iter([chunk]) + ) + + assert result.message.content == content_list + + +def test__normalize_non_stream_plugin_result__passthrough_llm_result(): + prompt_messages = [UserPromptMessage(content="hi")] + llm_result = LLMResult( + model="test-model", + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content="ok"), + usage=LLMUsage.empty_usage(), + ) + + assert ( + _normalize_non_stream_plugin_result(model="test-model", prompt_messages=prompt_messages, result=llm_result) + == llm_result + ) + + +def test__normalize_non_stream_plugin_result__empty_iterator_defaults(): + prompt_messages = [UserPromptMessage(content="hi")] + + result = _normalize_non_stream_plugin_result(model="test-model", prompt_messages=prompt_messages, result=iter([])) + + assert result.model == "test-model" + assert result.prompt_messages == prompt_messages + assert result.message.content == [] + assert result.message.tool_calls == [] + assert result.usage == LLMUsage.empty_usage() + assert result.system_fingerprint is None diff --git a/api/uv.lock b/api/uv.lock index 7853d06bf6..7c3e118515 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1633,7 +1633,7 @@ requires-dist = [ { name = "pandas", extras = ["excel", "output-formatting", "performance"], specifier = "~=2.2.2" }, { name = "psycogreen", specifier = "~=1.0.2" }, { name = "psycopg2-binary", specifier = "~=2.9.6" }, - { name = "pycryptodome", specifier = "==3.19.1" }, + { name = "pycryptodome", specifier = "==3.23.0" }, { name = "pydantic", specifier = "~=2.11.4" }, { name = "pydantic-extra-types", specifier = "~=2.10.3" }, { name = "pydantic-settings", specifier = "~=2.11.0" }, @@ -4796,20 +4796,21 @@ wheels = [ [[package]] name = "pycryptodome" -version = "3.19.1" +version = "3.23.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b1/38/42a8855ff1bf568c61ca6557e2203f318fb7afeadaf2eb8ecfdbde107151/pycryptodome-3.19.1.tar.gz", hash = "sha256:8ae0dd1bcfada451c35f9e29a3e5db385caabc190f98e4a80ad02a61098fb776", size = 4782144, upload-time = "2023-12-28T06:52:40.741Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/a6/8452177684d5e906854776276ddd34eca30d1b1e15aa1ee9cefc289a33f5/pycryptodome-3.23.0.tar.gz", hash = "sha256:447700a657182d60338bab09fdb27518f8856aecd80ae4c6bdddb67ff5da44ef", size = 4921276, upload-time = "2025-05-17T17:21:45.242Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a8/ef/4931bc30674f0de0ca0e827b58c8b0c17313a8eae2754976c610b866118b/pycryptodome-3.19.1-cp35-abi3-macosx_10_9_universal2.whl", hash = "sha256:67939a3adbe637281c611596e44500ff309d547e932c449337649921b17b6297", size = 2417027, upload-time = "2023-12-28T06:51:50.138Z" }, - { url = "https://files.pythonhosted.org/packages/67/e6/238c53267fd8d223029c0a0d3730cb1b6594d60f62e40c4184703dc490b1/pycryptodome-3.19.1-cp35-abi3-macosx_10_9_x86_64.whl", hash = "sha256:11ddf6c9b52116b62223b6a9f4741bc4f62bb265392a4463282f7f34bb287180", size = 1579728, upload-time = "2023-12-28T06:51:52.385Z" }, - { url = "https://files.pythonhosted.org/packages/7c/87/7181c42c8d5ba89822a4b824830506d0aeec02959bb893614767e3279846/pycryptodome-3.19.1-cp35-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3e6f89480616781d2a7f981472d0cdb09b9da9e8196f43c1234eff45c915766", size = 2051440, upload-time = "2023-12-28T06:51:55.751Z" }, - { url = "https://files.pythonhosted.org/packages/34/dd/332c4c0055527d17dac317ed9f9c864fc047b627d82f4b9a56c110afc6fc/pycryptodome-3.19.1-cp35-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27e1efcb68993b7ce5d1d047a46a601d41281bba9f1971e6be4aa27c69ab8065", size = 2125379, upload-time = "2023-12-28T06:51:58.567Z" }, - { url = "https://files.pythonhosted.org/packages/24/9e/320b885ea336c218ff54ec2b276cd70ba6904e4f5a14a771ed39a2c47d59/pycryptodome-3.19.1-cp35-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1c6273ca5a03b672e504995529b8bae56da0ebb691d8ef141c4aa68f60765700", size = 2153951, upload-time = "2023-12-28T06:52:01.699Z" }, - { url = "https://files.pythonhosted.org/packages/f4/54/8ae0c43d1257b41bc9d3277c3f875174fd8ad86b9567f0b8609b99c938ee/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:b0bfe61506795877ff974f994397f0c862d037f6f1c0bfc3572195fc00833b96", size = 2044041, upload-time = "2023-12-28T06:52:03.737Z" }, - { url = "https://files.pythonhosted.org/packages/45/93/f8450a92cc38541c3ba1f4cb4e267e15ae6d6678ca617476d52c3a3764d4/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_i686.whl", hash = "sha256:f34976c5c8eb79e14c7d970fb097482835be8d410a4220f86260695ede4c3e17", size = 2182446, upload-time = "2023-12-28T06:52:05.588Z" }, - { url = "https://files.pythonhosted.org/packages/af/cd/ed6e429fb0792ce368f66e83246264dd3a7a045b0b1e63043ed22a063ce5/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7c9e222d0976f68d0cf6409cfea896676ddc1d98485d601e9508f90f60e2b0a2", size = 2144914, upload-time = "2023-12-28T06:52:07.44Z" }, - { url = "https://files.pythonhosted.org/packages/f6/23/b064bd4cfbf2cc5f25afcde0e7c880df5b20798172793137ba4b62d82e72/pycryptodome-3.19.1-cp35-abi3-win32.whl", hash = "sha256:4805e053571140cb37cf153b5c72cd324bb1e3e837cbe590a19f69b6cf85fd03", size = 1713105, upload-time = "2023-12-28T06:52:09.585Z" }, - { url = "https://files.pythonhosted.org/packages/7d/e0/ded1968a5257ab34216a0f8db7433897a2337d59e6d03be113713b346ea2/pycryptodome-3.19.1-cp35-abi3-win_amd64.whl", hash = "sha256:a470237ee71a1efd63f9becebc0ad84b88ec28e6784a2047684b693f458f41b7", size = 1749222, upload-time = "2023-12-28T06:52:11.534Z" }, + { url = "https://files.pythonhosted.org/packages/db/6c/a1f71542c969912bb0e106f64f60a56cc1f0fabecf9396f45accbe63fa68/pycryptodome-3.23.0-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:187058ab80b3281b1de11c2e6842a357a1f71b42cb1e15bce373f3d238135c27", size = 2495627, upload-time = "2025-05-17T17:20:47.139Z" }, + { url = "https://files.pythonhosted.org/packages/6e/4e/a066527e079fc5002390c8acdd3aca431e6ea0a50ffd7201551175b47323/pycryptodome-3.23.0-cp37-abi3-macosx_10_9_x86_64.whl", hash = "sha256:cfb5cd445280c5b0a4e6187a7ce8de5a07b5f3f897f235caa11f1f435f182843", size = 1640362, upload-time = "2025-05-17T17:20:50.392Z" }, + { url = "https://files.pythonhosted.org/packages/50/52/adaf4c8c100a8c49d2bd058e5b551f73dfd8cb89eb4911e25a0c469b6b4e/pycryptodome-3.23.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:67bd81fcbe34f43ad9422ee8fd4843c8e7198dd88dd3d40e6de42ee65fbe1490", size = 2182625, upload-time = "2025-05-17T17:20:52.866Z" }, + { url = "https://files.pythonhosted.org/packages/5f/e9/a09476d436d0ff1402ac3867d933c61805ec2326c6ea557aeeac3825604e/pycryptodome-3.23.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8987bd3307a39bc03df5c8e0e3d8be0c4c3518b7f044b0f4c15d1aa78f52575", size = 2268954, upload-time = "2025-05-17T17:20:55.027Z" }, + { url = "https://files.pythonhosted.org/packages/f9/c5/ffe6474e0c551d54cab931918127c46d70cab8f114e0c2b5a3c071c2f484/pycryptodome-3.23.0-cp37-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aa0698f65e5b570426fc31b8162ed4603b0c2841cbb9088e2b01641e3065915b", size = 2308534, upload-time = "2025-05-17T17:20:57.279Z" }, + { url = "https://files.pythonhosted.org/packages/18/28/e199677fc15ecf43010f2463fde4c1a53015d1fe95fb03bca2890836603a/pycryptodome-3.23.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:53ecbafc2b55353edcebd64bf5da94a2a2cdf5090a6915bcca6eca6cc452585a", size = 2181853, upload-time = "2025-05-17T17:20:59.322Z" }, + { url = "https://files.pythonhosted.org/packages/ce/ea/4fdb09f2165ce1365c9eaefef36625583371ee514db58dc9b65d3a255c4c/pycryptodome-3.23.0-cp37-abi3-musllinux_1_2_i686.whl", hash = "sha256:156df9667ad9f2ad26255926524e1c136d6664b741547deb0a86a9acf5ea631f", size = 2342465, upload-time = "2025-05-17T17:21:03.83Z" }, + { url = "https://files.pythonhosted.org/packages/22/82/6edc3fc42fe9284aead511394bac167693fb2b0e0395b28b8bedaa07ef04/pycryptodome-3.23.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:dea827b4d55ee390dc89b2afe5927d4308a8b538ae91d9c6f7a5090f397af1aa", size = 2267414, upload-time = "2025-05-17T17:21:06.72Z" }, + { url = "https://files.pythonhosted.org/packages/59/fe/aae679b64363eb78326c7fdc9d06ec3de18bac68be4b612fc1fe8902693c/pycryptodome-3.23.0-cp37-abi3-win32.whl", hash = "sha256:507dbead45474b62b2bbe318eb1c4c8ee641077532067fec9c1aa82c31f84886", size = 1768484, upload-time = "2025-05-17T17:21:08.535Z" }, + { url = "https://files.pythonhosted.org/packages/54/2f/e97a1b8294db0daaa87012c24a7bb714147c7ade7656973fd6c736b484ff/pycryptodome-3.23.0-cp37-abi3-win_amd64.whl", hash = "sha256:c75b52aacc6c0c260f204cbdd834f76edc9fb0d8e0da9fbf8352ef58202564e2", size = 1799636, upload-time = "2025-05-17T17:21:10.393Z" }, + { url = "https://files.pythonhosted.org/packages/18/3d/f9441a0d798bf2b1e645adc3265e55706aead1255ccdad3856dbdcffec14/pycryptodome-3.23.0-cp37-abi3-win_arm64.whl", hash = "sha256:11eeeb6917903876f134b56ba11abe95c0b0fd5e3330def218083c7d98bbcb3c", size = 1703675, upload-time = "2025-05-17T17:21:13.146Z" }, ] [[package]] diff --git a/web/.storybook/main.ts b/web/.storybook/main.ts index ca56261431..918860c786 100644 --- a/web/.storybook/main.ts +++ b/web/.storybook/main.ts @@ -1,27 +1,15 @@ -import type { StorybookConfig } from '@storybook/nextjs' -import path from 'node:path' -import { fileURLToPath } from 'node:url' - -const storybookDir = path.dirname(fileURLToPath(import.meta.url)) +import type { StorybookConfig } from '@storybook/nextjs-vite' const config: StorybookConfig = { stories: ['../app/components/**/*.stories.@(js|jsx|mjs|ts|tsx)'], addons: [ - '@storybook/addon-onboarding', + // Not working with Storybook Vite framework + // '@storybook/addon-onboarding', '@storybook/addon-links', '@storybook/addon-docs', '@chromatic-com/storybook', ], - framework: { - name: '@storybook/nextjs', - options: { - builder: { - useSWC: true, - lazyCompilation: false, - }, - nextConfigPath: undefined, - }, - }, + framework: '@storybook/nextjs-vite', staticDirs: ['../public'], core: { disableWhatsNewNotifications: true, @@ -29,17 +17,5 @@ const config: StorybookConfig = { docs: { defaultName: 'Documentation', }, - webpackFinal: async (config) => { - // Add alias to mock problematic modules with circular dependencies - config.resolve = config.resolve || {} - config.resolve.alias = { - ...config.resolve.alias, - // Mock the plugin index files to avoid circular dependencies - [path.resolve(storybookDir, '../app/components/base/prompt-editor/plugins/context-block/index.tsx')]: path.resolve(storybookDir, '__mocks__/context-block.tsx'), - [path.resolve(storybookDir, '../app/components/base/prompt-editor/plugins/history-block/index.tsx')]: path.resolve(storybookDir, '__mocks__/history-block.tsx'), - [path.resolve(storybookDir, '../app/components/base/prompt-editor/plugins/query-block/index.tsx')]: path.resolve(storybookDir, '__mocks__/query-block.tsx'), - } - return config - }, } export default config diff --git a/web/app/components/app/overview/settings/index.spec.tsx b/web/app/components/app/overview/settings/index.spec.tsx index 776c55d149..c9cbe0b724 100644 --- a/web/app/components/app/overview/settings/index.spec.tsx +++ b/web/app/components/app/overview/settings/index.spec.tsx @@ -1,3 +1,6 @@ +/** + * @vitest-environment jsdom + */ import type { ReactNode } from 'react' import type { ModalContextState } from '@/context/modal-context' import type { ProviderContextState } from '@/context/provider-context' diff --git a/web/app/components/base/action-button/index.stories.tsx b/web/app/components/base/action-button/index.stories.tsx index 07e0592374..d6f0767faa 100644 --- a/web/app/components/base/action-button/index.stories.tsx +++ b/web/app/components/base/action-button/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { RiAddLine, RiDeleteBinLine, RiEditLine, RiMore2Fill, RiSaveLine, RiShareLine } from '@remixicon/react' import ActionButton, { ActionButtonState } from '.' diff --git a/web/app/components/base/agent-log-modal/index.stories.tsx b/web/app/components/base/agent-log-modal/index.stories.tsx index 781782af8d..87318848b4 100644 --- a/web/app/components/base/agent-log-modal/index.stories.tsx +++ b/web/app/components/base/agent-log-modal/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { IChatItem } from '@/app/components/base/chat/chat/type' import type { AgentLogDetailResponse } from '@/models/log' import { useEffect, useRef } from 'react' diff --git a/web/app/components/base/answer-icon/index.stories.tsx b/web/app/components/base/answer-icon/index.stories.tsx index 0928d9cda6..d5de350a40 100644 --- a/web/app/components/base/answer-icon/index.stories.tsx +++ b/web/app/components/base/answer-icon/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { ReactNode } from 'react' import AnswerIcon from '.' diff --git a/web/app/components/base/app-icon-picker/index.stories.tsx b/web/app/components/base/app-icon-picker/index.stories.tsx index 43abfccc39..08e9d69f32 100644 --- a/web/app/components/base/app-icon-picker/index.stories.tsx +++ b/web/app/components/base/app-icon-picker/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { AppIconSelection } from '.' import { useState } from 'react' import AppIconPicker from '.' diff --git a/web/app/components/base/app-icon/index.stories.tsx b/web/app/components/base/app-icon/index.stories.tsx index 9fdffb54b0..a645471254 100644 --- a/web/app/components/base/app-icon/index.stories.tsx +++ b/web/app/components/base/app-icon/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { ComponentProps } from 'react' import AppIcon from '.' diff --git a/web/app/components/base/audio-btn/index.stories.tsx b/web/app/components/base/audio-btn/index.stories.tsx index e560b9af99..c760e1366d 100644 --- a/web/app/components/base/audio-btn/index.stories.tsx +++ b/web/app/components/base/audio-btn/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { ComponentProps } from 'react' import { useEffect } from 'react' import AudioBtn from '.' diff --git a/web/app/components/base/audio-gallery/index.stories.tsx b/web/app/components/base/audio-gallery/index.stories.tsx index 539ab9e332..cf22058c9a 100644 --- a/web/app/components/base/audio-gallery/index.stories.tsx +++ b/web/app/components/base/audio-gallery/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import AudioGallery from '.' const AUDIO_SOURCES = [ diff --git a/web/app/components/base/auto-height-textarea/index.stories.tsx b/web/app/components/base/auto-height-textarea/index.stories.tsx index d0f36e4736..f5239a49ca 100644 --- a/web/app/components/base/auto-height-textarea/index.stories.tsx +++ b/web/app/components/base/auto-height-textarea/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import AutoHeightTextarea from '.' diff --git a/web/app/components/base/avatar/index.stories.tsx b/web/app/components/base/avatar/index.stories.tsx index fdc9bf8281..5e392640ca 100644 --- a/web/app/components/base/avatar/index.stories.tsx +++ b/web/app/components/base/avatar/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import Avatar from '.' const meta = { diff --git a/web/app/components/base/badge/index.stories.tsx b/web/app/components/base/badge/index.stories.tsx index e1fe8cb271..b2ab794087 100644 --- a/web/app/components/base/badge/index.stories.tsx +++ b/web/app/components/base/badge/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import Badge from '../badge' const meta = { diff --git a/web/app/components/base/block-input/index.stories.tsx b/web/app/components/base/block-input/index.stories.tsx index d05cc221b6..484b917c75 100644 --- a/web/app/components/base/block-input/index.stories.tsx +++ b/web/app/components/base/block-input/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import BlockInput from '.' diff --git a/web/app/components/base/button/add-button.stories.tsx b/web/app/components/base/button/add-button.stories.tsx index edd52b2b78..5181309f2c 100644 --- a/web/app/components/base/button/add-button.stories.tsx +++ b/web/app/components/base/button/add-button.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import AddButton from './add-button' const meta = { diff --git a/web/app/components/base/button/index.stories.tsx b/web/app/components/base/button/index.stories.tsx index 02d20b4af4..25bd5957e1 100644 --- a/web/app/components/base/button/index.stories.tsx +++ b/web/app/components/base/button/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { RocketLaunchIcon } from '@heroicons/react/20/solid' import { Button } from '.' diff --git a/web/app/components/base/button/sync-button.stories.tsx b/web/app/components/base/button/sync-button.stories.tsx index dcfbf6daf3..5a5c078ec1 100644 --- a/web/app/components/base/button/sync-button.stories.tsx +++ b/web/app/components/base/button/sync-button.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import SyncButton from './sync-button' const meta = { diff --git a/web/app/components/base/chat/chat/answer/index.stories.tsx b/web/app/components/base/chat/chat/answer/index.stories.tsx index a8e42b7ad3..f39746efb2 100644 --- a/web/app/components/base/chat/chat/answer/index.stories.tsx +++ b/web/app/components/base/chat/chat/answer/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { ChatItem } from '../../types' import { WorkflowRunningStatus } from '@/app/components/workflow/types' import Answer from '.' diff --git a/web/app/components/base/chat/chat/hooks.multimodal.spec.ts b/web/app/components/base/chat/chat/hooks.multimodal.spec.ts new file mode 100644 index 0000000000..2975d62887 --- /dev/null +++ b/web/app/components/base/chat/chat/hooks.multimodal.spec.ts @@ -0,0 +1,178 @@ +/** + * Tests for multimodal image file handling in chat hooks. + * Tests the file object conversion logic without full hook integration. + */ + +describe('Multimodal File Handling', () => { + describe('File type to MIME type mapping', () => { + it('should map image to image/png', () => { + const fileType: string = 'image' + const expectedMime = 'image/png' + const mimeType = fileType === 'image' ? 'image/png' : 'application/octet-stream' + expect(mimeType).toBe(expectedMime) + }) + + it('should map video to video/mp4', () => { + const fileType: string = 'video' + const expectedMime = 'video/mp4' + const mimeType = fileType === 'video' ? 'video/mp4' : 'application/octet-stream' + expect(mimeType).toBe(expectedMime) + }) + + it('should map audio to audio/mpeg', () => { + const fileType: string = 'audio' + const expectedMime = 'audio/mpeg' + const mimeType = fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream' + expect(mimeType).toBe(expectedMime) + }) + + it('should map unknown to application/octet-stream', () => { + const fileType: string = 'unknown' + const expectedMime = 'application/octet-stream' + const mimeType = ['image', 'video', 'audio'].includes(fileType) ? 'image/png' : 'application/octet-stream' + expect(mimeType).toBe(expectedMime) + }) + }) + + describe('TransferMethod selection', () => { + it('should select remote_url for images', () => { + const fileType: string = 'image' + const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file' + expect(transferMethod).toBe('remote_url') + }) + + it('should select local_file for non-images', () => { + const fileType: string = 'video' + const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file' + expect(transferMethod).toBe('local_file') + }) + }) + + describe('File extension mapping', () => { + it('should use .png extension for images', () => { + const fileType: string = 'image' + const expectedExtension = '.png' + const extension = fileType === 'image' ? 'png' : 'bin' + expect(extension).toBe(expectedExtension.replace('.', '')) + }) + + it('should use .mp4 extension for videos', () => { + const fileType: string = 'video' + const expectedExtension = '.mp4' + const extension = fileType === 'video' ? 'mp4' : 'bin' + expect(extension).toBe(expectedExtension.replace('.', '')) + }) + + it('should use .mp3 extension for audio', () => { + const fileType: string = 'audio' + const expectedExtension = '.mp3' + const extension = fileType === 'audio' ? 'mp3' : 'bin' + expect(extension).toBe(expectedExtension.replace('.', '')) + }) + }) + + describe('File name generation', () => { + it('should generate correct file name for images', () => { + const fileType: string = 'image' + const expectedName = 'generated_image.png' + const fileName = `generated_${fileType}.${fileType === 'image' ? 'png' : 'bin'}` + expect(fileName).toBe(expectedName) + }) + + it('should generate correct file name for videos', () => { + const fileType: string = 'video' + const expectedName = 'generated_video.mp4' + const fileName = `generated_${fileType}.${fileType === 'video' ? 'mp4' : 'bin'}` + expect(fileName).toBe(expectedName) + }) + + it('should generate correct file name for audio', () => { + const fileType: string = 'audio' + const expectedName = 'generated_audio.mp3' + const fileName = `generated_${fileType}.${fileType === 'audio' ? 'mp3' : 'bin'}` + expect(fileName).toBe(expectedName) + }) + }) + + describe('SupportFileType mapping', () => { + it('should map image type to image supportFileType', () => { + const fileType: string = 'image' + const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document' + expect(supportFileType).toBe('image') + }) + + it('should map video type to video supportFileType', () => { + const fileType: string = 'video' + const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document' + expect(supportFileType).toBe('video') + }) + + it('should map audio type to audio supportFileType', () => { + const fileType: string = 'audio' + const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document' + expect(supportFileType).toBe('audio') + }) + + it('should map unknown type to document supportFileType', () => { + const fileType: string = 'unknown' + const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document' + expect(supportFileType).toBe('document') + }) + }) + + describe('File conversion logic', () => { + it('should detect existing transferMethod', () => { + const fileWithTransferMethod = { + id: 'file-123', + transferMethod: 'remote_url' as const, + type: 'image/png', + name: 'test.png', + size: 1024, + supportFileType: 'image', + progress: 100, + } + const hasTransferMethod = 'transferMethod' in fileWithTransferMethod + expect(hasTransferMethod).toBe(true) + }) + + it('should detect missing transferMethod', () => { + const fileWithoutTransferMethod = { + id: 'file-456', + type: 'image', + url: 'http://example.com/image.png', + belongs_to: 'assistant', + } + const hasTransferMethod = 'transferMethod' in fileWithoutTransferMethod + expect(hasTransferMethod).toBe(false) + }) + + it('should create file with size 0 for generated files', () => { + const expectedSize = 0 + expect(expectedSize).toBe(0) + }) + }) + + describe('Agent vs Non-Agent mode logic', () => { + it('should check for agent_thoughts to determine mode', () => { + const agentResponse: { agent_thoughts?: Array> } = { + agent_thoughts: [{}], + } + const isAgentMode = agentResponse.agent_thoughts && agentResponse.agent_thoughts.length > 0 + expect(isAgentMode).toBe(true) + }) + + it('should detect non-agent mode when agent_thoughts is empty', () => { + const nonAgentResponse: { agent_thoughts?: Array> } = { + agent_thoughts: [], + } + const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0 + expect(isAgentMode).toBe(false) + }) + + it('should detect non-agent mode when agent_thoughts is undefined', () => { + const nonAgentResponse: { agent_thoughts?: Array> } = {} + const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0 + expect(isAgentMode).toBeFalsy() + }) + }) +}) diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index e9365d98cb..801382e550 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -745,9 +745,40 @@ export const useChat = ( } }, onFile(file) { + // Convert simple file type to MIME type for non-agent mode + // Backend sends: { id, type: "image", belongs_to, url } + // Frontend expects: { id, type: "image/png", transferMethod, url, uploadedId, supportFileType, name, size } + + // Determine file type for MIME conversion + const fileType = (file as { type?: string }).type || 'image' + + // If file already has transferMethod, use it as base and ensure all required fields exist + // Otherwise, create a new complete file object + const baseFile = ('transferMethod' in file) ? (file as Partial) : null + + const convertedFile: FileEntity = { + id: baseFile?.id || (file as { id: string }).id, + type: baseFile?.type || (fileType === 'image' ? 'image/png' : fileType === 'video' ? 'video/mp4' : fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream'), + transferMethod: (baseFile?.transferMethod as FileEntity['transferMethod']) || (fileType === 'image' ? 'remote_url' : 'local_file'), + uploadedId: baseFile?.uploadedId || (file as { id: string }).id, + supportFileType: baseFile?.supportFileType || (fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'), + progress: baseFile?.progress ?? 100, + name: baseFile?.name || `generated_${fileType}.${fileType === 'image' ? 'png' : fileType === 'video' ? 'mp4' : fileType === 'audio' ? 'mp3' : 'bin'}`, + url: baseFile?.url || (file as { url?: string }).url, + size: baseFile?.size ?? 0, // Generated files don't have a known size + } + + // For agent mode, add files to the last thought const lastThought = responseItem.agent_thoughts?.[responseItem.agent_thoughts?.length - 1] - if (lastThought) - responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(lastThought as any).message_files, file] + if (lastThought) { + const thought = lastThought as { message_files?: FileEntity[] } + responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(thought.message_files ?? []), convertedFile] + } + // For non-agent mode, add files directly to responseItem.message_files + else { + const currentFiles = (responseItem.message_files as FileEntity[] | undefined) ?? [] + responseItem.message_files = [...currentFiles, convertedFile] + } updateCurrentQAOnTree({ placeholderQuestionId, diff --git a/web/app/components/base/chat/chat/question.stories.tsx b/web/app/components/base/chat/chat/question.stories.tsx index 4542dc3ac6..308a096864 100644 --- a/web/app/components/base/chat/chat/question.stories.tsx +++ b/web/app/components/base/chat/chat/question.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { ChatItem } from '../types' import { User } from '@/app/components/base/icons/src/public/avatar' diff --git a/web/app/components/base/checkbox/index.stories.tsx b/web/app/components/base/checkbox/index.stories.tsx index 580d731a7e..0dac4cae15 100644 --- a/web/app/components/base/checkbox/index.stories.tsx +++ b/web/app/components/base/checkbox/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import Checkbox from '.' diff --git a/web/app/components/base/chip/index.stories.tsx b/web/app/components/base/chip/index.stories.tsx index fc43ae8724..5812f97d98 100644 --- a/web/app/components/base/chip/index.stories.tsx +++ b/web/app/components/base/chip/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { Item } from '.' import { useState } from 'react' import Chip from '.' diff --git a/web/app/components/base/confirm/index.stories.tsx b/web/app/components/base/confirm/index.stories.tsx index 12cb46d9e4..6d22bbe87b 100644 --- a/web/app/components/base/confirm/index.stories.tsx +++ b/web/app/components/base/confirm/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import Confirm from '.' import Button from '../button' diff --git a/web/app/components/base/content-dialog/index.stories.tsx b/web/app/components/base/content-dialog/index.stories.tsx index aaebcad1b7..8ddd5c667d 100644 --- a/web/app/components/base/content-dialog/index.stories.tsx +++ b/web/app/components/base/content-dialog/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useEffect, useState } from 'react' import ContentDialog from '.' diff --git a/web/app/components/base/copy-feedback/index.stories.tsx b/web/app/components/base/copy-feedback/index.stories.tsx index 3bab620aec..aa535993f8 100644 --- a/web/app/components/base/copy-feedback/index.stories.tsx +++ b/web/app/components/base/copy-feedback/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import CopyFeedback, { CopyFeedbackNew } from '.' diff --git a/web/app/components/base/copy-icon/index.stories.tsx b/web/app/components/base/copy-icon/index.stories.tsx index 5962773792..dd13343819 100644 --- a/web/app/components/base/copy-icon/index.stories.tsx +++ b/web/app/components/base/copy-icon/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import CopyIcon from '.' const meta = { diff --git a/web/app/components/base/corner-label/index.stories.tsx b/web/app/components/base/corner-label/index.stories.tsx index 1592f94259..dbfab31da0 100644 --- a/web/app/components/base/corner-label/index.stories.tsx +++ b/web/app/components/base/corner-label/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import CornerLabel from '.' const meta = { diff --git a/web/app/components/base/date-and-time-picker/index.stories.tsx b/web/app/components/base/date-and-time-picker/index.stories.tsx index ad057f7969..1ed35afe88 100644 --- a/web/app/components/base/date-and-time-picker/index.stories.tsx +++ b/web/app/components/base/date-and-time-picker/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { DatePickerProps } from './types' import { useState } from 'react' import { fn } from 'storybook/test' diff --git a/web/app/components/base/dialog/index.stories.tsx b/web/app/components/base/dialog/index.stories.tsx index f573b856d3..af2e669535 100644 --- a/web/app/components/base/dialog/index.stories.tsx +++ b/web/app/components/base/dialog/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useEffect, useState } from 'react' import Dialog from '.' diff --git a/web/app/components/base/divider/index.stories.tsx b/web/app/components/base/divider/index.stories.tsx index c634173202..2ae00eca47 100644 --- a/web/app/components/base/divider/index.stories.tsx +++ b/web/app/components/base/divider/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import Divider from '.' const meta = { diff --git a/web/app/components/base/drawer-plus/index.stories.tsx b/web/app/components/base/drawer-plus/index.stories.tsx index c79dd8af8a..4bdfef2ab3 100644 --- a/web/app/components/base/drawer-plus/index.stories.tsx +++ b/web/app/components/base/drawer-plus/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import { fn } from 'storybook/test' import DrawerPlus from '.' diff --git a/web/app/components/base/drawer/index.stories.tsx b/web/app/components/base/drawer/index.stories.tsx index cfcfbf6a2e..ca7b3bc243 100644 --- a/web/app/components/base/drawer/index.stories.tsx +++ b/web/app/components/base/drawer/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import { fn } from 'storybook/test' import Drawer from '.' diff --git a/web/app/components/base/dropdown/index.stories.tsx b/web/app/components/base/dropdown/index.stories.tsx index 4b08d54c47..7cb7f820f6 100644 --- a/web/app/components/base/dropdown/index.stories.tsx +++ b/web/app/components/base/dropdown/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { Item } from '.' import { useState } from 'react' import { fn } from 'storybook/test' diff --git a/web/app/components/base/effect/index.stories.tsx b/web/app/components/base/effect/index.stories.tsx index a7f316fe7e..36a0e668cf 100644 --- a/web/app/components/base/effect/index.stories.tsx +++ b/web/app/components/base/effect/index.stories.tsx @@ -1,5 +1,5 @@ /* eslint-disable tailwindcss/classnames-order */ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import Effect from '.' const meta = { diff --git a/web/app/components/base/emoji-picker/Inner.stories.tsx b/web/app/components/base/emoji-picker/Inner.stories.tsx index 642b4092e8..be0e993cce 100644 --- a/web/app/components/base/emoji-picker/Inner.stories.tsx +++ b/web/app/components/base/emoji-picker/Inner.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import EmojiPickerInner from './Inner' diff --git a/web/app/components/base/emoji-picker/index.stories.tsx b/web/app/components/base/emoji-picker/index.stories.tsx index beadcc0898..0649f32e68 100644 --- a/web/app/components/base/emoji-picker/index.stories.tsx +++ b/web/app/components/base/emoji-picker/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import EmojiPicker from '.' diff --git a/web/app/components/base/features/index.stories.tsx b/web/app/components/base/features/index.stories.tsx index c94d4faa1d..99d8df097b 100644 --- a/web/app/components/base/features/index.stories.tsx +++ b/web/app/components/base/features/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { Features } from './types' import { useState } from 'react' import { FeaturesProvider } from '.' diff --git a/web/app/components/base/file-icon/index.stories.tsx b/web/app/components/base/file-icon/index.stories.tsx index 21f9c3111c..4f0ec61c94 100644 --- a/web/app/components/base/file-icon/index.stories.tsx +++ b/web/app/components/base/file-icon/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import FileIcon from '.' const meta = { diff --git a/web/app/components/base/file-uploader/file-image-render.stories.tsx b/web/app/components/base/file-uploader/file-image-render.stories.tsx index 132c0b61a3..ca051e4b27 100644 --- a/web/app/components/base/file-uploader/file-image-render.stories.tsx +++ b/web/app/components/base/file-uploader/file-image-render.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import FileImageRender from './file-image-render' const SAMPLE_IMAGE = 'data:image/svg+xml;utf8,Preview' diff --git a/web/app/components/base/file-uploader/file-list.stories.tsx b/web/app/components/base/file-uploader/file-list.stories.tsx index 37c828c7f7..560202779a 100644 --- a/web/app/components/base/file-uploader/file-list.stories.tsx +++ b/web/app/components/base/file-uploader/file-list.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { FileEntity } from './types' import { useState } from 'react' import { SupportUploadFileTypes } from '@/app/components/workflow/types' diff --git a/web/app/components/base/file-uploader/file-type-icon.stories.tsx b/web/app/components/base/file-uploader/file-type-icon.stories.tsx index c317afab68..6a6df069b1 100644 --- a/web/app/components/base/file-uploader/file-type-icon.stories.tsx +++ b/web/app/components/base/file-uploader/file-type-icon.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import FileTypeIcon from './file-type-icon' import { FileAppearanceTypeEnum } from './types' diff --git a/web/app/components/base/file-uploader/file-uploader-in-attachment/index.stories.tsx b/web/app/components/base/file-uploader/file-uploader-in-attachment/index.stories.tsx index aa53ff17d9..f10fce1173 100644 --- a/web/app/components/base/file-uploader/file-uploader-in-attachment/index.stories.tsx +++ b/web/app/components/base/file-uploader/file-uploader-in-attachment/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { FileEntity } from '../types' import type { FileUpload } from '@/app/components/base/features/types' import { useState } from 'react' diff --git a/web/app/components/base/file-uploader/file-uploader-in-chat-input/index.stories.tsx b/web/app/components/base/file-uploader/file-uploader-in-chat-input/index.stories.tsx index e094a48803..632fc40136 100644 --- a/web/app/components/base/file-uploader/file-uploader-in-chat-input/index.stories.tsx +++ b/web/app/components/base/file-uploader/file-uploader-in-chat-input/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { FileEntity } from '../types' import type { FileUpload } from '@/app/components/base/features/types' import { useState } from 'react' diff --git a/web/app/components/base/float-right-container/index.stories.tsx b/web/app/components/base/float-right-container/index.stories.tsx index dcc55d4cf6..5887afd1e3 100644 --- a/web/app/components/base/float-right-container/index.stories.tsx +++ b/web/app/components/base/float-right-container/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import { fn } from 'storybook/test' import FloatRightContainer from '.' diff --git a/web/app/components/base/form/index.stories.tsx b/web/app/components/base/form/index.stories.tsx index 41e2e9deb8..3a6e052f0e 100644 --- a/web/app/components/base/form/index.stories.tsx +++ b/web/app/components/base/form/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { FormStoryRender } from '../../../../.storybook/utils/form-story-wrapper' import type { FormSchema } from './types' import { useStore } from '@tanstack/react-form' diff --git a/web/app/components/base/fullscreen-modal/index.stories.tsx b/web/app/components/base/fullscreen-modal/index.stories.tsx index 72fd28df66..3285b1c4ea 100644 --- a/web/app/components/base/fullscreen-modal/index.stories.tsx +++ b/web/app/components/base/fullscreen-modal/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import FullScreenModal from '.' diff --git a/web/app/components/base/grid-mask/index.stories.tsx b/web/app/components/base/grid-mask/index.stories.tsx index 1b67a1510d..24028f4347 100644 --- a/web/app/components/base/grid-mask/index.stories.tsx +++ b/web/app/components/base/grid-mask/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import GridMask from '.' const meta = { diff --git a/web/app/components/base/icons/icon-gallery.stories.tsx b/web/app/components/base/icons/icon-gallery.stories.tsx index 55322a7ea3..8d49f70ce2 100644 --- a/web/app/components/base/icons/icon-gallery.stories.tsx +++ b/web/app/components/base/icons/icon-gallery.stories.tsx @@ -1,9 +1,9 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +/// +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import * as React from 'react' -declare const require: any - type IconComponent = React.ComponentType> +type IconModule = { default: IconComponent } type IconEntry = { name: string @@ -12,18 +12,16 @@ type IconEntry = { Component: IconComponent } -const iconContext = require.context('./src', true, /\.tsx$/) +const iconModules: Record = import.meta.glob('./src/**/*.tsx', { eager: true }) -const iconEntries: IconEntry[] = iconContext - .keys() - .filter((key: string) => !key.endsWith('.stories.tsx') && !key.endsWith('.spec.tsx')) - .map((key: string) => { - const mod = iconContext(key) - const Component = mod.default as IconComponent | undefined +const iconEntries: IconEntry[] = Object.entries(iconModules) + .filter(([key]) => !key.endsWith('.stories.tsx') && !key.endsWith('.spec.tsx')) + .map(([key, mod]) => { + const Component = mod.default if (!Component) return null - const relativePath = key.replace(/^\.\//, '') + const relativePath = key.replace(/^\.\/src\//, '') const path = `app/components/base/icons/src/${relativePath}` const parts = relativePath.split('/') const fileName = parts.pop() || '' diff --git a/web/app/components/base/image-gallery/index.stories.tsx b/web/app/components/base/image-gallery/index.stories.tsx index c1b463170c..d3be60fd56 100644 --- a/web/app/components/base/image-gallery/index.stories.tsx +++ b/web/app/components/base/image-gallery/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import ImageGallery from '.' const IMAGE_SOURCES = [ diff --git a/web/app/components/base/image-uploader/image-list.stories.tsx b/web/app/components/base/image-uploader/image-list.stories.tsx index eabdb27aab..efafa6b4f0 100644 --- a/web/app/components/base/image-uploader/image-list.stories.tsx +++ b/web/app/components/base/image-uploader/image-list.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { ImageFile } from '@/types/app' import { useMemo, useState } from 'react' import { TransferMethod } from '@/types/app' diff --git a/web/app/components/base/inline-delete-confirm/index.stories.tsx b/web/app/components/base/inline-delete-confirm/index.stories.tsx index 56fa5a9431..c352c512a5 100644 --- a/web/app/components/base/inline-delete-confirm/index.stories.tsx +++ b/web/app/components/base/inline-delete-confirm/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import { fn } from 'storybook/test' import InlineDeleteConfirm from '.' diff --git a/web/app/components/base/input-number/index.stories.tsx b/web/app/components/base/input-number/index.stories.tsx index 997003d53f..4b7bebf216 100644 --- a/web/app/components/base/input-number/index.stories.tsx +++ b/web/app/components/base/input-number/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import { InputNumber } from '.' diff --git a/web/app/components/base/input/index.stories.tsx b/web/app/components/base/input/index.stories.tsx index 59861435ef..860f65dfc7 100644 --- a/web/app/components/base/input/index.stories.tsx +++ b/web/app/components/base/input/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import Input from '.' diff --git a/web/app/components/base/linked-apps-panel/index.stories.tsx b/web/app/components/base/linked-apps-panel/index.stories.tsx index 07787173e0..fb9d7c4fba 100644 --- a/web/app/components/base/linked-apps-panel/index.stories.tsx +++ b/web/app/components/base/linked-apps-panel/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { RelatedApp } from '@/models/datasets' import { AppModeEnum } from '@/types/app' import LinkedAppsPanel from '.' diff --git a/web/app/components/base/list-empty/index.stories.tsx b/web/app/components/base/list-empty/index.stories.tsx index 36c0e3c7a7..5079337687 100644 --- a/web/app/components/base/list-empty/index.stories.tsx +++ b/web/app/components/base/list-empty/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import ListEmpty from '.' const meta = { diff --git a/web/app/components/base/loading/index.stories.tsx b/web/app/components/base/loading/index.stories.tsx index f22f87516c..22fe746709 100644 --- a/web/app/components/base/loading/index.stories.tsx +++ b/web/app/components/base/loading/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import Loading from '.' const meta = { diff --git a/web/app/components/base/logo/index.stories.tsx b/web/app/components/base/logo/index.stories.tsx index 7bd63151e4..8347bc48da 100644 --- a/web/app/components/base/logo/index.stories.tsx +++ b/web/app/components/base/logo/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { ReactNode } from 'react' import { ThemeProvider } from 'next-themes' import DifyLogo from './dify-logo' diff --git a/web/app/components/base/markdown-blocks/code-block.stories.tsx b/web/app/components/base/markdown-blocks/code-block.stories.tsx index 98473bdf57..b9e92ada22 100644 --- a/web/app/components/base/markdown-blocks/code-block.stories.tsx +++ b/web/app/components/base/markdown-blocks/code-block.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import CodeBlock from './code-block' const SAMPLE_CODE = `const greet = (name: string) => { diff --git a/web/app/components/base/markdown-blocks/think-block.stories.tsx b/web/app/components/base/markdown-blocks/think-block.stories.tsx index 6d5c8dc418..23713fb263 100644 --- a/web/app/components/base/markdown-blocks/think-block.stories.tsx +++ b/web/app/components/base/markdown-blocks/think-block.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import { ChatContextProvider } from '@/app/components/base/chat/chat/context' import ThinkBlock from './think-block' diff --git a/web/app/components/base/markdown/index.stories.tsx b/web/app/components/base/markdown/index.stories.tsx index 8c940e01cf..289dfda147 100644 --- a/web/app/components/base/markdown/index.stories.tsx +++ b/web/app/components/base/markdown/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import { Markdown } from '.' diff --git a/web/app/components/base/mermaid/index.stories.tsx b/web/app/components/base/mermaid/index.stories.tsx index 73030d7905..70c259db08 100644 --- a/web/app/components/base/mermaid/index.stories.tsx +++ b/web/app/components/base/mermaid/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import Flowchart from '.' diff --git a/web/app/components/base/message-log-modal/index.stories.tsx b/web/app/components/base/message-log-modal/index.stories.tsx index 6c29584f7d..e370bd3338 100644 --- a/web/app/components/base/message-log-modal/index.stories.tsx +++ b/web/app/components/base/message-log-modal/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { IChatItem } from '@/app/components/base/chat/chat/type' import type { WorkflowRunDetailResponse } from '@/models/log' import type { NodeTracing, NodeTracingListResponse } from '@/types/workflow' diff --git a/web/app/components/base/modal-like-wrap/index.stories.tsx b/web/app/components/base/modal-like-wrap/index.stories.tsx index c7d66b8e6a..9e5ecd6d15 100644 --- a/web/app/components/base/modal-like-wrap/index.stories.tsx +++ b/web/app/components/base/modal-like-wrap/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import ModalLikeWrap from '.' const meta = { diff --git a/web/app/components/base/modal/index.stories.tsx b/web/app/components/base/modal/index.stories.tsx index c0ea31eb42..91bb851f20 100644 --- a/web/app/components/base/modal/index.stories.tsx +++ b/web/app/components/base/modal/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useEffect, useState } from 'react' import Modal from '.' diff --git a/web/app/components/base/modal/modal.stories.tsx b/web/app/components/base/modal/modal.stories.tsx index adb80aebe6..2ddf706866 100644 --- a/web/app/components/base/modal/modal.stories.tsx +++ b/web/app/components/base/modal/modal.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useEffect, useState } from 'react' import Modal from './modal' diff --git a/web/app/components/base/new-audio-button/index.stories.tsx b/web/app/components/base/new-audio-button/index.stories.tsx index 0bc8accec1..44a7e2616a 100644 --- a/web/app/components/base/new-audio-button/index.stories.tsx +++ b/web/app/components/base/new-audio-button/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { ComponentProps } from 'react' import { useEffect } from 'react' import AudioBtn from '.' diff --git a/web/app/components/base/notion-connector/index.stories.tsx b/web/app/components/base/notion-connector/index.stories.tsx index eb8b17df3f..d43e4a2ae6 100644 --- a/web/app/components/base/notion-connector/index.stories.tsx +++ b/web/app/components/base/notion-connector/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import NotionConnector from '.' const meta = { diff --git a/web/app/components/base/notion-icon/index.stories.tsx b/web/app/components/base/notion-icon/index.stories.tsx index 5389a6f935..68f400f363 100644 --- a/web/app/components/base/notion-icon/index.stories.tsx +++ b/web/app/components/base/notion-icon/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import NotionIcon from '.' const meta = { diff --git a/web/app/components/base/notion-page-selector/index.stories.tsx b/web/app/components/base/notion-page-selector/index.stories.tsx index 9b2c44687a..d338793363 100644 --- a/web/app/components/base/notion-page-selector/index.stories.tsx +++ b/web/app/components/base/notion-page-selector/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { DataSourceCredential } from '@/app/components/header/account-setting/data-source-page-new/types' import type { NotionPage } from '@/models/common' import { useEffect, useMemo, useState } from 'react' diff --git a/web/app/components/base/pagination/index.stories.tsx b/web/app/components/base/pagination/index.stories.tsx index 4ad5488b96..e53f285bb2 100644 --- a/web/app/components/base/pagination/index.stories.tsx +++ b/web/app/components/base/pagination/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useMemo, useState } from 'react' import Pagination from '.' diff --git a/web/app/components/base/param-item/index.stories.tsx b/web/app/components/base/param-item/index.stories.tsx index 1b5b233f6d..0cf6c40146 100644 --- a/web/app/components/base/param-item/index.stories.tsx +++ b/web/app/components/base/param-item/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import ParamItem from '.' diff --git a/web/app/components/base/popover/index.stories.tsx b/web/app/components/base/popover/index.stories.tsx index ab57bc15cc..0076c1852b 100644 --- a/web/app/components/base/popover/index.stories.tsx +++ b/web/app/components/base/popover/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import CustomPopover from '.' diff --git a/web/app/components/base/portal-to-follow-elem/index.stories.tsx b/web/app/components/base/portal-to-follow-elem/index.stories.tsx index bbe5e9d206..c9c43f34c6 100644 --- a/web/app/components/base/portal-to-follow-elem/index.stories.tsx +++ b/web/app/components/base/portal-to-follow-elem/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import { PortalToFollowElem, diff --git a/web/app/components/base/portal-to-follow-elem/index.tsx b/web/app/components/base/portal-to-follow-elem/index.tsx index a656ab5308..c57fba9dd0 100644 --- a/web/app/components/base/portal-to-follow-elem/index.tsx +++ b/web/app/components/base/portal-to-follow-elem/index.tsx @@ -61,9 +61,12 @@ export function usePortalToFollowElem({ }), shift({ padding: 5 }), size({ - apply({ rects, elements }) { - if (triggerPopupSameWidth) - elements.floating.style.width = `${rects.reference.width}px` + apply({ rects, elements, availableHeight }) { + Object.assign(elements.floating.style, { + maxHeight: `${Math.max(0, availableHeight)}px`, + overflowY: 'auto', + ...(triggerPopupSameWidth && { width: `${rects.reference.width}px` }), + }) }, }), ], diff --git a/web/app/components/base/premium-badge/index.stories.tsx b/web/app/components/base/premium-badge/index.stories.tsx index c1f6ede869..9d892cbf61 100644 --- a/web/app/components/base/premium-badge/index.stories.tsx +++ b/web/app/components/base/premium-badge/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import PremiumBadge from '.' const colors: Array['color']>> = ['blue', 'indigo', 'gray', 'orange'] diff --git a/web/app/components/base/progress-bar/progress-circle.stories.tsx b/web/app/components/base/progress-bar/progress-circle.stories.tsx index 10f8ce6c28..1dd52d5683 100644 --- a/web/app/components/base/progress-bar/progress-circle.stories.tsx +++ b/web/app/components/base/progress-bar/progress-circle.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import ProgressCircle from './progress-circle' diff --git a/web/app/components/base/prompt-editor/index.stories.tsx b/web/app/components/base/prompt-editor/index.stories.tsx index f7b0812ff7..92f23345be 100644 --- a/web/app/components/base/prompt-editor/index.stories.tsx +++ b/web/app/components/base/prompt-editor/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' // Mock component to avoid complex initialization issues diff --git a/web/app/components/base/prompt-log-modal/index.stories.tsx b/web/app/components/base/prompt-log-modal/index.stories.tsx index 39fab32030..42f90e6a57 100644 --- a/web/app/components/base/prompt-log-modal/index.stories.tsx +++ b/web/app/components/base/prompt-log-modal/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { IChatItem } from '@/app/components/base/chat/chat/type' import { useEffect } from 'react' import { useStore } from '@/app/components/app/store' diff --git a/web/app/components/base/qrcode/index.stories.tsx b/web/app/components/base/qrcode/index.stories.tsx index 312dc6a5a8..700a71fceb 100644 --- a/web/app/components/base/qrcode/index.stories.tsx +++ b/web/app/components/base/qrcode/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import ShareQRCode from '.' const QRDemo = ({ diff --git a/web/app/components/base/radio-card/index.stories.tsx b/web/app/components/base/radio-card/index.stories.tsx index 3adccfaf0d..40ef7069f5 100644 --- a/web/app/components/base/radio-card/index.stories.tsx +++ b/web/app/components/base/radio-card/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { RiCloudLine, RiCpuLine, RiDatabase2Line, RiLightbulbLine, RiRocketLine, RiShieldLine } from '@remixicon/react' import { useState } from 'react' import RadioCard from '.' diff --git a/web/app/components/base/radio/index.stories.tsx b/web/app/components/base/radio/index.stories.tsx index 1f9f7173fd..61449f1b5f 100644 --- a/web/app/components/base/radio/index.stories.tsx +++ b/web/app/components/base/radio/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import Radio from '.' diff --git a/web/app/components/base/search-input/index.stories.tsx b/web/app/components/base/search-input/index.stories.tsx index 4b5323a8db..b27a6c2fb5 100644 --- a/web/app/components/base/search-input/index.stories.tsx +++ b/web/app/components/base/search-input/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import SearchInput from '.' diff --git a/web/app/components/base/segmented-control/index.stories.tsx b/web/app/components/base/segmented-control/index.stories.tsx index d7b41b3921..6ccb3e293a 100644 --- a/web/app/components/base/segmented-control/index.stories.tsx +++ b/web/app/components/base/segmented-control/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { RiLineChartLine, RiListCheck2, RiRobot2Line } from '@remixicon/react' import { useState } from 'react' import { SegmentedControl } from '.' diff --git a/web/app/components/base/select/index.stories.tsx b/web/app/components/base/select/index.stories.tsx index 793f21343a..5a7fae6cf7 100644 --- a/web/app/components/base/select/index.stories.tsx +++ b/web/app/components/base/select/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { Item } from '.' import { useState } from 'react' import Select, { PortalSelect, SimpleSelect } from '.' diff --git a/web/app/components/base/simple-pie-chart/index.stories.tsx b/web/app/components/base/simple-pie-chart/index.stories.tsx index d08c8fa0ce..05bf603629 100644 --- a/web/app/components/base/simple-pie-chart/index.stories.tsx +++ b/web/app/components/base/simple-pie-chart/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useMemo, useState } from 'react' import SimplePieChart from '.' diff --git a/web/app/components/base/skeleton/index.stories.tsx b/web/app/components/base/skeleton/index.stories.tsx index b5ea649b34..e767852406 100644 --- a/web/app/components/base/skeleton/index.stories.tsx +++ b/web/app/components/base/skeleton/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { SkeletonContainer, SkeletonPoint, diff --git a/web/app/components/base/slider/index.stories.tsx b/web/app/components/base/slider/index.stories.tsx index 7640e06c09..bde937ffad 100644 --- a/web/app/components/base/slider/index.stories.tsx +++ b/web/app/components/base/slider/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import Slider from '.' diff --git a/web/app/components/base/sort/index.stories.tsx b/web/app/components/base/sort/index.stories.tsx index 3ecf9983a0..46b46e8d1e 100644 --- a/web/app/components/base/sort/index.stories.tsx +++ b/web/app/components/base/sort/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useMemo, useState } from 'react' import Sort from '.' diff --git a/web/app/components/base/spinner/index.stories.tsx b/web/app/components/base/spinner/index.stories.tsx index d4a481e55a..f5a8e83059 100644 --- a/web/app/components/base/spinner/index.stories.tsx +++ b/web/app/components/base/spinner/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import Spinner from '.' diff --git a/web/app/components/base/svg-gallery/index.stories.tsx b/web/app/components/base/svg-gallery/index.stories.tsx index 65da97d243..ccbf320b52 100644 --- a/web/app/components/base/svg-gallery/index.stories.tsx +++ b/web/app/components/base/svg-gallery/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import SVGRenderer from '.' const SAMPLE_SVG = ` diff --git a/web/app/components/base/svg/index.stories.tsx b/web/app/components/base/svg/index.stories.tsx index 3c6a7ca0a3..3d215dad5a 100644 --- a/web/app/components/base/svg/index.stories.tsx +++ b/web/app/components/base/svg/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import SVGBtn from '.' diff --git a/web/app/components/base/switch/index.stories.tsx b/web/app/components/base/switch/index.stories.tsx index 941ebaf172..7fe7d1fbec 100644 --- a/web/app/components/base/switch/index.stories.tsx +++ b/web/app/components/base/switch/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import Switch from '.' diff --git a/web/app/components/base/tab-header/index.stories.tsx b/web/app/components/base/tab-header/index.stories.tsx index 3d7f2bf31b..2b45907788 100644 --- a/web/app/components/base/tab-header/index.stories.tsx +++ b/web/app/components/base/tab-header/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { ITabHeaderProps } from '.' import { useState } from 'react' import TabHeader from '.' diff --git a/web/app/components/base/tab-slider-new/index.stories.tsx b/web/app/components/base/tab-slider-new/index.stories.tsx index d0a412732a..56f9df4e27 100644 --- a/web/app/components/base/tab-slider-new/index.stories.tsx +++ b/web/app/components/base/tab-slider-new/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { RiSparklingFill, RiTerminalBoxLine } from '@remixicon/react' import { useState } from 'react' import TabSliderNew from '.' diff --git a/web/app/components/base/tab-slider-plain/index.stories.tsx b/web/app/components/base/tab-slider-plain/index.stories.tsx index dd8c7e0d30..e621ba43aa 100644 --- a/web/app/components/base/tab-slider-plain/index.stories.tsx +++ b/web/app/components/base/tab-slider-plain/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import TabSliderPlain from '.' diff --git a/web/app/components/base/tab-slider/index.stories.tsx b/web/app/components/base/tab-slider/index.stories.tsx index 703116fe19..0db53491bd 100644 --- a/web/app/components/base/tab-slider/index.stories.tsx +++ b/web/app/components/base/tab-slider/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useEffect, useState } from 'react' import TabSlider from '.' diff --git a/web/app/components/base/tag-input/index.stories.tsx b/web/app/components/base/tag-input/index.stories.tsx index cb9551702e..ef1b2e3365 100644 --- a/web/app/components/base/tag-input/index.stories.tsx +++ b/web/app/components/base/tag-input/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import TagInput from '.' diff --git a/web/app/components/base/tag-management/index.stories.tsx b/web/app/components/base/tag-management/index.stories.tsx index e6a088c267..cb62965901 100644 --- a/web/app/components/base/tag-management/index.stories.tsx +++ b/web/app/components/base/tag-management/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import type { Tag } from './constant' import { useEffect, useRef } from 'react' import { ToastProvider } from '@/app/components/base/toast' diff --git a/web/app/components/base/tag/index.stories.tsx b/web/app/components/base/tag/index.stories.tsx index 8ca15c0c8b..219ed4f9e8 100644 --- a/web/app/components/base/tag/index.stories.tsx +++ b/web/app/components/base/tag/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import Tag from '.' const COLORS: Array['color']>> = ['green', 'yellow', 'red', 'gray'] diff --git a/web/app/components/base/textarea/index.stories.tsx b/web/app/components/base/textarea/index.stories.tsx index 7d584368ef..0474e4c93f 100644 --- a/web/app/components/base/textarea/index.stories.tsx +++ b/web/app/components/base/textarea/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' import Textarea from '.' diff --git a/web/app/components/base/toast/index.stories.tsx b/web/app/components/base/toast/index.stories.tsx index 6ef65475cb..4ab9138070 100644 --- a/web/app/components/base/toast/index.stories.tsx +++ b/web/app/components/base/toast/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useCallback } from 'react' import Toast, { ToastProvider, useToastContext } from '.' diff --git a/web/app/components/base/tooltip/index.stories.tsx b/web/app/components/base/tooltip/index.stories.tsx index aeca69464f..9e2ce9977b 100644 --- a/web/app/components/base/tooltip/index.stories.tsx +++ b/web/app/components/base/tooltip/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import Tooltip from '.' const TooltipGrid = () => { diff --git a/web/app/components/base/video-gallery/index.stories.tsx b/web/app/components/base/video-gallery/index.stories.tsx index 7e17ee208c..93aba599ef 100644 --- a/web/app/components/base/video-gallery/index.stories.tsx +++ b/web/app/components/base/video-gallery/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import VideoGallery from '.' const VIDEO_SOURCES = [ diff --git a/web/app/components/base/voice-input/index.stories.tsx b/web/app/components/base/voice-input/index.stories.tsx index e368f4bd51..169c62014d 100644 --- a/web/app/components/base/voice-input/index.stories.tsx +++ b/web/app/components/base/voice-input/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { useState } from 'react' // Mock component since VoiceInput requires browser APIs and service dependencies diff --git a/web/app/components/base/with-input-validation/index.stories.tsx b/web/app/components/base/with-input-validation/index.stories.tsx index 167fa73e84..cb06d45956 100644 --- a/web/app/components/base/with-input-validation/index.stories.tsx +++ b/web/app/components/base/with-input-validation/index.stories.tsx @@ -1,4 +1,4 @@ -import type { Meta, StoryObj } from '@storybook/nextjs' +import type { Meta, StoryObj } from '@storybook/nextjs-vite' import { z } from 'zod' import withValidation from '.' diff --git a/web/app/components/datasets/create/website/watercrawl/index.spec.tsx b/web/app/components/datasets/create/website/watercrawl/index.spec.tsx index 4bb8267cea..646c59eb75 100644 --- a/web/app/components/datasets/create/website/watercrawl/index.spec.tsx +++ b/web/app/components/datasets/create/website/watercrawl/index.spec.tsx @@ -1,3 +1,6 @@ +/** + * @vitest-environment jsdom + */ import type { Mock } from 'vitest' import type { CrawlOptions, CrawlResultItem } from '@/models/datasets' import { fireEvent, render, screen, waitFor } from '@testing-library/react' diff --git a/web/app/components/datasets/documents/detail/completed/child-segment-list.spec.tsx b/web/app/components/datasets/documents/detail/completed/child-segment-list.spec.tsx new file mode 100644 index 0000000000..aa3e300322 --- /dev/null +++ b/web/app/components/datasets/documents/detail/completed/child-segment-list.spec.tsx @@ -0,0 +1,499 @@ +import type { DocumentContextValue } from '@/app/components/datasets/documents/detail/context' +import type { ChildChunkDetail, ChunkingMode, ParentMode } from '@/models/datasets' +import { fireEvent, render, screen } from '@testing-library/react' +import * as React from 'react' +import ChildSegmentList from './child-segment-list' + +// ============================================================================ +// Hoisted Mocks +// ============================================================================ + +const { + mockParentMode, + mockCurrChildChunk, +} = vi.hoisted(() => ({ + mockParentMode: { current: 'paragraph' as ParentMode }, + mockCurrChildChunk: { current: { childChunkInfo: undefined, showModal: false } as { childChunkInfo?: ChildChunkDetail, showModal: boolean } }, +})) + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string, options?: { count?: number, ns?: string }) => { + if (key === 'segment.childChunks') + return options?.count === 1 ? 'child chunk' : 'child chunks' + if (key === 'segment.searchResults') + return 'search results' + if (key === 'segment.edited') + return 'edited' + if (key === 'operation.add') + return 'Add' + const prefix = options?.ns ? `${options.ns}.` : '' + return `${prefix}${key}` + }, + }), +})) + +// Mock document context +vi.mock('../context', () => ({ + useDocumentContext: (selector: (value: DocumentContextValue) => unknown) => { + const value: DocumentContextValue = { + datasetId: 'test-dataset-id', + documentId: 'test-document-id', + docForm: 'text' as ChunkingMode, + parentMode: mockParentMode.current, + } + return selector(value) + }, +})) + +// Mock segment list context +vi.mock('./index', () => ({ + useSegmentListContext: (selector: (value: { currChildChunk: { childChunkInfo?: ChildChunkDetail, showModal: boolean } }) => unknown) => { + return selector({ currChildChunk: mockCurrChildChunk.current }) + }, +})) + +// Mock skeleton component +vi.mock('./skeleton/full-doc-list-skeleton', () => ({ + default: () =>
Loading...
, +})) + +// Mock Empty component +vi.mock('./common/empty', () => ({ + default: ({ onClearFilter }: { onClearFilter: () => void }) => ( +
+ +
+ ), +})) + +// Mock FormattedText and EditSlice +vi.mock('../../../formatted-text/formatted', () => ({ + FormattedText: ({ children, className }: { children: React.ReactNode, className?: string }) => ( +
{children}
+ ), +})) + +vi.mock('../../../formatted-text/flavours/edit-slice', () => ({ + EditSlice: ({ label, text, onDelete, onClick, labelClassName, contentClassName }: { + label: string + text: string + onDelete: () => void + onClick: (e: React.MouseEvent) => void + labelClassName?: string + contentClassName?: string + }) => ( +
+ {label} + {text} + +
+ ), +})) + +// ============================================================================ +// Test Data Factories +// ============================================================================ + +const createMockChildChunk = (overrides: Partial = {}): ChildChunkDetail => ({ + id: `child-${Math.random().toString(36).substr(2, 9)}`, + position: 1, + segment_id: 'segment-1', + content: 'Child chunk content', + word_count: 100, + created_at: 1700000000, + updated_at: 1700000000, + type: 'automatic', + ...overrides, +}) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('ChildSegmentList', () => { + const defaultProps = { + childChunks: [] as ChildChunkDetail[], + parentChunkId: 'parent-1', + enabled: true, + } + + beforeEach(() => { + vi.clearAllMocks() + mockParentMode.current = 'paragraph' + mockCurrChildChunk.current = { childChunkInfo: undefined, showModal: false } + }) + + describe('Rendering', () => { + it('should render with empty child chunks', () => { + render() + + expect(screen.getByText(/child chunks/i)).toBeInTheDocument() + }) + + it('should render child chunks when provided', () => { + const childChunks = [ + createMockChildChunk({ id: 'child-1', position: 1, content: 'First chunk' }), + createMockChildChunk({ id: 'child-2', position: 2, content: 'Second chunk' }), + ] + + render() + + // In paragraph mode, content is collapsed by default + expect(screen.getByText(/2 child chunks/i)).toBeInTheDocument() + }) + + it('should render total count correctly with total prop in full-doc mode', () => { + mockParentMode.current = 'full-doc' + const childChunks = [createMockChildChunk()] + + // Pass inputValue="" to ensure isSearching is false + render() + + expect(screen.getByText(/5 child chunks/i)).toBeInTheDocument() + }) + + it('should render loading skeleton in full-doc mode when loading', () => { + mockParentMode.current = 'full-doc' + + render() + + expect(screen.getByTestId('full-doc-list-skeleton')).toBeInTheDocument() + }) + + it('should not render loading skeleton when not loading', () => { + mockParentMode.current = 'full-doc' + + render() + + expect(screen.queryByTestId('full-doc-list-skeleton')).not.toBeInTheDocument() + }) + }) + + describe('Paragraph Mode', () => { + beforeEach(() => { + mockParentMode.current = 'paragraph' + }) + + it('should show collapse icon in paragraph mode', () => { + const childChunks = [createMockChildChunk()] + + render() + + // Check for collapse/expand behavior + const totalRow = screen.getByText(/1 child chunk/i).closest('div') + expect(totalRow).toBeInTheDocument() + }) + + it('should toggle collapsed state when clicked', () => { + const childChunks = [createMockChildChunk({ content: 'Test content' })] + + render() + + // Initially collapsed in paragraph mode - content should not be visible + expect(screen.queryByTestId('formatted-text')).not.toBeInTheDocument() + + // Find and click the toggle area + const toggleArea = screen.getByText(/1 child chunk/i).closest('div') + + // Click to expand + if (toggleArea) + fireEvent.click(toggleArea) + + // After expansion, content should be visible + expect(screen.getByTestId('formatted-text')).toBeInTheDocument() + }) + + it('should apply opacity when disabled', () => { + const { container } = render() + + const wrapper = container.firstChild + expect(wrapper).toHaveClass('opacity-50') + }) + + it('should not apply opacity when enabled', () => { + const { container } = render() + + const wrapper = container.firstChild + expect(wrapper).not.toHaveClass('opacity-50') + }) + }) + + describe('Full-Doc Mode', () => { + beforeEach(() => { + mockParentMode.current = 'full-doc' + }) + + it('should show content by default in full-doc mode', () => { + const childChunks = [createMockChildChunk({ content: 'Full doc content' })] + + render() + + expect(screen.getByTestId('formatted-text')).toBeInTheDocument() + }) + + it('should render search input in full-doc mode', () => { + render() + + const input = document.querySelector('input') + expect(input).toBeInTheDocument() + }) + + it('should call handleInputChange when input changes', () => { + const handleInputChange = vi.fn() + + render() + + const input = document.querySelector('input') + if (input) { + fireEvent.change(input, { target: { value: 'test search' } }) + expect(handleInputChange).toHaveBeenCalledWith('test search') + } + }) + + it('should show search results text when searching', () => { + render() + + expect(screen.getByText(/3 search results/i)).toBeInTheDocument() + }) + + it('should show empty component when no results and searching', () => { + render( + , + ) + + expect(screen.getByTestId('empty-component')).toBeInTheDocument() + }) + + it('should call onClearFilter when clear button clicked in empty state', () => { + const onClearFilter = vi.fn() + + render( + , + ) + + const clearButton = screen.getByText('Clear Filter') + fireEvent.click(clearButton) + + expect(onClearFilter).toHaveBeenCalled() + }) + }) + + describe('Child Chunk Items', () => { + it('should render edited label when chunk is edited', () => { + mockParentMode.current = 'full-doc' + const editedChunk = createMockChildChunk({ + id: 'edited-chunk', + position: 1, + created_at: 1700000000, + updated_at: 1700000001, // Different from created_at + }) + + render() + + expect(screen.getByText(/C-1 · edited/i)).toBeInTheDocument() + }) + + it('should not show edited label when chunk is not edited', () => { + mockParentMode.current = 'full-doc' + const normalChunk = createMockChildChunk({ + id: 'normal-chunk', + position: 2, + created_at: 1700000000, + updated_at: 1700000000, // Same as created_at + }) + + render() + + expect(screen.getByText('C-2')).toBeInTheDocument() + expect(screen.queryByText(/edited/i)).not.toBeInTheDocument() + }) + + it('should call onClickSlice when chunk is clicked', () => { + mockParentMode.current = 'full-doc' + const onClickSlice = vi.fn() + const chunk = createMockChildChunk({ id: 'clickable-chunk' }) + + render( + , + ) + + const editSlice = screen.getByTestId('edit-slice') + fireEvent.click(editSlice) + + expect(onClickSlice).toHaveBeenCalledWith(chunk) + }) + + it('should call onDelete when delete button is clicked', () => { + mockParentMode.current = 'full-doc' + const onDelete = vi.fn() + const chunk = createMockChildChunk({ id: 'deletable-chunk', segment_id: 'seg-1' }) + + render( + , + ) + + const deleteButton = screen.getByTestId('delete-button') + fireEvent.click(deleteButton) + + expect(onDelete).toHaveBeenCalledWith('seg-1', 'deletable-chunk') + }) + + it('should apply focused styles when chunk is currently selected', () => { + mockParentMode.current = 'full-doc' + const chunk = createMockChildChunk({ id: 'focused-chunk' }) + mockCurrChildChunk.current = { childChunkInfo: chunk, showModal: true } + + render() + + const label = screen.getByTestId('edit-slice-label') + expect(label).toHaveClass('bg-state-accent-solid') + }) + }) + + describe('Add Button', () => { + it('should call handleAddNewChildChunk when Add button is clicked', () => { + const handleAddNewChildChunk = vi.fn() + + render( + , + ) + + const addButton = screen.getByText('Add') + fireEvent.click(addButton) + + expect(handleAddNewChildChunk).toHaveBeenCalledWith('parent-123') + }) + + it('should disable Add button when loading in full-doc mode', () => { + mockParentMode.current = 'full-doc' + + render() + + const addButton = screen.getByText('Add') + expect(addButton).toBeDisabled() + }) + + it('should stop propagation when Add button is clicked', () => { + const handleAddNewChildChunk = vi.fn() + const parentClickHandler = vi.fn() + + render( +
+ +
, + ) + + const addButton = screen.getByText('Add') + fireEvent.click(addButton) + + expect(handleAddNewChildChunk).toHaveBeenCalled() + // Parent should not be called due to stopPropagation + }) + }) + + describe('computeTotalInfo function', () => { + it('should return search results when searching in full-doc mode', () => { + mockParentMode.current = 'full-doc' + + render() + + expect(screen.getByText(/10 search results/i)).toBeInTheDocument() + }) + + it('should return "--" when total is 0 in full-doc mode', () => { + mockParentMode.current = 'full-doc' + + render() + + // When total is 0, displayText is '--' + expect(screen.getByText(/--/)).toBeInTheDocument() + }) + + it('should use childChunks length in paragraph mode', () => { + mockParentMode.current = 'paragraph' + const childChunks = [ + createMockChildChunk(), + createMockChildChunk(), + createMockChildChunk(), + ] + + render() + + expect(screen.getByText(/3 child chunks/i)).toBeInTheDocument() + }) + }) + + describe('Focused State', () => { + it('should not apply opacity when focused even if disabled', () => { + const { container } = render( + , + ) + + const wrapper = container.firstChild + expect(wrapper).not.toHaveClass('opacity-50') + }) + }) + + describe('Input clear button', () => { + it('should call handleInputChange with empty string when clear is clicked', () => { + mockParentMode.current = 'full-doc' + const handleInputChange = vi.fn() + + render( + , + ) + + // Find the clear button (it's the showClearIcon button in Input) + const input = document.querySelector('input') + if (input) { + // Trigger clear by simulating the input's onClear + const clearButton = document.querySelector('[class*="cursor-pointer"]') + if (clearButton) + fireEvent.click(clearButton) + } + }) + }) +}) diff --git a/web/app/components/datasets/documents/detail/completed/child-segment-list.tsx b/web/app/components/datasets/documents/detail/completed/child-segment-list.tsx index b23aac6af9..fd6fd338d0 100644 --- a/web/app/components/datasets/documents/detail/completed/child-segment-list.tsx +++ b/web/app/components/datasets/documents/detail/completed/child-segment-list.tsx @@ -1,7 +1,7 @@ import type { FC } from 'react' import type { ChildChunkDetail } from '@/models/datasets' import { RiArrowDownSLine, RiArrowRightSLine } from '@remixicon/react' -import { useMemo, useState } from 'react' +import { useState } from 'react' import { useTranslation } from 'react-i18next' import Divider from '@/app/components/base/divider' import Input from '@/app/components/base/input' @@ -29,6 +29,37 @@ type IChildSegmentCardProps = { focused?: boolean } +function computeTotalInfo( + isFullDocMode: boolean, + isSearching: boolean, + total: number | undefined, + childChunksLength: number, +): { displayText: string, count: number, translationKey: 'segment.searchResults' | 'segment.childChunks' } { + if (isSearching) { + const count = total ?? 0 + return { + displayText: count === 0 ? '--' : String(formatNumber(count)), + count, + translationKey: 'segment.searchResults', + } + } + + if (isFullDocMode) { + const count = total ?? 0 + return { + displayText: count === 0 ? '--' : String(formatNumber(count)), + count, + translationKey: 'segment.childChunks', + } + } + + return { + displayText: String(formatNumber(childChunksLength)), + count: childChunksLength, + translationKey: 'segment.childChunks', + } +} + const ChildSegmentList: FC = ({ childChunks, parentChunkId, @@ -49,59 +80,87 @@ const ChildSegmentList: FC = ({ const [collapsed, setCollapsed] = useState(true) - const toggleCollapse = () => { - setCollapsed(!collapsed) + const isParagraphMode = parentMode === 'paragraph' + const isFullDocMode = parentMode === 'full-doc' + const isSearching = inputValue !== '' && isFullDocMode + const contentOpacity = (enabled || focused) ? '' : 'opacity-50 group-hover/card:opacity-100' + const { displayText, count, translationKey } = computeTotalInfo(isFullDocMode, isSearching, total, childChunks.length) + const totalText = `${displayText} ${t(translationKey, { ns: 'datasetDocuments', count })}` + + const toggleCollapse = () => setCollapsed(prev => !prev) + const showContent = (isFullDocMode && !isLoading) || !collapsed + const hoverVisibleClass = isParagraphMode ? 'hidden group-hover/card:inline-block' : '' + + const renderCollapseIcon = () => { + if (!isParagraphMode) + return null + const Icon = collapsed ? RiArrowRightSLine : RiArrowDownSLine + return } - const isParagraphMode = useMemo(() => { - return parentMode === 'paragraph' - }, [parentMode]) + const renderChildChunkItem = (childChunk: ChildChunkDetail) => { + const isEdited = childChunk.updated_at !== childChunk.created_at + const isFocused = currChildChunk?.childChunkInfo?.id === childChunk.id + const label = isEdited + ? `C-${childChunk.position} · ${t('segment.edited', { ns: 'datasetDocuments' })}` + : `C-${childChunk.position}` - const isFullDocMode = useMemo(() => { - return parentMode === 'full-doc' - }, [parentMode]) + return ( + onDelete?.(childChunk.segment_id, childChunk.id)} + className="child-chunk" + labelClassName={isFocused ? 'bg-state-accent-solid text-text-primary-on-surface' : ''} + labelInnerClassName="text-[10px] font-semibold align-bottom leading-6" + contentClassName={cn('!leading-6', isFocused ? 'bg-state-accent-hover-alt text-text-primary' : 'text-text-secondary')} + showDivider={false} + onClick={(e) => { + e.stopPropagation() + onClickSlice?.(childChunk) + }} + offsetOptions={({ rects }) => ({ + mainAxis: isFullDocMode ? -rects.floating.width : 12 - rects.floating.width, + crossAxis: (20 - rects.floating.height) / 2, + })} + /> + ) + } - const contentOpacity = useMemo(() => { - return (enabled || focused) ? '' : 'opacity-50 group-hover/card:opacity-100' - }, [enabled, focused]) - - const totalText = useMemo(() => { - const isSearch = inputValue !== '' && isFullDocMode - if (!isSearch) { - const text = isFullDocMode - ? !total - ? '--' - : formatNumber(total) - : formatNumber(childChunks.length) - const count = isFullDocMode - ? text === '--' - ? 0 - : total - : childChunks.length - return `${text} ${t('segment.childChunks', { ns: 'datasetDocuments', count })}` + const renderContent = () => { + if (childChunks.length > 0) { + return ( + + {childChunks.map(renderChildChunkItem)} + + ) } - else { - const text = !total ? '--' : formatNumber(total) - const count = text === '--' ? 0 : total - return `${count} ${t('segment.searchResults', { ns: 'datasetDocuments', count })}` + if (inputValue !== '') { + return ( +
+ +
+ ) } - }, [isFullDocMode, total, childChunks.length, inputValue]) + return null + } return (
- {isFullDocMode ? : null} -
+ {isFullDocMode && } +
{ @@ -109,23 +168,15 @@ const ChildSegmentList: FC = ({ toggleCollapse() }} > - { - isParagraphMode - ? collapsed - ? ( - - ) - : () - : null - } + {renderCollapseIcon()} {totalText} - · + ·
- {isFullDocMode - ? ( - handleInputChange?.(e.target.value)} - onClear={() => handleInputChange?.('')} - /> - ) - : null} + {isFullDocMode && ( + handleInputChange?.(e.target.value)} + onClear={() => handleInputChange?.('')} + /> + )}
- {isLoading ? : null} - {((isFullDocMode && !isLoading) || !collapsed) - ? ( -
- {isParagraphMode && ( -
- -
- )} - {childChunks.length > 0 - ? ( - - {childChunks.map((childChunk) => { - const edited = childChunk.updated_at !== childChunk.created_at - const focused = currChildChunk?.childChunkInfo?.id === childChunk.id - return ( - onDelete?.(childChunk.segment_id, childChunk.id)} - className="child-chunk" - labelClassName={focused ? 'bg-state-accent-solid text-text-primary-on-surface' : ''} - labelInnerClassName="text-[10px] font-semibold align-bottom leading-6" - contentClassName={cn('!leading-6', focused ? 'bg-state-accent-hover-alt text-text-primary' : 'text-text-secondary')} - showDivider={false} - onClick={(e) => { - e.stopPropagation() - onClickSlice?.(childChunk) - }} - offsetOptions={({ rects }) => { - return { - mainAxis: isFullDocMode ? -rects.floating.width : 12 - rects.floating.width, - crossAxis: (20 - rects.floating.height) / 2, - } - }} - /> - ) - })} - - ) - : inputValue !== '' - ? ( -
- -
- ) - : null} + {isLoading && } + {showContent && ( +
+ {isParagraphMode && ( +
+
- ) - : null} + )} + {renderContent()} +
+ )}
) } diff --git a/web/app/components/datasets/documents/detail/completed/common/drawer.tsx b/web/app/components/datasets/documents/detail/completed/common/drawer.tsx index dc1b7192c3..a68742890a 100644 --- a/web/app/components/datasets/documents/detail/completed/common/drawer.tsx +++ b/web/app/components/datasets/documents/detail/completed/common/drawer.tsx @@ -17,6 +17,31 @@ type DrawerProps = { needCheckChunks?: boolean } +const SIDE_POSITION_CLASS = { + right: 'right-0', + left: 'left-0', + bottom: 'bottom-0', + top: 'top-0', +} as const + +function containsTarget(selector: string, target: Node | null): boolean { + const elements = document.querySelectorAll(selector) + return Array.from(elements).some(el => el?.contains(target)) +} + +function shouldReopenChunkDetail( + isClickOnChunk: boolean, + isClickOnChildChunk: boolean, + segmentModalOpen: boolean, + childChunkModalOpen: boolean, +): boolean { + if (segmentModalOpen && isClickOnChildChunk) + return true + if (childChunkModalOpen && isClickOnChunk && !isClickOnChildChunk) + return true + return !isClickOnChunk && !isClickOnChildChunk +} + const Drawer = ({ open, onClose, @@ -41,22 +66,22 @@ const Drawer = ({ const shouldCloseDrawer = useCallback((target: Node | null) => { const panelContent = panelContentRef.current - if (!panelContent) + if (!panelContent || !target) return false - const chunks = document.querySelectorAll('.chunk-card') - const childChunks = document.querySelectorAll('.child-chunk') - const imagePreviewer = document.querySelector('.image-previewer') - const isClickOnChunk = Array.from(chunks).some((chunk) => { - return chunk && chunk.contains(target) - }) - const isClickOnChildChunk = Array.from(childChunks).some((chunk) => { - return chunk && chunk.contains(target) - }) - const reopenChunkDetail = (currSegment.showModal && isClickOnChildChunk) - || (currChildChunk.showModal && isClickOnChunk && !isClickOnChildChunk) || (!isClickOnChunk && !isClickOnChildChunk) - const isClickOnImagePreviewer = imagePreviewer && imagePreviewer.contains(target) - return target && !panelContent.contains(target) && (!needCheckChunks || reopenChunkDetail) && !isClickOnImagePreviewer - }, [currSegment, currChildChunk, needCheckChunks]) + + if (panelContent.contains(target)) + return false + + if (containsTarget('.image-previewer', target)) + return false + + if (!needCheckChunks) + return true + + const isClickOnChunk = containsTarget('.chunk-card', target) + const isClickOnChildChunk = containsTarget('.child-chunk', target) + return shouldReopenChunkDetail(isClickOnChunk, isClickOnChildChunk, currSegment.showModal, currChildChunk.showModal) + }, [currSegment.showModal, currChildChunk.showModal, needCheckChunks]) const onDownCapture = useCallback((e: PointerEvent) => { if (!open || modal) @@ -77,32 +102,27 @@ const Drawer = ({ const isHorizontal = side === 'left' || side === 'right' + const overlayPointerEvents = modal && open ? 'pointer-events-auto' : 'pointer-events-none' + const content = (
- {showOverlay - ? ( -