diff --git a/AGENTS.md b/AGENTS.md index deab7c8629..7d96ac3a6d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -25,6 +25,30 @@ pnpm type-check:tsgo pnpm test ``` +### Frontend Linting + +ESLint is used for frontend code quality. Available commands: + +```bash +# Lint all files (report only) +pnpm lint + +# Lint and auto-fix issues +pnpm lint:fix + +# Lint specific files or directories +pnpm lint:fix app/components/base/button/ +pnpm lint:fix app/components/base/button/index.tsx + +# Lint quietly (errors only, no warnings) +pnpm lint:quiet + +# Check code complexity +pnpm lint:complexity +``` + +**Important**: Always run `pnpm lint:fix` before committing. The pre-commit hook runs `lint-staged` which only lints staged files. + ## Testing & Quality Practices - Follow TDD: red → green → refactor. diff --git a/agent-notes/.gitkeep b/agent-notes/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/agent-notes/api/core/model_runtime/model_providers/__base/large_language_model.py.md b/agent-notes/api/core/model_runtime/model_providers/__base/large_language_model.py.md deleted file mode 100644 index f03c41cc25..0000000000 --- a/agent-notes/api/core/model_runtime/model_providers/__base/large_language_model.py.md +++ /dev/null @@ -1,27 +0,0 @@ -# Notes: `large_language_model.py` - -## Purpose - -Provides the base `LargeLanguageModel` implementation used by the model runtime to invoke plugin-backed LLMs and to -bridge plugin daemon streaming semantics back into API-layer entities (`LLMResult`, `LLMResultChunk`). - -## Key behaviors / invariants - -- `invoke(..., stream=False)` still calls the plugin in streaming mode and then synthesizes a single `LLMResult` from - the first yielded `LLMResultChunk`. -- Plugin invocation is wrapped by `_invoke_llm_via_plugin(...)`, and `stream=False` normalization is handled by - `_normalize_non_stream_plugin_result(...)` / `_build_llm_result_from_first_chunk(...)`. -- Tool call deltas are merged incrementally via `_increase_tool_call(...)` to support multiple provider chunking - patterns (IDs anchored to first chunk, every chunk, or missing entirely). -- A tool-call delta with an empty `id` requires at least one existing tool call; otherwise we raise `ValueError` to - surface invalid delta sequences explicitly. -- Callback invocation is centralized in `_run_callbacks(...)` to ensure consistent error handling/logging. -- For compatibility with dify issue `#17799`, `prompt_messages` may be removed by the plugin daemon in chunks and must - be re-attached in this layer before callbacks/consumers use them. -- Callback hooks (`on_before_invoke`, `on_new_chunk`, `on_after_invoke`, `on_invoke_error`) must not break invocation - unless `callback.raise_error` is true. - -## Test focus - -- `api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py` validates tool-call delta merging and - patches `_gen_tool_call_id` for deterministic IDs. diff --git a/api/AGENTS.md b/api/AGENTS.md index 6ce419828b..13adb42276 100644 --- a/api/AGENTS.md +++ b/api/AGENTS.md @@ -1,97 +1,47 @@ # API Agent Guide -## Agent Notes (must-check) +## Notes for Agent (must-check) -Before you start work on any backend file under `api/`, you MUST check whether a related note exists under: +Before changing any backend code under `api/`, you MUST read the surrounding docstrings and comments. These notes contain required context (invariants, edge cases, trade-offs) and are treated as part of the spec. -- `agent-notes/.md` +Look for: -Rules: +- The module (file) docstring at the top of a source code file +- Docstrings on classes and functions/methods +- Paragraph/block comments for non-obvious logic -- **Path mapping**: for a target file `/.py`, the note must be `agent-notes//.py.md` (same folder structure, same filename, plus `.md`). -- **Before working**: - - If the note exists, read it first and follow any constraints/decisions recorded there. - - If the note conflicts with the current code, or references an "origin" file/path that has been deleted, renamed, or migrated, treat the **code as the single source of truth** and update the note to match reality. - - If the note does not exist, create it with a short architecture/intent summary and any relevant invariants/edge cases. -- **During working**: - - Keep the note in sync as you discover constraints, make decisions, or change approach. - - If you move/rename a file, migrate its note to the new mapped path (and fix any outdated references inside the note). - - Record non-obvious edge cases, trade-offs, and the test/verification plan as you go (not just at the end). - - Keep notes **coherent**: integrate new findings into the relevant sections and rewrite for clarity; avoid append-only “recent fix” / changelog-style additions unless the note is explicitly intended to be a changelog. -- **When finishing work**: - - Update the related note(s) to reflect what changed, why, and any new edge cases/tests. - - If a file is deleted, remove or clearly deprecate the corresponding note so it cannot be mistaken as current guidance. - - Keep notes concise and accurate; they are meant to prevent repeated rediscovery. +### What to write where -## Skill Index +- Keep notes scoped: module notes cover module-wide context, class notes cover class-wide context, function/method notes cover behavioural contracts, and paragraph/block comments cover local “why”. Avoid duplicating the same content across scopes unless repetition prevents misuse. +- **Module (file) docstring**: purpose, boundaries, key invariants, and “gotchas” that a new reader must know before editing. + - Include cross-links to the key collaborators (modules/services) when discovery is otherwise hard. + - Prefer stable facts (invariants, contracts) over ephemeral “today we…” notes. +- **Class docstring**: responsibility, lifecycle, invariants, and how it should be used (or not used). + - If the class is intentionally stateful, note what state exists and what methods mutate it. + - If concurrency/async assumptions matter, state them explicitly. +- **Function/method docstring**: behavioural contract. + - Document arguments, return shape, side effects (DB writes, external I/O, task dispatch), and raised domain exceptions. + - Add examples only when they prevent misuse. +- **Paragraph/block comments**: explain *why* (trade-offs, historical constraints, surprising edge cases), not what the code already states. + - Keep comments adjacent to the logic they justify; delete or rewrite comments that no longer match reality. -Start with the section that best matches your need. Each entry lists the problems it solves plus key files/concepts so you know what to expect before opening it. +### Rules (must follow) -### Platform Foundations +In this section, “notes” means module/class/function docstrings plus any relevant paragraph/block comments. -#### [Infrastructure Overview](agent_skills/infra.md) - -- **When to read this** - - You need to understand where a feature belongs in the architecture. - - You’re wiring storage, Redis, vector stores, or OTEL. - - You’re about to add CLI commands or async jobs. -- **What it covers** - - Configuration stack (`configs/app_config.py`, remote settings) - - Storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`) - - Redis conventions (`extensions/ext_redis.py`) - - Plugin runtime topology - - Vector-store factory (`core/rag/datasource/vdb/*`) - - Observability hooks - - SSRF proxy usage - - Core CLI commands - -### Plugin & Extension Development - -#### [Plugin Systems](agent_skills/plugin.md) - -- **When to read this** - - You’re building or debugging a marketplace plugin. - - You need to know how manifests, providers, daemons, and migrations fit together. -- **What it covers** - - Plugin manifests (`core/plugin/entities/plugin.py`) - - Installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands) - - Runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent) - - Daemon coordination (`core/plugin/entities/plugin_daemon.py`) - - How provider registries surface capabilities to the rest of the platform - -#### [Plugin OAuth](agent_skills/plugin_oauth.md) - -- **When to read this** - - You must integrate OAuth for a plugin or datasource. - - You’re handling credential encryption or refresh flows. -- **Topics** - - Credential storage - - Encryption helpers (`core/helper/provider_encryption.py`) - - OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`) - - How console/API layers expose the flows - -### Workflow Entry & Execution - -#### [Trigger Concepts](agent_skills/trigger.md) - -- **When to read this** - - You’re debugging why a workflow didn’t start. - - You’re adding a new trigger type or hook. - - You need to trace async execution, draft debugging, or webhook/schedule pipelines. -- **Details** - - Start-node taxonomy - - Webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`) - - Async orchestration (`services/async_workflow_service.py`, Celery queues) - - Debug event bus - - Storage/logging interactions - -## General Reminders - -- All skill docs assume you follow the coding style rules below—run the lint/type/test commands before submitting changes. -- When you cannot find an answer in these briefs, search the codebase using the paths referenced (e.g., `core/plugin/impl/tool.py`, `services/dataset_service.py`). -- If you run into cross-cutting concerns (tenancy, configuration, storage), check the infrastructure guide first; it links to most supporting modules. -- Keep multi-tenancy and configuration central: everything flows through `configs.dify_config` and `tenant_id`. -- When touching plugins or triggers, consult both the system overview and the specialised doc to ensure you adjust lifecycle, storage, and observability consistently. +- **Before working** + - Read the notes in the area you’ll touch; treat them as part of the spec. + - If a docstring or comment conflicts with the current code, treat the **code as the single source of truth** and update the docstring or comment to match reality. + - If important intent/invariants/edge cases are missing, add them in the closest docstring or comment (module for overall scope, function for behaviour). +- **During working** + - Keep the notes in sync as you discover constraints, make decisions, or change approach. + - If you move/rename responsibilities across modules/classes, update the affected docstrings and comments so readers can still find the “why” and the invariants. + - Record non-obvious edge cases, trade-offs, and the test/verification plan in the nearest docstring or comment that will stay correct. + - Keep the notes **coherent**: integrate new findings into the relevant docstrings and comments; avoid append-only “recent fix” / changelog-style additions. +- **When finishing** + - Update the notes to reflect what changed, why, and any new edge cases/tests. + - Remove or rewrite any comments that could be mistaken as current guidance but no longer apply. + - Keep docstrings and comments concise and accurate; they are meant to prevent repeated rediscovery. ## Coding Style @@ -226,7 +176,7 @@ Before opening a PR / submitting: - Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic. - Services: coordinate repositories, providers, background tasks; keep side effects explicit. -- Document non-obvious behaviour with concise comments. +- Document non-obvious behaviour with concise docstrings and comments. ### Miscellaneous diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index cd958bbb36..d05e726dcb 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -36,6 +36,16 @@ class NotionEstimatePayload(BaseModel): doc_language: str = Field(default="English") +class DataSourceNotionListQuery(BaseModel): + dataset_id: str | None = Field(default=None, description="Dataset ID") + credential_id: str = Field(..., description="Credential ID", min_length=1) + datasource_parameters: dict[str, Any] | None = Field(default=None, description="Datasource parameters JSON string") + + +class DataSourceNotionPreviewQuery(BaseModel): + credential_id: str = Field(..., description="Credential ID", min_length=1) + + register_schema_model(console_ns, NotionEstimatePayload) @@ -136,26 +146,15 @@ class DataSourceNotionListApi(Resource): def get(self): current_user, current_tenant_id = current_account_with_tenant() - dataset_id = request.args.get("dataset_id", default=None, type=str) - credential_id = request.args.get("credential_id", default=None, type=str) - if not credential_id: - raise ValueError("Credential id is required.") + query = DataSourceNotionListQuery.model_validate(request.args.to_dict()) # Get datasource_parameters from query string (optional, for GitHub and other datasources) - datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str) - datasource_parameters = {} - if datasource_parameters_str: - try: - datasource_parameters = json.loads(datasource_parameters_str) - if not isinstance(datasource_parameters, dict): - raise ValueError("datasource_parameters must be a JSON object.") - except json.JSONDecodeError: - raise ValueError("Invalid datasource_parameters JSON format.") + datasource_parameters = query.datasource_parameters or {} datasource_provider_service = DatasourceProviderService() credential = datasource_provider_service.get_datasource_credentials( tenant_id=current_tenant_id, - credential_id=credential_id, + credential_id=query.credential_id, provider="notion_datasource", plugin_id="langgenius/notion_datasource", ) @@ -164,8 +163,8 @@ class DataSourceNotionListApi(Resource): exist_page_ids = [] with Session(db.engine) as session: # import notion in the exist dataset - if dataset_id: - dataset = DatasetService.get_dataset(dataset_id) + if query.dataset_id: + dataset = DatasetService.get_dataset(query.dataset_id) if not dataset: raise NotFound("Dataset not found.") if dataset.data_source_type != "notion_import": @@ -173,7 +172,7 @@ class DataSourceNotionListApi(Resource): documents = session.scalars( select(Document).filter_by( - dataset_id=dataset_id, + dataset_id=query.dataset_id, tenant_id=current_tenant_id, data_source_type="notion_import", enabled=True, @@ -240,13 +239,12 @@ class DataSourceNotionApi(Resource): def get(self, page_id, page_type): _, current_tenant_id = current_account_with_tenant() - credential_id = request.args.get("credential_id", default=None, type=str) - if not credential_id: - raise ValueError("Credential id is required.") + query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict()) + datasource_provider_service = DatasourceProviderService() credential = datasource_provider_service.get_datasource_credentials( tenant_id=current_tenant_id, - credential_id=credential_id, + credential_id=query.credential_id, provider="notion_datasource", plugin_id="langgenius/notion_datasource", ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 8ceb896d4f..e9371b608c 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -176,7 +176,18 @@ class IndexingEstimatePayload(BaseModel): return result -register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload) +class ConsoleDatasetListQuery(BaseModel): + page: int = Field(default=1, description="Page number") + limit: int = Field(default=20, description="Number of items per page") + keyword: str | None = Field(default=None, description="Search keyword") + include_all: bool = Field(default=False, description="Include all datasets") + ids: list[str] = Field(default_factory=list, description="Filter by dataset IDs") + tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs") + + +register_schema_models( + console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload, ConsoleDatasetListQuery +) def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]: @@ -275,18 +286,19 @@ class DatasetListApi(Resource): @enterprise_license_required def get(self): current_user, current_tenant_id = current_account_with_tenant() - page = request.args.get("page", default=1, type=int) - limit = request.args.get("limit", default=20, type=int) - ids = request.args.getlist("ids") + query = ConsoleDatasetListQuery.model_validate(request.args.to_dict(flat=False)) # provider = request.args.get("provider", default="vendor") - search = request.args.get("keyword", default=None, type=str) - tag_ids = request.args.getlist("tag_ids") - include_all = request.args.get("include_all", default="false").lower() == "true" - if ids: - datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id) + if query.ids: + datasets, total = DatasetService.get_datasets_by_ids(query.ids, current_tenant_id) else: datasets, total = DatasetService.get_datasets( - page, limit, current_tenant_id, current_user, search, tag_ids, include_all + query.page, + query.limit, + current_tenant_id, + current_user, + query.keyword, + query.tag_ids, + query.include_all, ) # check embedding setting @@ -318,7 +330,13 @@ class DatasetListApi(Resource): else: item.update({"partial_member_list": []}) - response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} + response = { + "data": data, + "has_more": len(datasets) == query.limit, + "limit": query.limit, + "total": total, + "page": query.page, + } return response, 200 @console_ns.doc("create_dataset") diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index a70a7ce480..588eb6e1b8 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -98,12 +98,19 @@ class BedrockRetrievalPayload(BaseModel): knowledge_id: str +class ExternalApiTemplateListQuery(BaseModel): + page: int = Field(default=1, description="Page number") + limit: int = Field(default=20, description="Number of items per page") + keyword: str | None = Field(default=None, description="Search keyword") + + register_schema_models( console_ns, ExternalKnowledgeApiPayload, ExternalDatasetCreatePayload, ExternalHitTestingPayload, BedrockRetrievalPayload, + ExternalApiTemplateListQuery, ) @@ -124,19 +131,17 @@ class ExternalApiTemplateListApi(Resource): @account_initialization_required def get(self): _, current_tenant_id = current_account_with_tenant() - page = request.args.get("page", default=1, type=int) - limit = request.args.get("limit", default=20, type=int) - search = request.args.get("keyword", default=None, type=str) + query = ExternalApiTemplateListQuery.model_validate(request.args.to_dict()) external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis( - page, limit, current_tenant_id, search + query.page, query.limit, current_tenant_id, query.keyword ) response = { "data": [item.to_dict() for item in external_knowledge_apis], - "has_more": len(external_knowledge_apis) == limit, - "limit": limit, + "has_more": len(external_knowledge_apis) == query.limit, + "limit": query.limit, "total": total, - "page": page, + "page": query.page, } return response, 200 diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index e42db10ba6..b77eac605e 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -3,7 +3,7 @@ from typing import Any from flask import request from flask_restx import Resource, marshal_with -from pydantic import BaseModel +from pydantic import BaseModel, Field from sqlalchemy import and_, select from werkzeug.exceptions import BadRequest, Forbidden, NotFound @@ -28,6 +28,10 @@ class InstalledAppUpdatePayload(BaseModel): is_pinned: bool | None = None +class InstalledAppsListQuery(BaseModel): + app_id: str | None = Field(default=None, description="App ID to filter by") + + logger = logging.getLogger(__name__) @@ -37,13 +41,13 @@ class InstalledAppsListApi(Resource): @account_initialization_required @marshal_with(installed_app_list_fields) def get(self): - app_id = request.args.get("app_id", default=None, type=str) + query = InstalledAppsListQuery.model_validate(request.args.to_dict()) current_user, current_tenant_id = current_account_with_tenant() - if app_id: + if query.app_id: installed_apps = db.session.scalars( select(InstalledApp).where( - and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id) + and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == query.app_id) ) ).all() else: diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 023ffc991a..9988524a80 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -40,6 +40,7 @@ register_schema_models( TagBasePayload, TagBindingPayload, TagBindingRemovePayload, + TagListQueryParam, ) diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 94faf8dd42..b036a71f18 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -87,6 +87,14 @@ class TagUnbindingPayload(BaseModel): target_id: str +class DatasetListQuery(BaseModel): + page: int = Field(default=1, description="Page number") + limit: int = Field(default=20, description="Number of items per page") + keyword: str | None = Field(default=None, description="Search keyword") + include_all: bool = Field(default=False, description="Include all datasets") + tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs") + + register_schema_models( service_api_ns, DatasetCreatePayload, @@ -96,6 +104,7 @@ register_schema_models( TagDeletePayload, TagBindingPayload, TagUnbindingPayload, + DatasetListQuery, ) @@ -113,15 +122,11 @@ class DatasetListApi(DatasetApiResource): ) def get(self, tenant_id): """Resource for getting datasets.""" - page = request.args.get("page", default=1, type=int) - limit = request.args.get("limit", default=20, type=int) + query = DatasetListQuery.model_validate(request.args.to_dict(flat=False)) # provider = request.args.get("provider", default="vendor") - search = request.args.get("keyword", default=None, type=str) - tag_ids = request.args.getlist("tag_ids") - include_all = request.args.get("include_all", default="false").lower() == "true" datasets, total = DatasetService.get_datasets( - page, limit, tenant_id, current_user, search, tag_ids, include_all + query.page, query.limit, tenant_id, current_user, query.keyword, query.tag_ids, query.include_all ) # check embedding setting provider_manager = ProviderManager() @@ -147,7 +152,13 @@ class DatasetListApi(DatasetApiResource): item["embedding_available"] = False else: item["embedding_available"] = True - response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} + response = { + "data": data, + "has_more": len(datasets) == query.limit, + "limit": query.limit, + "total": total, + "page": query.page, + } return response, 200 @service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__]) diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 49ff4f57dc..1260645624 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -69,7 +69,14 @@ class DocumentTextUpdate(BaseModel): return self -for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate]: +class DocumentListQuery(BaseModel): + page: int = Field(default=1, description="Page number") + limit: int = Field(default=20, description="Number of items per page") + keyword: str | None = Field(default=None, description="Search keyword") + status: str | None = Field(default=None, description="Document status filter") + + +for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate, DocumentListQuery]: service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore @@ -460,34 +467,33 @@ class DocumentListApi(DatasetApiResource): def get(self, tenant_id, dataset_id): dataset_id = str(dataset_id) tenant_id = str(tenant_id) - page = request.args.get("page", default=1, type=int) - limit = request.args.get("limit", default=20, type=int) - search = request.args.get("keyword", default=None, type=str) - status = request.args.get("status", default=None, type=str) + query_params = DocumentListQuery.model_validate(request.args.to_dict()) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) - if status: - query = DocumentService.apply_display_status_filter(query, status) + if query_params.status: + query = DocumentService.apply_display_status_filter(query, query_params.status) - if search: - search = f"%{search}%" + if query_params.keyword: + search = f"%{query_params.keyword}%" query = query.where(Document.name.like(search)) query = query.order_by(desc(Document.created_at), desc(Document.position)) - paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_documents = db.paginate( + select=query, page=query_params.page, per_page=query_params.limit, max_per_page=100, error_out=False + ) documents = paginated_documents.items response = { "data": marshal(documents, document_fields), - "has_more": len(documents) == limit, - "limit": limit, + "has_more": len(documents) == query_params.limit, + "limit": query_params.limit, "total": paginated_documents.total, - "page": page, + "page": query_params.page, } return response diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index aab25c1af3..b8d9508004 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -11,7 +11,9 @@ from controllers.service_api.wraps import DatasetApiResource, cloud_edition_bill from fields.dataset_fields import dataset_metadata_fields from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, MetadataArgs, + MetadataDetail, MetadataOperationData, ) from services.metadata_service import MetadataService @@ -22,7 +24,13 @@ class MetadataUpdatePayload(BaseModel): register_schema_model(service_api_ns, MetadataUpdatePayload) -register_schema_models(service_api_ns, MetadataArgs, MetadataOperationData) +register_schema_models( + service_api_ns, + MetadataArgs, + MetadataDetail, + DocumentMetadataOperation, + MetadataOperationData, +) @service_api_ns.route("/datasets//metadata") diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 2760466a3b..8b6b8f227b 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -236,4 +236,7 @@ class AgentChatAppRunner(AppRunner): queue_manager=queue_manager, stream=application_generate_entity.stream, agent=True, + message_id=message.id, + user_id=application_generate_entity.user_id, + tenant_id=app_config.tenant_id, ) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index e2e6c11480..617515945b 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,6 +1,8 @@ +import base64 import logging import time from collections.abc import Generator, Mapping, Sequence +from mimetypes import guess_extension from typing import TYPE_CHECKING, Any, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity @@ -11,10 +13,16 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, ModelConfigWithCredentialsEntity, ) -from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent +from core.app.entities.queue_entities import ( + QueueAgentMessageEvent, + QueueLLMChunkEvent, + QueueMessageEndEvent, + QueueMessageFileEvent, +) from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.external_data_tool.external_data_fetch import ExternalDataFetch +from core.file.enums import FileTransferMethod, FileType from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -22,6 +30,7 @@ from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, + TextPromptMessageContent, ) from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.invoke import InvokeBadRequestError @@ -29,7 +38,10 @@ from core.moderation.input_moderation import InputModeration from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform -from models.model import App, AppMode, Message, MessageAnnotation +from core.tools.tool_file_manager import ToolFileManager +from extensions.ext_database import db +from models.enums import CreatorUserRole +from models.model import App, AppMode, Message, MessageAnnotation, MessageFile if TYPE_CHECKING: from core.file.models import File @@ -203,6 +215,9 @@ class AppRunner: queue_manager: AppQueueManager, stream: bool, agent: bool = False, + message_id: str | None = None, + user_id: str | None = None, + tenant_id: str | None = None, ): """ Handle invoke result @@ -210,21 +225,41 @@ class AppRunner: :param queue_manager: application queue manager :param stream: stream :param agent: agent + :param message_id: message id for multimodal output + :param user_id: user id for multimodal output + :param tenant_id: tenant id for multimodal output :return: """ if not stream and isinstance(invoke_result, LLMResult): - self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) + self._handle_invoke_result_direct( + invoke_result=invoke_result, + queue_manager=queue_manager, + ) elif stream and isinstance(invoke_result, Generator): - self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) + self._handle_invoke_result_stream( + invoke_result=invoke_result, + queue_manager=queue_manager, + agent=agent, + message_id=message_id, + user_id=user_id, + tenant_id=tenant_id, + ) else: raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") - def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool): + def _handle_invoke_result_direct( + self, + invoke_result: LLMResult, + queue_manager: AppQueueManager, + ): """ Handle invoke result direct :param invoke_result: invoke result :param queue_manager: application queue manager :param agent: agent + :param message_id: message id for multimodal output + :param user_id: user id for multimodal output + :param tenant_id: tenant id for multimodal output :return: """ queue_manager.publish( @@ -235,13 +270,22 @@ class AppRunner: ) def _handle_invoke_result_stream( - self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool + self, + invoke_result: Generator[LLMResultChunk, None, None], + queue_manager: AppQueueManager, + agent: bool, + message_id: str | None = None, + user_id: str | None = None, + tenant_id: str | None = None, ): """ Handle invoke result :param invoke_result: invoke result :param queue_manager: application queue manager :param agent: agent + :param message_id: message id for multimodal output + :param user_id: user id for multimodal output + :param tenant_id: tenant id for multimodal output :return: """ model: str = "" @@ -259,12 +303,26 @@ class AppRunner: text += message.content elif isinstance(message.content, list): for content in message.content: - if not isinstance(content, str): - # TODO(QuantumGhost): Add multimodal output support for easy ui. - _logger.warning("received multimodal output, type=%s", type(content)) + if isinstance(content, str): + text += content + elif isinstance(content, TextPromptMessageContent): text += content.data + elif isinstance(content, ImagePromptMessageContent): + if message_id and user_id and tenant_id: + try: + self._handle_multimodal_image_content( + content=content, + message_id=message_id, + user_id=user_id, + tenant_id=tenant_id, + queue_manager=queue_manager, + ) + except Exception: + _logger.exception("Failed to handle multimodal image output") + else: + _logger.warning("Received multimodal output but missing required parameters") else: - text += content # failback to str + text += content.data if hasattr(content, "data") else str(content) if not model: model = result.model @@ -289,6 +347,101 @@ class AppRunner: PublishFrom.APPLICATION_MANAGER, ) + def _handle_multimodal_image_content( + self, + content: ImagePromptMessageContent, + message_id: str, + user_id: str, + tenant_id: str, + queue_manager: AppQueueManager, + ): + """ + Handle multimodal image content from LLM response. + Save the image and create a MessageFile record. + + :param content: ImagePromptMessageContent instance + :param message_id: message id + :param user_id: user id + :param tenant_id: tenant id + :param queue_manager: queue manager + :return: + """ + _logger.info("Handling multimodal image content for message %s", message_id) + + image_url = content.url + base64_data = content.base64_data + + _logger.info("Image URL: %s, Base64 data present: %s", image_url, base64_data) + + if not image_url and not base64_data: + _logger.warning("Image content has neither URL nor base64 data") + return + + tool_file_manager = ToolFileManager() + + # Save the image file + try: + if image_url: + # Download image from URL + _logger.info("Downloading image from URL: %s", image_url) + tool_file = tool_file_manager.create_file_by_url( + user_id=user_id, + tenant_id=tenant_id, + file_url=image_url, + conversation_id=None, + ) + _logger.info("Image saved successfully, tool_file_id: %s", tool_file.id) + elif base64_data: + if base64_data.startswith("data:"): + base64_data = base64_data.split(",", 1)[1] + + image_binary = base64.b64decode(base64_data) + mimetype = content.mime_type or "image/png" + extension = guess_extension(mimetype) or ".png" + + tool_file = tool_file_manager.create_file_by_raw( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_binary=image_binary, + mimetype=mimetype, + filename=f"generated_image{extension}", + ) + _logger.info("Image saved successfully, tool_file_id: %s", tool_file.id) + else: + return + except Exception: + _logger.exception("Failed to save image file") + return + + # Create MessageFile record + message_file = MessageFile( + message_id=message_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + belongs_to="assistant", + url=f"/files/tools/{tool_file.id}", + upload_file_id=tool_file.id, + created_by_role=( + CreatorUserRole.ACCOUNT + if queue_manager.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} + else CreatorUserRole.END_USER + ), + created_by=user_id, + ) + + db.session.add(message_file) + db.session.commit() + db.session.refresh(message_file) + + # Publish QueueMessageFileEvent + queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file.id), + PublishFrom.APPLICATION_MANAGER, + ) + + _logger.info("QueueMessageFileEvent published for message_file_id: %s", message_file.id) + def moderation_for_inputs( self, *, diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index f8338b226b..7d1a4c619f 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -226,5 +226,10 @@ class ChatAppRunner(AppRunner): # handle invoke result self._handle_invoke_result( - invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream + invoke_result=invoke_result, + queue_manager=queue_manager, + stream=application_generate_entity.stream, + message_id=message.id, + user_id=application_generate_entity.user_id, + tenant_id=app_config.tenant_id, ) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index ddfb5725b4..a872c2e1f7 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -184,5 +184,10 @@ class CompletionAppRunner(AppRunner): # handle invoke result self._handle_invoke_result( - invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream + invoke_result=invoke_result, + queue_manager=queue_manager, + stream=application_generate_entity.stream, + message_id=message.id, + user_id=application_generate_entity.user_id, + tenant_id=app_config.tenant_id, ) diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 5bb93fa44a..6c997753fa 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -39,6 +39,7 @@ from core.app.entities.task_entities import ( MessageAudioEndStreamResponse, MessageAudioStreamResponse, MessageEndStreamResponse, + StreamEvent, StreamResponse, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline @@ -70,6 +71,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): _task_state: EasyUITaskState _application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity] + _precomputed_event_type: StreamEvent | None = None def __init__( self, @@ -342,11 +344,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): - event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id) + # Determine the event type once, on first LLM chunk, and reuse for subsequent chunks + if not hasattr(self, "_precomputed_event_type") or self._precomputed_event_type is None: + self._precomputed_event_type = self._message_cycle_manager.get_message_event_type( + message_id=self._message_id + ) yield self._message_cycle_manager.message_to_stream_response( answer=cast(str, delta_text), message_id=self._message_id, - event_type=event_type, + event_type=self._precomputed_event_type, ) else: yield self._agent_message_to_stream_response( diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 0e7f300cee..2d4ee08daf 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -5,7 +5,7 @@ from threading import Thread from typing import Union from flask import Flask, current_app -from sqlalchemy import exists, select +from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config @@ -30,6 +30,7 @@ from core.app.entities.task_entities import ( StreamEvent, WorkflowTaskState, ) +from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator from core.tools.signature import sign_tool_file from extensions.ext_database import db @@ -57,13 +58,15 @@ class MessageCycleManager: self._message_has_file: set[str] = set() def get_message_event_type(self, message_id: str) -> StreamEvent: + # Fast path: cached determination from prior QueueMessageFileEvent if message_id in self._message_has_file: return StreamEvent.MESSAGE_FILE - with Session(db.engine, expire_on_commit=False) as session: - has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar() + # Use SQLAlchemy 2.x style session.scalar(select(...)) + with session_factory.create_session() as session: + message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id)) - if has_file: + if message_file: self._message_has_file.add(message_id) return StreamEvent.MESSAGE_FILE @@ -199,6 +202,8 @@ class MessageCycleManager: message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id)) if message_file and message_file.url is not None: + self._message_has_file.add(message_file.message_id) + # get tool file id tool_file_id = message_file.url.split("/")[-1] # trim extension diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index 98ea15e3fc..ce23da1e09 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -1,4 +1,4 @@ -from collections.abc import Generator, Mapping +from collections.abc import Generator from typing import Any from core.datasource.__base.datasource_plugin import DatasourcePlugin @@ -34,7 +34,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): def get_online_document_pages( self, user_id: str, - datasource_parameters: Mapping[str, Any], + datasource_parameters: dict[str, Any], provider_type: str, ) -> Generator[OnlineDocumentPagesMessage, None, None]: manager = PluginDatasourceManager() diff --git a/api/pyproject.toml b/api/pyproject.toml index 9f9bd11fa6..575c1434c5 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -64,7 +64,7 @@ dependencies = [ "pandas[excel,output-formatting,performance]~=2.2.2", "psycogreen~=1.0.2", "psycopg2-binary~=2.9.6", - "pycryptodome==3.19.1", + "pycryptodome==3.23.0", "pydantic~=2.11.4", "pydantic-extra-types~=2.10.3", "pydantic-settings~=2.11.0", diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 26ce8cad33..946b8cdfdb 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -131,7 +131,7 @@ class BillingService: headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" - response = httpx.request(method, url, json=json, params=params, headers=headers) + response = httpx.request(method, url, json=json, params=params, headers=headers, follow_redirects=True) if method == "GET" and response.status_code != httpx.codes.OK: raise ValueError("Unable to retrieve billing information. Please try again later or contact support.") if method == "PUT": @@ -143,6 +143,9 @@ class BillingService: raise ValueError("Invalid arguments.") if method == "POST" and response.status_code != httpx.codes.OK: raise ValueError(f"Unable to send request to {url}. Please try again later or contact support.") + if method == "DELETE" and response.status_code != httpx.codes.OK: + logger.error("billing_service: DELETE response: %s %s", response.status_code, response.text) + raise ValueError(f"Unable to process delete request {url}. Please try again later or contact support.") return response.json() @staticmethod @@ -165,7 +168,7 @@ class BillingService: def delete_account(cls, account_id: str): """Delete account.""" params = {"account_id": account_id} - return cls._send_request("DELETE", "/account/", params=params) + return cls._send_request("DELETE", "/account", params=params) @classmethod def is_email_in_freeze(cls, email: str) -> bool: diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py index e6492c230d..b5e6508006 100644 --- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py +++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3 CACHE_REDIS_KEY_PREFIX = "plugin_autoupgrade_check_task:cached_plugin_manifests:" -CACHE_REDIS_TTL = 60 * 15 # 15 minutes +CACHE_REDIS_TTL = 60 * 60 # 1 hour def _get_redis_cache_key(plugin_id: str) -> str: diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py new file mode 100644 index 0000000000..421a5246eb --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -0,0 +1,454 @@ +"""Test multimodal image output handling in BaseAppRunner.""" + +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueMessageFileEvent +from core.file.enums import FileTransferMethod, FileType +from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from models.enums import CreatorUserRole + + +class TestBaseAppRunnerMultimodal: + """Test that BaseAppRunner correctly handles multimodal image content.""" + + @pytest.fixture + def mock_user_id(self): + """Mock user ID.""" + return str(uuid4()) + + @pytest.fixture + def mock_tenant_id(self): + """Mock tenant ID.""" + return str(uuid4()) + + @pytest.fixture + def mock_message_id(self): + """Mock message ID.""" + return str(uuid4()) + + @pytest.fixture + def mock_queue_manager(self): + """Create a mock queue manager.""" + manager = MagicMock() + manager.invoke_from = InvokeFrom.SERVICE_API + return manager + + @pytest.fixture + def mock_tool_file(self): + """Create a mock tool file.""" + tool_file = MagicMock() + tool_file.id = str(uuid4()) + return tool_file + + @pytest.fixture + def mock_message_file(self): + """Create a mock message file.""" + message_file = MagicMock() + message_file.id = str(uuid4()) + return message_file + + def test_handle_multimodal_image_content_with_url( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test handling image from URL.""" + # Arrange + image_url = "http://example.com/image.png" + content = ImagePromptMessageContent( + url=image_url, + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_url.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert + # Verify tool file was created from URL + mock_mgr.create_file_by_url.assert_called_once_with( + user_id=mock_user_id, + tenant_id=mock_tenant_id, + file_url=image_url, + conversation_id=None, + ) + + # Verify message file was created with correct parameters + mock_msg_file_class.assert_called_once() + call_kwargs = mock_msg_file_class.call_args[1] + assert call_kwargs["message_id"] == mock_message_id + assert call_kwargs["type"] == FileType.IMAGE + assert call_kwargs["transfer_method"] == FileTransferMethod.TOOL_FILE + assert call_kwargs["belongs_to"] == "assistant" + assert call_kwargs["created_by"] == mock_user_id + + # Verify database operations + mock_session.add.assert_called_once_with(mock_message_file) + mock_session.commit.assert_called_once() + mock_session.refresh.assert_called_once_with(mock_message_file) + + # Verify event was published + mock_queue_manager.publish.assert_called_once() + publish_call = mock_queue_manager.publish.call_args + assert isinstance(publish_call[0][0], QueueMessageFileEvent) + assert publish_call[0][0].message_file_id == mock_message_file.id + # publish_from might be passed as positional or keyword argument + assert ( + publish_call[0][1] == PublishFrom.APPLICATION_MANAGER + or publish_call.kwargs.get("publish_from") == PublishFrom.APPLICATION_MANAGER + ) + + def test_handle_multimodal_image_content_with_base64( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test handling image from base64 data.""" + # Arrange + import base64 + + # Create a small test image (1x1 PNG) + test_image_data = base64.b64encode( + b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde" + ).decode() + content = ImagePromptMessageContent( + base64_data=test_image_data, + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_raw.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert + # Verify tool file was created from base64 + mock_mgr.create_file_by_raw.assert_called_once() + call_kwargs = mock_mgr.create_file_by_raw.call_args[1] + assert call_kwargs["user_id"] == mock_user_id + assert call_kwargs["tenant_id"] == mock_tenant_id + assert call_kwargs["conversation_id"] is None + assert "file_binary" in call_kwargs + assert call_kwargs["mimetype"] == "image/png" + assert call_kwargs["filename"].startswith("generated_image") + assert call_kwargs["filename"].endswith(".png") + + # Verify message file was created + mock_msg_file_class.assert_called_once() + + # Verify database operations + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + mock_session.refresh.assert_called_once() + + # Verify event was published + mock_queue_manager.publish.assert_called_once() + + def test_handle_multimodal_image_content_with_base64_data_uri( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test handling image from base64 data with URI prefix.""" + # Arrange + # Data URI format: data:image/png;base64, + test_image_data = ( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + ) + content = ImagePromptMessageContent( + base64_data=f"data:image/png;base64,{test_image_data}", + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_raw.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - verify that base64 data was extracted correctly (without prefix) + mock_mgr.create_file_by_raw.assert_called_once() + call_kwargs = mock_mgr.create_file_by_raw.call_args[1] + # The base64 data should be decoded, so we check the binary was passed + assert "file_binary" in call_kwargs + + def test_handle_multimodal_image_content_without_url_or_base64( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + ): + """Test handling image content without URL or base64 data.""" + # Arrange + content = ImagePromptMessageContent( + url="", + base64_data="", + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - should not create any files or publish events + mock_mgr_class.assert_not_called() + mock_msg_file_class.assert_not_called() + mock_session.add.assert_not_called() + mock_queue_manager.publish.assert_not_called() + + def test_handle_multimodal_image_content_with_error( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + ): + """Test handling image content when an error occurs.""" + # Arrange + image_url = "http://example.com/image.png" + content = ImagePromptMessageContent( + url=image_url, + format="png", + mime_type="image/png", + ) + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock to raise exception + mock_mgr = MagicMock() + mock_mgr.create_file_by_url.side_effect = Exception("Network error") + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + # Should not raise exception, just log it + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - should not create message file or publish event on error + mock_msg_file_class.assert_not_called() + mock_session.add.assert_not_called() + mock_queue_manager.publish.assert_not_called() + + def test_handle_multimodal_image_content_debugger_mode( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test that debugger mode sets correct created_by_role.""" + # Arrange + image_url = "http://example.com/image.png" + content = ImagePromptMessageContent( + url=image_url, + format="png", + mime_type="image/png", + ) + mock_queue_manager.invoke_from = InvokeFrom.DEBUGGER + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_url.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - verify created_by_role is ACCOUNT for debugger mode + call_kwargs = mock_msg_file_class.call_args[1] + assert call_kwargs["created_by_role"] == CreatorUserRole.ACCOUNT + + def test_handle_multimodal_image_content_service_api_mode( + self, + mock_user_id, + mock_tenant_id, + mock_message_id, + mock_queue_manager, + mock_tool_file, + mock_message_file, + ): + """Test that service API mode sets correct created_by_role.""" + # Arrange + image_url = "http://example.com/image.png" + content = ImagePromptMessageContent( + url=image_url, + format="png", + mime_type="image/png", + ) + mock_queue_manager.invoke_from = InvokeFrom.SERVICE_API + + with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + # Setup mock tool file manager + mock_mgr = MagicMock() + mock_mgr.create_file_by_url.return_value = mock_tool_file + mock_mgr_class.return_value = mock_mgr + + with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + # Setup mock message file + mock_msg_file_class.return_value = mock_message_file + + with patch("core.app.apps.base_app_runner.db.session") as mock_session: + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.refresh = MagicMock() + + # Act + # Create a mock runner with the method bound + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert - verify created_by_role is END_USER for service API + call_kwargs = mock_msg_file_class.call_args[1] + assert call_kwargs["created_by_role"] == CreatorUserRole.END_USER diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py index 5ef7f0d7f4..5a43a247e3 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py @@ -1,7 +1,6 @@ """Unit tests for the message cycle manager optimization.""" -from types import SimpleNamespace -from unittest.mock import ANY, Mock, patch +from unittest.mock import Mock, patch import pytest from flask import current_app @@ -28,17 +27,14 @@ class TestMessageCycleManagerOptimization: def test_get_message_event_type_with_message_file(self, message_cycle_manager): """Test get_message_event_type returns MESSAGE_FILE when message has files.""" - with ( - patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, - patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), - ): + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: # Setup mock session and message file mock_session = Mock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session mock_message_file = Mock() - # Current implementation uses session.query(...).scalar() - mock_session.query.return_value.scalar.return_value = mock_message_file + # Current implementation uses session.scalar(select(...)) + mock_session.scalar.return_value = mock_message_file # Execute with current_app.app_context(): @@ -46,19 +42,16 @@ class TestMessageCycleManagerOptimization: # Assert assert result == StreamEvent.MESSAGE_FILE - mock_session.query.return_value.scalar.assert_called_once() + mock_session.scalar.assert_called_once() def test_get_message_event_type_without_message_file(self, message_cycle_manager): """Test get_message_event_type returns MESSAGE when message has no files.""" - with ( - patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, - patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), - ): + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: # Setup mock session and no message file mock_session = Mock() - mock_session_class.return_value.__enter__.return_value = mock_session - # Current implementation uses session.query(...).scalar() - mock_session.query.return_value.scalar.return_value = None + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + # Current implementation uses session.scalar(select(...)) + mock_session.scalar.return_value = None # Execute with current_app.app_context(): @@ -66,21 +59,18 @@ class TestMessageCycleManagerOptimization: # Assert assert result == StreamEvent.MESSAGE - mock_session.query.return_value.scalar.assert_called_once() + mock_session.scalar.assert_called_once() def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager): """MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it.""" - with ( - patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, - patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), - ): + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: # Setup mock session and message file mock_session = Mock() - mock_session_class.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session mock_message_file = Mock() - # Current implementation uses session.query(...).scalar() - mock_session.query.return_value.scalar.return_value = mock_message_file + # Current implementation uses session.scalar(select(...)) + mock_session.scalar.return_value = mock_message_file # Execute: compute event type once, then pass to message_to_stream_response with current_app.app_context(): @@ -94,11 +84,11 @@ class TestMessageCycleManagerOptimization: assert result.answer == "Hello world" assert result.id == "test-message-id" assert result.event == StreamEvent.MESSAGE_FILE - mock_session.query.return_value.scalar.assert_called_once() + mock_session.scalar.assert_called_once() def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager): """Test that message_to_stream_response skips database query when event_type is provided.""" - with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class: + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: # Execute with event_type provided result = message_cycle_manager.message_to_stream_response( answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE @@ -109,8 +99,8 @@ class TestMessageCycleManagerOptimization: assert result.answer == "Hello world" assert result.id == "test-message-id" assert result.event == StreamEvent.MESSAGE - # Should not query database when event_type is provided - mock_session_class.assert_not_called() + # Should not open a session when event_type is provided + mock_session_factory.create_session.assert_not_called() def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager): """Test message_to_stream_response with from_variable_selector parameter.""" @@ -130,24 +120,21 @@ class TestMessageCycleManagerOptimization: def test_optimization_usage_example(self, message_cycle_manager): """Test the optimization pattern that should be used by callers.""" # Step 1: Get event type once (this queries database) - with ( - patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class, - patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())), - ): + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: mock_session = Mock() - mock_session_class.return_value.__enter__.return_value = mock_session - # Current implementation uses session.query(...).scalar() - mock_session.query.return_value.scalar.return_value = None # No files + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + # Current implementation uses session.scalar(select(...)) + mock_session.scalar.return_value = None # No files with current_app.app_context(): event_type = message_cycle_manager.get_message_event_type("test-message-id") - # Should query database once - mock_session_class.assert_called_once_with(ANY, expire_on_commit=False) + # Should open session once + mock_session_factory.create_session.assert_called_once() assert event_type == StreamEvent.MESSAGE # Step 2: Use event_type for multiple calls (no additional queries) - with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class: - mock_session_class.return_value.__enter__.return_value = Mock() + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: + mock_session_factory.create_session.return_value.__enter__.return_value = Mock() chunk1_response = message_cycle_manager.message_to_stream_response( answer="Chunk 1", message_id="test-message-id", event_type=event_type @@ -157,8 +144,8 @@ class TestMessageCycleManagerOptimization: answer="Chunk 2", message_id="test-message-id", event_type=event_type ) - # Should not query database again - mock_session_class.assert_not_called() + # Should not open session again when event_type provided + mock_session_factory.create_session.assert_not_called() assert chunk1_response.event == StreamEvent.MESSAGE assert chunk2_response.event == StreamEvent.MESSAGE diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index d00743278e..eecb3c7672 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -171,22 +171,26 @@ class TestBillingServiceSendRequest: "status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND] ) def test_delete_request_non_200_with_valid_json(self, mock_httpx_request, mock_billing_config, status_code): - """Test DELETE request with non-200 status code but valid JSON response. + """Test DELETE request with non-200 status code raises ValueError. - DELETE doesn't check status code, so it returns the error JSON. + DELETE now checks status code and raises ValueError for non-200 responses. """ # Arrange error_response = {"detail": "Error message"} mock_response = MagicMock() mock_response.status_code = status_code + mock_response.text = "Error message" mock_response.json.return_value = error_response mock_httpx_request.return_value = mock_response - # Act - result = BillingService._send_request("DELETE", "/test", json={"key": "value"}) - - # Assert - assert result == error_response + # Act & Assert + with patch("services.billing_service.logger") as mock_logger: + with pytest.raises(ValueError) as exc_info: + BillingService._send_request("DELETE", "/test", json={"key": "value"}) + assert "Unable to process delete request" in str(exc_info.value) + # Verify error logging + mock_logger.error.assert_called_once() + assert "DELETE response" in str(mock_logger.error.call_args) @pytest.mark.parametrize( "status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND] @@ -210,9 +214,9 @@ class TestBillingServiceSendRequest: "status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND] ) def test_delete_request_non_200_with_invalid_json(self, mock_httpx_request, mock_billing_config, status_code): - """Test DELETE request with non-200 status code and invalid JSON response raises exception. + """Test DELETE request with non-200 status code raises ValueError before JSON parsing. - DELETE doesn't check status code, so it calls response.json() which raises JSONDecodeError + DELETE now checks status code before calling response.json(), so ValueError is raised when the response cannot be parsed as JSON (e.g., empty response). """ # Arrange @@ -223,8 +227,13 @@ class TestBillingServiceSendRequest: mock_httpx_request.return_value = mock_response # Act & Assert - with pytest.raises(json.JSONDecodeError): - BillingService._send_request("DELETE", "/test", json={"key": "value"}) + with patch("services.billing_service.logger") as mock_logger: + with pytest.raises(ValueError) as exc_info: + BillingService._send_request("DELETE", "/test", json={"key": "value"}) + assert "Unable to process delete request" in str(exc_info.value) + # Verify error logging + mock_logger.error.assert_called_once() + assert "DELETE response" in str(mock_logger.error.call_args) def test_retry_on_request_error(self, mock_httpx_request, mock_billing_config): """Test that _send_request retries on httpx.RequestError.""" @@ -789,7 +798,7 @@ class TestBillingServiceAccountManagement: # Assert assert result == expected_response - mock_send_request.assert_called_once_with("DELETE", "/account/", params={"account_id": account_id}) + mock_send_request.assert_called_once_with("DELETE", "/account", params={"account_id": account_id}) def test_is_email_in_freeze_true(self, mock_send_request): """Test checking if email is frozen (returns True).""" diff --git a/api/uv.lock b/api/uv.lock index 7853d06bf6..7808c16a8c 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1633,7 +1633,7 @@ requires-dist = [ { name = "pandas", extras = ["excel", "output-formatting", "performance"], specifier = "~=2.2.2" }, { name = "psycogreen", specifier = "~=1.0.2" }, { name = "psycopg2-binary", specifier = "~=2.9.6" }, - { name = "pycryptodome", specifier = "==3.19.1" }, + { name = "pycryptodome", specifier = "==3.23.0" }, { name = "pydantic", specifier = "~=2.11.4" }, { name = "pydantic-extra-types", specifier = "~=2.10.3" }, { name = "pydantic-settings", specifier = "~=2.11.0" }, @@ -4796,20 +4796,21 @@ wheels = [ [[package]] name = "pycryptodome" -version = "3.19.1" +version = "3.23.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b1/38/42a8855ff1bf568c61ca6557e2203f318fb7afeadaf2eb8ecfdbde107151/pycryptodome-3.19.1.tar.gz", hash = "sha256:8ae0dd1bcfada451c35f9e29a3e5db385caabc190f98e4a80ad02a61098fb776", size = 4782144, upload-time = "2023-12-28T06:52:40.741Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/a6/8452177684d5e906854776276ddd34eca30d1b1e15aa1ee9cefc289a33f5/pycryptodome-3.23.0.tar.gz", hash = "sha256:447700a657182d60338bab09fdb27518f8856aecd80ae4c6bdddb67ff5da44ef", size = 4921276, upload-time = "2025-05-17T17:21:45.242Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a8/ef/4931bc30674f0de0ca0e827b58c8b0c17313a8eae2754976c610b866118b/pycryptodome-3.19.1-cp35-abi3-macosx_10_9_universal2.whl", hash = "sha256:67939a3adbe637281c611596e44500ff309d547e932c449337649921b17b6297", size = 2417027, upload-time = "2023-12-28T06:51:50.138Z" }, - { url = "https://files.pythonhosted.org/packages/67/e6/238c53267fd8d223029c0a0d3730cb1b6594d60f62e40c4184703dc490b1/pycryptodome-3.19.1-cp35-abi3-macosx_10_9_x86_64.whl", hash = "sha256:11ddf6c9b52116b62223b6a9f4741bc4f62bb265392a4463282f7f34bb287180", size = 1579728, upload-time = "2023-12-28T06:51:52.385Z" }, - { url = "https://files.pythonhosted.org/packages/7c/87/7181c42c8d5ba89822a4b824830506d0aeec02959bb893614767e3279846/pycryptodome-3.19.1-cp35-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3e6f89480616781d2a7f981472d0cdb09b9da9e8196f43c1234eff45c915766", size = 2051440, upload-time = "2023-12-28T06:51:55.751Z" }, - { url = "https://files.pythonhosted.org/packages/34/dd/332c4c0055527d17dac317ed9f9c864fc047b627d82f4b9a56c110afc6fc/pycryptodome-3.19.1-cp35-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27e1efcb68993b7ce5d1d047a46a601d41281bba9f1971e6be4aa27c69ab8065", size = 2125379, upload-time = "2023-12-28T06:51:58.567Z" }, - { url = "https://files.pythonhosted.org/packages/24/9e/320b885ea336c218ff54ec2b276cd70ba6904e4f5a14a771ed39a2c47d59/pycryptodome-3.19.1-cp35-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1c6273ca5a03b672e504995529b8bae56da0ebb691d8ef141c4aa68f60765700", size = 2153951, upload-time = "2023-12-28T06:52:01.699Z" }, - { url = "https://files.pythonhosted.org/packages/f4/54/8ae0c43d1257b41bc9d3277c3f875174fd8ad86b9567f0b8609b99c938ee/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:b0bfe61506795877ff974f994397f0c862d037f6f1c0bfc3572195fc00833b96", size = 2044041, upload-time = "2023-12-28T06:52:03.737Z" }, - { url = "https://files.pythonhosted.org/packages/45/93/f8450a92cc38541c3ba1f4cb4e267e15ae6d6678ca617476d52c3a3764d4/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_i686.whl", hash = "sha256:f34976c5c8eb79e14c7d970fb097482835be8d410a4220f86260695ede4c3e17", size = 2182446, upload-time = "2023-12-28T06:52:05.588Z" }, - { url = "https://files.pythonhosted.org/packages/af/cd/ed6e429fb0792ce368f66e83246264dd3a7a045b0b1e63043ed22a063ce5/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7c9e222d0976f68d0cf6409cfea896676ddc1d98485d601e9508f90f60e2b0a2", size = 2144914, upload-time = "2023-12-28T06:52:07.44Z" }, - { url = "https://files.pythonhosted.org/packages/f6/23/b064bd4cfbf2cc5f25afcde0e7c880df5b20798172793137ba4b62d82e72/pycryptodome-3.19.1-cp35-abi3-win32.whl", hash = "sha256:4805e053571140cb37cf153b5c72cd324bb1e3e837cbe590a19f69b6cf85fd03", size = 1713105, upload-time = "2023-12-28T06:52:09.585Z" }, - { url = "https://files.pythonhosted.org/packages/7d/e0/ded1968a5257ab34216a0f8db7433897a2337d59e6d03be113713b346ea2/pycryptodome-3.19.1-cp35-abi3-win_amd64.whl", hash = "sha256:a470237ee71a1efd63f9becebc0ad84b88ec28e6784a2047684b693f458f41b7", size = 1749222, upload-time = "2023-12-28T06:52:11.534Z" }, + { url = "https://files.pythonhosted.org/packages/db/6c/a1f71542c969912bb0e106f64f60a56cc1f0fabecf9396f45accbe63fa68/pycryptodome-3.23.0-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:187058ab80b3281b1de11c2e6842a357a1f71b42cb1e15bce373f3d238135c27", size = 2495627, upload-time = "2025-05-17T17:20:47.139Z" }, + { url = "https://files.pythonhosted.org/packages/6e/4e/a066527e079fc5002390c8acdd3aca431e6ea0a50ffd7201551175b47323/pycryptodome-3.23.0-cp37-abi3-macosx_10_9_x86_64.whl", hash = "sha256:cfb5cd445280c5b0a4e6187a7ce8de5a07b5f3f897f235caa11f1f435f182843", size = 1640362, upload-time = "2025-05-17T17:20:50.392Z" }, + { url = "https://files.pythonhosted.org/packages/50/52/adaf4c8c100a8c49d2bd058e5b551f73dfd8cb89eb4911e25a0c469b6b4e/pycryptodome-3.23.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:67bd81fcbe34f43ad9422ee8fd4843c8e7198dd88dd3d40e6de42ee65fbe1490", size = 2182625, upload-time = "2025-05-17T17:20:52.866Z" }, + { url = "https://files.pythonhosted.org/packages/5f/e9/a09476d436d0ff1402ac3867d933c61805ec2326c6ea557aeeac3825604e/pycryptodome-3.23.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8987bd3307a39bc03df5c8e0e3d8be0c4c3518b7f044b0f4c15d1aa78f52575", size = 2268954, upload-time = "2025-05-17T17:20:55.027Z" }, + { url = "https://files.pythonhosted.org/packages/f9/c5/ffe6474e0c551d54cab931918127c46d70cab8f114e0c2b5a3c071c2f484/pycryptodome-3.23.0-cp37-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aa0698f65e5b570426fc31b8162ed4603b0c2841cbb9088e2b01641e3065915b", size = 2308534, upload-time = "2025-05-17T17:20:57.279Z" }, + { url = "https://files.pythonhosted.org/packages/18/28/e199677fc15ecf43010f2463fde4c1a53015d1fe95fb03bca2890836603a/pycryptodome-3.23.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:53ecbafc2b55353edcebd64bf5da94a2a2cdf5090a6915bcca6eca6cc452585a", size = 2181853, upload-time = "2025-05-17T17:20:59.322Z" }, + { url = "https://files.pythonhosted.org/packages/ce/ea/4fdb09f2165ce1365c9eaefef36625583371ee514db58dc9b65d3a255c4c/pycryptodome-3.23.0-cp37-abi3-musllinux_1_2_i686.whl", hash = "sha256:156df9667ad9f2ad26255926524e1c136d6664b741547deb0a86a9acf5ea631f", size = 2342465, upload-time = "2025-05-17T17:21:03.83Z" }, + { url = "https://files.pythonhosted.org/packages/22/82/6edc3fc42fe9284aead511394bac167693fb2b0e0395b28b8bedaa07ef04/pycryptodome-3.23.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:dea827b4d55ee390dc89b2afe5927d4308a8b538ae91d9c6f7a5090f397af1aa", size = 2267414, upload-time = "2025-05-17T17:21:06.72Z" }, + { url = "https://files.pythonhosted.org/packages/59/fe/aae679b64363eb78326c7fdc9d06ec3de18bac68be4b612fc1fe8902693c/pycryptodome-3.23.0-cp37-abi3-win32.whl", hash = "sha256:507dbead45474b62b2bbe318eb1c4c8ee641077532067fec9c1aa82c31f84886", size = 1768484, upload-time = "2025-05-17T17:21:08.535Z" }, + { url = "https://files.pythonhosted.org/packages/54/2f/e97a1b8294db0daaa87012c24a7bb714147c7ade7656973fd6c736b484ff/pycryptodome-3.23.0-cp37-abi3-win_amd64.whl", hash = "sha256:c75b52aacc6c0c260f204cbdd834f76edc9fb0d8e0da9fbf8352ef58202564e2", size = 1799636, upload-time = "2025-05-17T17:21:10.393Z" }, + { url = "https://files.pythonhosted.org/packages/18/3d/f9441a0d798bf2b1e645adc3265e55706aead1255ccdad3856dbdcffec14/pycryptodome-3.23.0-cp37-abi3-win_arm64.whl", hash = "sha256:11eeeb6917903876f134b56ba11abe95c0b0fd5e3330def218083c7d98bbcb3c", size = 1703675, upload-time = "2025-05-17T17:21:13.146Z" }, ] [[package]] @@ -5003,11 +5004,11 @@ wheels = [ [[package]] name = "pypdf" -version = "6.6.0" +version = "6.6.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d8/f4/801632a8b62a805378b6af2b5a3fcbfd8923abf647e0ed1af846a83433b2/pypdf-6.6.0.tar.gz", hash = "sha256:4c887ef2ea38d86faded61141995a3c7d068c9d6ae8477be7ae5de8a8e16592f", size = 5281063, upload-time = "2026-01-09T11:20:11.786Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b8/bb/a44bab1ac3c54dbcf653d7b8bcdee93dddb2d3bf025a3912cacb8149a2f2/pypdf-6.6.2.tar.gz", hash = "sha256:0a3ea3b3303982333404e22d8f75d7b3144f9cf4b2970b96856391a516f9f016", size = 5281850, upload-time = "2026-01-26T11:57:55.964Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b2/ba/96f99276194f720e74ed99905a080f6e77810558874e8935e580331b46de/pypdf-6.6.0-py3-none-any.whl", hash = "sha256:bca9091ef6de36c7b1a81e09327c554b7ce51e88dad68f5890c2b4a4417f1fd7", size = 328963, upload-time = "2026-01-09T11:20:09.278Z" }, + { url = "https://files.pythonhosted.org/packages/7d/be/549aaf1dfa4ab4aed29b09703d2fb02c4366fc1f05e880948c296c5764b9/pypdf-6.6.2-py3-none-any.whl", hash = "sha256:44c0c9811cfb3b83b28f1c3d054531d5b8b81abaedee0d8cb403650d023832ba", size = 329132, upload-time = "2026-01-26T11:57:54.099Z" }, ] [[package]] diff --git a/web/app/components/base/chat/chat/hooks.multimodal.spec.ts b/web/app/components/base/chat/chat/hooks.multimodal.spec.ts new file mode 100644 index 0000000000..2975d62887 --- /dev/null +++ b/web/app/components/base/chat/chat/hooks.multimodal.spec.ts @@ -0,0 +1,178 @@ +/** + * Tests for multimodal image file handling in chat hooks. + * Tests the file object conversion logic without full hook integration. + */ + +describe('Multimodal File Handling', () => { + describe('File type to MIME type mapping', () => { + it('should map image to image/png', () => { + const fileType: string = 'image' + const expectedMime = 'image/png' + const mimeType = fileType === 'image' ? 'image/png' : 'application/octet-stream' + expect(mimeType).toBe(expectedMime) + }) + + it('should map video to video/mp4', () => { + const fileType: string = 'video' + const expectedMime = 'video/mp4' + const mimeType = fileType === 'video' ? 'video/mp4' : 'application/octet-stream' + expect(mimeType).toBe(expectedMime) + }) + + it('should map audio to audio/mpeg', () => { + const fileType: string = 'audio' + const expectedMime = 'audio/mpeg' + const mimeType = fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream' + expect(mimeType).toBe(expectedMime) + }) + + it('should map unknown to application/octet-stream', () => { + const fileType: string = 'unknown' + const expectedMime = 'application/octet-stream' + const mimeType = ['image', 'video', 'audio'].includes(fileType) ? 'image/png' : 'application/octet-stream' + expect(mimeType).toBe(expectedMime) + }) + }) + + describe('TransferMethod selection', () => { + it('should select remote_url for images', () => { + const fileType: string = 'image' + const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file' + expect(transferMethod).toBe('remote_url') + }) + + it('should select local_file for non-images', () => { + const fileType: string = 'video' + const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file' + expect(transferMethod).toBe('local_file') + }) + }) + + describe('File extension mapping', () => { + it('should use .png extension for images', () => { + const fileType: string = 'image' + const expectedExtension = '.png' + const extension = fileType === 'image' ? 'png' : 'bin' + expect(extension).toBe(expectedExtension.replace('.', '')) + }) + + it('should use .mp4 extension for videos', () => { + const fileType: string = 'video' + const expectedExtension = '.mp4' + const extension = fileType === 'video' ? 'mp4' : 'bin' + expect(extension).toBe(expectedExtension.replace('.', '')) + }) + + it('should use .mp3 extension for audio', () => { + const fileType: string = 'audio' + const expectedExtension = '.mp3' + const extension = fileType === 'audio' ? 'mp3' : 'bin' + expect(extension).toBe(expectedExtension.replace('.', '')) + }) + }) + + describe('File name generation', () => { + it('should generate correct file name for images', () => { + const fileType: string = 'image' + const expectedName = 'generated_image.png' + const fileName = `generated_${fileType}.${fileType === 'image' ? 'png' : 'bin'}` + expect(fileName).toBe(expectedName) + }) + + it('should generate correct file name for videos', () => { + const fileType: string = 'video' + const expectedName = 'generated_video.mp4' + const fileName = `generated_${fileType}.${fileType === 'video' ? 'mp4' : 'bin'}` + expect(fileName).toBe(expectedName) + }) + + it('should generate correct file name for audio', () => { + const fileType: string = 'audio' + const expectedName = 'generated_audio.mp3' + const fileName = `generated_${fileType}.${fileType === 'audio' ? 'mp3' : 'bin'}` + expect(fileName).toBe(expectedName) + }) + }) + + describe('SupportFileType mapping', () => { + it('should map image type to image supportFileType', () => { + const fileType: string = 'image' + const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document' + expect(supportFileType).toBe('image') + }) + + it('should map video type to video supportFileType', () => { + const fileType: string = 'video' + const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document' + expect(supportFileType).toBe('video') + }) + + it('should map audio type to audio supportFileType', () => { + const fileType: string = 'audio' + const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document' + expect(supportFileType).toBe('audio') + }) + + it('should map unknown type to document supportFileType', () => { + const fileType: string = 'unknown' + const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document' + expect(supportFileType).toBe('document') + }) + }) + + describe('File conversion logic', () => { + it('should detect existing transferMethod', () => { + const fileWithTransferMethod = { + id: 'file-123', + transferMethod: 'remote_url' as const, + type: 'image/png', + name: 'test.png', + size: 1024, + supportFileType: 'image', + progress: 100, + } + const hasTransferMethod = 'transferMethod' in fileWithTransferMethod + expect(hasTransferMethod).toBe(true) + }) + + it('should detect missing transferMethod', () => { + const fileWithoutTransferMethod = { + id: 'file-456', + type: 'image', + url: 'http://example.com/image.png', + belongs_to: 'assistant', + } + const hasTransferMethod = 'transferMethod' in fileWithoutTransferMethod + expect(hasTransferMethod).toBe(false) + }) + + it('should create file with size 0 for generated files', () => { + const expectedSize = 0 + expect(expectedSize).toBe(0) + }) + }) + + describe('Agent vs Non-Agent mode logic', () => { + it('should check for agent_thoughts to determine mode', () => { + const agentResponse: { agent_thoughts?: Array> } = { + agent_thoughts: [{}], + } + const isAgentMode = agentResponse.agent_thoughts && agentResponse.agent_thoughts.length > 0 + expect(isAgentMode).toBe(true) + }) + + it('should detect non-agent mode when agent_thoughts is empty', () => { + const nonAgentResponse: { agent_thoughts?: Array> } = { + agent_thoughts: [], + } + const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0 + expect(isAgentMode).toBe(false) + }) + + it('should detect non-agent mode when agent_thoughts is undefined', () => { + const nonAgentResponse: { agent_thoughts?: Array> } = {} + const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0 + expect(isAgentMode).toBeFalsy() + }) + }) +}) diff --git a/web/app/components/base/chat/chat/hooks.ts b/web/app/components/base/chat/chat/hooks.ts index 9b8a9b11dc..182aeebdbb 100644 --- a/web/app/components/base/chat/chat/hooks.ts +++ b/web/app/components/base/chat/chat/hooks.ts @@ -419,9 +419,40 @@ export const useChat = ( } }, onFile(file) { + // Convert simple file type to MIME type for non-agent mode + // Backend sends: { id, type: "image", belongs_to, url } + // Frontend expects: { id, type: "image/png", transferMethod, url, uploadedId, supportFileType, name, size } + + // Determine file type for MIME conversion + const fileType = (file as { type?: string }).type || 'image' + + // If file already has transferMethod, use it as base and ensure all required fields exist + // Otherwise, create a new complete file object + const baseFile = ('transferMethod' in file) ? (file as Partial) : null + + const convertedFile: FileEntity = { + id: baseFile?.id || (file as { id: string }).id, + type: baseFile?.type || (fileType === 'image' ? 'image/png' : fileType === 'video' ? 'video/mp4' : fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream'), + transferMethod: (baseFile?.transferMethod as FileEntity['transferMethod']) || (fileType === 'image' ? 'remote_url' : 'local_file'), + uploadedId: baseFile?.uploadedId || (file as { id: string }).id, + supportFileType: baseFile?.supportFileType || (fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'), + progress: baseFile?.progress ?? 100, + name: baseFile?.name || `generated_${fileType}.${fileType === 'image' ? 'png' : fileType === 'video' ? 'mp4' : fileType === 'audio' ? 'mp3' : 'bin'}`, + url: baseFile?.url || (file as { url?: string }).url, + size: baseFile?.size ?? 0, // Generated files don't have a known size + } + + // For agent mode, add files to the last thought const lastThought = responseItem.agent_thoughts?.[responseItem.agent_thoughts?.length - 1] - if (lastThought) - responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(lastThought as any).message_files, file] + if (lastThought) { + const thought = lastThought as { message_files?: FileEntity[] } + responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(thought.message_files ?? []), convertedFile] + } + // For non-agent mode, add files directly to responseItem.message_files + else { + const currentFiles = (responseItem.message_files as FileEntity[] | undefined) ?? [] + responseItem.message_files = [...currentFiles, convertedFile] + } updateCurrentQAOnTree({ placeholderQuestionId, diff --git a/web/app/components/base/icons/icon-gallery.stories.tsx b/web/app/components/base/icons/icon-gallery.stories.tsx index 15206f2735..8d49f70ce2 100644 --- a/web/app/components/base/icons/icon-gallery.stories.tsx +++ b/web/app/components/base/icons/icon-gallery.stories.tsx @@ -1,9 +1,9 @@ +/// import type { Meta, StoryObj } from '@storybook/nextjs-vite' import * as React from 'react' -declare const require: any - type IconComponent = React.ComponentType> +type IconModule = { default: IconComponent } type IconEntry = { name: string @@ -12,18 +12,16 @@ type IconEntry = { Component: IconComponent } -const iconContext = require.context('./src', true, /\.tsx$/) +const iconModules: Record = import.meta.glob('./src/**/*.tsx', { eager: true }) -const iconEntries: IconEntry[] = iconContext - .keys() - .filter((key: string) => !key.endsWith('.stories.tsx') && !key.endsWith('.spec.tsx')) - .map((key: string) => { - const mod = iconContext(key) - const Component = mod.default as IconComponent | undefined +const iconEntries: IconEntry[] = Object.entries(iconModules) + .filter(([key]) => !key.endsWith('.stories.tsx') && !key.endsWith('.spec.tsx')) + .map(([key, mod]) => { + const Component = mod.default if (!Component) return null - const relativePath = key.replace(/^\.\//, '') + const relativePath = key.replace(/^\.\/src\//, '') const path = `app/components/base/icons/src/${relativePath}` const parts = relativePath.split('/') const fileName = parts.pop() || '' diff --git a/web/app/components/datasets/hit-testing/index.spec.tsx b/web/app/components/datasets/hit-testing/index.spec.tsx index 6bab3afb6a..07a78cd55f 100644 --- a/web/app/components/datasets/hit-testing/index.spec.tsx +++ b/web/app/components/datasets/hit-testing/index.spec.tsx @@ -2039,8 +2039,13 @@ describe('Integration: Hit Testing Flow', () => { renderWithProviders() + // Wait for textbox with timeout for CI + const textarea = await waitFor( + () => screen.getByRole('textbox'), + { timeout: 3000 }, + ) + // Type query - const textarea = screen.getByRole('textbox') fireEvent.change(textarea, { target: { value: 'Test query' } }) // Find submit button by class @@ -2054,8 +2059,13 @@ describe('Integration: Hit Testing Flow', () => { const { container } = renderWithProviders() + // Wait for textbox with timeout for CI + const textarea = await waitFor( + () => screen.getByRole('textbox'), + { timeout: 3000 }, + ) + // Type query - const textarea = screen.getByRole('textbox') fireEvent.change(textarea, { target: { value: 'Test query' } }) // Component should still be functional - check for the main container @@ -2089,10 +2099,15 @@ describe('Integration: Hit Testing Flow', () => { isLoading: false, } as unknown as ReturnType) - const { container } = renderWithProviders() + const { container: _container } = renderWithProviders() + + // Wait for textbox to be rendered with timeout for CI environment + const textarea = await waitFor( + () => screen.getByRole('textbox'), + { timeout: 3000 }, + ) // Type query - const textarea = screen.getByRole('textbox') fireEvent.change(textarea, { target: { value: 'Test query' } }) // Submit @@ -2101,8 +2116,13 @@ describe('Integration: Hit Testing Flow', () => { if (submitButton) fireEvent.click(submitButton) - // Verify the component is still rendered after submission - expect(container.firstChild).toBeInTheDocument() + // Wait for the mutation to complete + await waitFor( + () => { + expect(mockHitTestingMutateAsync).toHaveBeenCalled() + }, + { timeout: 3000 }, + ) }) it('should render ResultItem components for non-external results', async () => { @@ -2127,10 +2147,15 @@ describe('Integration: Hit Testing Flow', () => { isLoading: false, } as unknown as ReturnType) - const { container } = renderWithProviders() + const { container: _container } = renderWithProviders() + + // Wait for component to be fully rendered with longer timeout + const textarea = await waitFor( + () => screen.getByRole('textbox'), + { timeout: 3000 }, + ) // Submit a query - const textarea = screen.getByRole('textbox') fireEvent.change(textarea, { target: { value: 'Test query' } }) const buttons = screen.getAllByRole('button') @@ -2138,8 +2163,13 @@ describe('Integration: Hit Testing Flow', () => { if (submitButton) fireEvent.click(submitButton) - // Verify component is rendered after submission - expect(container.firstChild).toBeInTheDocument() + // Wait for mutation to complete with longer timeout + await waitFor( + () => { + expect(mockHitTestingMutateAsync).toHaveBeenCalled() + }, + { timeout: 3000 }, + ) }) it('should render external results when dataset is external', async () => { @@ -2165,8 +2195,14 @@ describe('Integration: Hit Testing Flow', () => { // Component should render expect(container.firstChild).toBeInTheDocument() + + // Wait for textbox with timeout for CI + const textarea = await waitFor( + () => screen.getByRole('textbox'), + { timeout: 3000 }, + ) + // Type in textarea to verify component is functional - const textarea = screen.getByRole('textbox') fireEvent.change(textarea, { target: { value: 'Test query' } }) const buttons = screen.getAllByRole('button') @@ -2174,9 +2210,13 @@ describe('Integration: Hit Testing Flow', () => { if (submitButton) fireEvent.click(submitButton) - await waitFor(() => { - expect(screen.getByRole('textbox')).toBeInTheDocument() - }) + // Verify component is still functional after submission + await waitFor( + () => { + expect(screen.getByRole('textbox')).toBeInTheDocument() + }, + { timeout: 3000 }, + ) }) }) @@ -2260,8 +2300,13 @@ describe('renderHitResults Coverage', () => { const { container } = renderWithProviders() + // Wait for textbox with timeout for CI + const textarea = await waitFor( + () => screen.getByRole('textbox'), + { timeout: 3000 }, + ) + // Enter query - const textarea = screen.getByRole('textbox') fireEvent.change(textarea, { target: { value: 'test query' } }) // Submit @@ -2386,8 +2431,13 @@ describe('HitTestingPage Internal Functions Coverage', () => { const { container } = renderWithProviders() + // Wait for textbox with timeout for CI + const textarea = await waitFor( + () => screen.getByRole('textbox'), + { timeout: 3000 }, + ) + // Enter query and submit - const textarea = screen.getByRole('textbox') fireEvent.change(textarea, { target: { value: 'test query' } }) const buttons = screen.getAllByRole('button') @@ -2400,7 +2450,7 @@ describe('HitTestingPage Internal Functions Coverage', () => { // Wait for state updates await waitFor(() => { expect(container.firstChild).toBeInTheDocument() - }, { timeout: 2000 }) + }, { timeout: 3000 }) // Verify mutation was called expect(mockHitTestingMutateAsync).toHaveBeenCalled() @@ -2445,8 +2495,13 @@ describe('HitTestingPage Internal Functions Coverage', () => { const { container } = renderWithProviders() + // Wait for textbox with timeout for CI + const textarea = await waitFor( + () => screen.getByRole('textbox'), + { timeout: 3000 }, + ) + // Submit a query - const textarea = screen.getByRole('textbox') fireEvent.change(textarea, { target: { value: 'test' } }) const buttons = screen.getAllByRole('button') @@ -2458,7 +2513,7 @@ describe('HitTestingPage Internal Functions Coverage', () => { // Verify the component renders await waitFor(() => { expect(container.firstChild).toBeInTheDocument() - }) + }, { timeout: 3000 }) }) }) diff --git a/web/app/components/plugins/marketplace/index.spec.tsx b/web/app/components/plugins/marketplace/index.spec.tsx index 654b667deb..1c0c700177 100644 --- a/web/app/components/plugins/marketplace/index.spec.tsx +++ b/web/app/components/plugins/marketplace/index.spec.tsx @@ -162,6 +162,44 @@ vi.mock('@/utils/var', () => ({ getMarketplaceUrl: (path: string, _params?: Record) => `https://marketplace.dify.ai${path}`, })) +// Mock marketplace client used by marketplace utils +vi.mock('@/service/client', () => ({ + marketplaceClient: { + collections: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({ + data: { + collections: [ + { + name: 'collection-1', + label: { 'en-US': 'Collection 1' }, + description: { 'en-US': 'Desc' }, + rule: '', + created_at: '2024-01-01', + updated_at: '2024-01-01', + searchable: true, + search_params: { query: '', sort_by: 'install_count', sort_order: 'DESC' }, + }, + ], + }, + })), + collectionPlugins: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({ + data: { + plugins: [ + { type: 'plugin', org: 'test', name: 'plugin1', tags: [] }, + ], + }, + })), + // Some utils paths may call searchAdvanced; provide a minimal stub + searchAdvanced: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({ + data: { + plugins: [ + { type: 'plugin', org: 'test', name: 'plugin1', tags: [] }, + ], + total: 1, + }, + })), + }, +})) + // Mock context/query-client vi.mock('@/context/query-client', () => ({ TanstackQueryInitializer: ({ children }: { children: React.ReactNode }) =>
{children}
, @@ -1474,7 +1512,24 @@ describe('flatMap Coverage', () => { // ================================ // Async Utils Tests // ================================ + +// Narrow mock surface and avoid any in tests +// Types are local to this spec to keep scope minimal + +type FnMock = ReturnType + +type MarketplaceClientMock = { + collectionPlugins: FnMock + collections: FnMock +} + describe('Async Utils', () => { + let marketplaceClientMock: MarketplaceClientMock + + beforeAll(async () => { + const mod = await import('@/service/client') + marketplaceClientMock = mod.marketplaceClient as unknown as MarketplaceClientMock + }) beforeEach(() => { vi.clearAllMocks() }) @@ -1490,12 +1545,10 @@ describe('Async Utils', () => { { type: 'plugin', org: 'test', name: 'plugin2' }, ] - globalThis.fetch = vi.fn().mockResolvedValue( - new Response(JSON.stringify({ data: { plugins: mockPlugins } }), { - status: 200, - headers: { 'Content-Type': 'application/json' }, - }), - ) + // Adjusted to our mocked marketplaceClient instead of fetch + marketplaceClientMock.collectionPlugins.mockResolvedValueOnce({ + data: { plugins: mockPlugins }, + }) const { getMarketplacePluginsByCollectionId } = await import('./utils') const result = await getMarketplacePluginsByCollectionId('test-collection', { @@ -1504,12 +1557,13 @@ describe('Async Utils', () => { type: 'plugin', }) - expect(globalThis.fetch).toHaveBeenCalled() + expect(marketplaceClientMock.collectionPlugins).toHaveBeenCalled() expect(result).toHaveLength(2) }) it('should handle fetch error and return empty array', async () => { - globalThis.fetch = vi.fn().mockRejectedValue(new Error('Network error')) + // Simulate error from client + marketplaceClientMock.collectionPlugins.mockRejectedValueOnce(new Error('Network error')) const { getMarketplacePluginsByCollectionId } = await import('./utils') const result = await getMarketplacePluginsByCollectionId('test-collection') @@ -1519,25 +1573,18 @@ describe('Async Utils', () => { it('should pass abort signal when provided', async () => { const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }] - globalThis.fetch = vi.fn().mockResolvedValue( - new Response(JSON.stringify({ data: { plugins: mockPlugins } }), { - status: 200, - headers: { 'Content-Type': 'application/json' }, - }), - ) + // Our client mock receives the signal as second arg + marketplaceClientMock.collectionPlugins.mockResolvedValueOnce({ + data: { plugins: mockPlugins }, + }) const controller = new AbortController() const { getMarketplacePluginsByCollectionId } = await import('./utils') await getMarketplacePluginsByCollectionId('test-collection', {}, { signal: controller.signal }) - // oRPC uses Request objects, so check that fetch was called with a Request containing the right URL - expect(globalThis.fetch).toHaveBeenCalledWith( - expect.any(Request), - expect.any(Object), - ) - const call = vi.mocked(globalThis.fetch).mock.calls[0] - const request = call[0] as Request - expect(request.url).toContain('test-collection') + expect(marketplaceClientMock.collectionPlugins).toHaveBeenCalled() + const call = marketplaceClientMock.collectionPlugins.mock.calls[0] + expect(call[1]).toMatchObject({ signal: controller.signal }) }) }) @@ -1548,23 +1595,17 @@ describe('Async Utils', () => { ] const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }] - let callCount = 0 - globalThis.fetch = vi.fn().mockImplementation(() => { - callCount++ - if (callCount === 1) { - return Promise.resolve( - new Response(JSON.stringify({ data: { collections: mockCollections } }), { - status: 200, - headers: { 'Content-Type': 'application/json' }, - }), - ) + // Simulate two-step client calls: collections then collectionPlugins + let stage = 0 + marketplaceClientMock.collections.mockImplementationOnce(async () => { + stage = 1 + return { data: { collections: mockCollections } } + }) + marketplaceClientMock.collectionPlugins.mockImplementation(async () => { + if (stage === 1) { + return { data: { plugins: mockPlugins } } } - return Promise.resolve( - new Response(JSON.stringify({ data: { plugins: mockPlugins } }), { - status: 200, - headers: { 'Content-Type': 'application/json' }, - }), - ) + return { data: { plugins: [] } } }) const { getMarketplaceCollectionsAndPlugins } = await import('./utils') @@ -1578,7 +1619,8 @@ describe('Async Utils', () => { }) it('should handle fetch error and return empty data', async () => { - globalThis.fetch = vi.fn().mockRejectedValue(new Error('Network error')) + // Simulate client error + marketplaceClientMock.collections.mockRejectedValueOnce(new Error('Network error')) const { getMarketplaceCollectionsAndPlugins } = await import('./utils') const result = await getMarketplaceCollectionsAndPlugins() @@ -1588,24 +1630,16 @@ describe('Async Utils', () => { }) it('should append condition and type to URL when provided', async () => { - globalThis.fetch = vi.fn().mockResolvedValue( - new Response(JSON.stringify({ data: { collections: [] } }), { - status: 200, - headers: { 'Content-Type': 'application/json' }, - }), - ) - + // Assert that the client was called with query containing condition/type const { getMarketplaceCollectionsAndPlugins } = await import('./utils') await getMarketplaceCollectionsAndPlugins({ condition: 'category=tool', type: 'bundle', }) - // oRPC uses Request objects, so check that fetch was called with a Request containing the right URL - expect(globalThis.fetch).toHaveBeenCalled() - const call = vi.mocked(globalThis.fetch).mock.calls[0] - const request = call[0] as Request - expect(request.url).toContain('condition=category%3Dtool') + expect(marketplaceClientMock.collections).toHaveBeenCalled() + const call = marketplaceClientMock.collections.mock.calls[0] + expect(call[0]).toMatchObject({ query: expect.objectContaining({ condition: 'category=tool', type: 'bundle' }) }) }) }) }) diff --git a/web/app/components/tools/mcp/create-card.spec.tsx b/web/app/components/tools/mcp/create-card.spec.tsx new file mode 100644 index 0000000000..9ddee00460 --- /dev/null +++ b/web/app/components/tools/mcp/create-card.spec.tsx @@ -0,0 +1,221 @@ +import type { ReactNode } from 'react' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import * as React from 'react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import NewMCPCard from './create-card' + +// Track the mock functions +const mockCreateMCP = vi.fn().mockResolvedValue({ id: 'new-mcp-id', name: 'New MCP' }) + +// Mock the service +vi.mock('@/service/use-tools', () => ({ + useCreateMCP: () => ({ + mutateAsync: mockCreateMCP, + }), +})) + +// Mock the MCP Modal +type MockMCPModalProps = { + show: boolean + onConfirm: (info: { name: string, server_url: string }) => void + onHide: () => void +} + +vi.mock('./modal', () => ({ + default: ({ show, onConfirm, onHide }: MockMCPModalProps) => { + if (!show) + return null + return ( +
+ tools.mcp.modal.title + + +
+ ) + }, +})) + +// Mutable workspace manager state +let mockIsCurrentWorkspaceManager = true + +// Mock the app context +vi.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + isCurrentWorkspaceManager: mockIsCurrentWorkspaceManager, + isCurrentWorkspaceEditor: true, + }), +})) + +// Mock the plugins service +vi.mock('@/service/use-plugins', () => ({ + useInstalledPluginList: () => ({ + data: { pages: [] }, + hasNextPage: false, + isFetchingNextPage: false, + fetchNextPage: vi.fn(), + isLoading: false, + isSuccess: true, + }), +})) + +// Mock common service +vi.mock('@/service/common', () => ({ + uploadRemoteFileInfo: vi.fn().mockResolvedValue({ url: 'https://example.com/icon.png' }), +})) + +describe('NewMCPCard', () => { + const createWrapper = () => { + const queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + }, + }) + return ({ children }: { children: ReactNode }) => + React.createElement(QueryClientProvider, { client: queryClient }, children) + } + + const defaultProps = { + handleCreate: vi.fn(), + } + + beforeEach(() => { + mockCreateMCP.mockClear() + mockIsCurrentWorkspaceManager = true + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.create.cardTitle')).toBeInTheDocument() + }) + + it('should render card title', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.create.cardTitle')).toBeInTheDocument() + }) + + it('should render documentation link', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.create.cardLink')).toBeInTheDocument() + }) + + it('should render add icon', () => { + render(, { wrapper: createWrapper() }) + const svgElements = document.querySelectorAll('svg') + expect(svgElements.length).toBeGreaterThan(0) + }) + }) + + describe('User Interactions', () => { + it('should open modal when card is clicked', async () => { + render(, { wrapper: createWrapper() }) + + const cardTitle = screen.getByText('tools.mcp.create.cardTitle') + const clickableArea = cardTitle.closest('.group') + + if (clickableArea) { + fireEvent.click(clickableArea) + + await waitFor(() => { + expect(screen.getByText('tools.mcp.modal.title')).toBeInTheDocument() + }) + } + }) + + it('should have documentation link with correct target', () => { + render(, { wrapper: createWrapper() }) + + const docLink = screen.getByText('tools.mcp.create.cardLink').closest('a') + expect(docLink).toHaveAttribute('target', '_blank') + expect(docLink).toHaveAttribute('rel', 'noopener noreferrer') + }) + }) + + describe('Non-Manager User', () => { + it('should not render card when user is not workspace manager', () => { + mockIsCurrentWorkspaceManager = false + + render(, { wrapper: createWrapper() }) + + expect(screen.queryByText('tools.mcp.create.cardTitle')).not.toBeInTheDocument() + }) + }) + + describe('Styling', () => { + it('should have correct card structure', () => { + render(, { wrapper: createWrapper() }) + + const card = document.querySelector('.rounded-xl') + expect(card).toBeInTheDocument() + }) + + it('should have clickable cursor style', () => { + render(, { wrapper: createWrapper() }) + + const card = document.querySelector('.cursor-pointer') + expect(card).toBeInTheDocument() + }) + }) + + describe('Modal Interactions', () => { + it('should call create function when modal confirms', async () => { + const handleCreate = vi.fn() + render(, { wrapper: createWrapper() }) + + // Open the modal + const cardTitle = screen.getByText('tools.mcp.create.cardTitle') + const clickableArea = cardTitle.closest('.group') + + if (clickableArea) { + fireEvent.click(clickableArea) + + await waitFor(() => { + expect(screen.getByTestId('mcp-modal')).toBeInTheDocument() + }) + + // Click confirm + const confirmBtn = screen.getByTestId('confirm-btn') + fireEvent.click(confirmBtn) + + await waitFor(() => { + expect(mockCreateMCP).toHaveBeenCalledWith({ + name: 'Test MCP', + server_url: 'https://test.com', + }) + expect(handleCreate).toHaveBeenCalled() + }) + } + }) + + it('should close modal when close button is clicked', async () => { + render(, { wrapper: createWrapper() }) + + // Open the modal + const cardTitle = screen.getByText('tools.mcp.create.cardTitle') + const clickableArea = cardTitle.closest('.group') + + if (clickableArea) { + fireEvent.click(clickableArea) + + await waitFor(() => { + expect(screen.getByTestId('mcp-modal')).toBeInTheDocument() + }) + + // Click close + const closeBtn = screen.getByTestId('close-btn') + fireEvent.click(closeBtn) + + await waitFor(() => { + expect(screen.queryByTestId('mcp-modal')).not.toBeInTheDocument() + }) + } + }) + }) +}) diff --git a/web/app/components/tools/mcp/detail/content.spec.tsx b/web/app/components/tools/mcp/detail/content.spec.tsx new file mode 100644 index 0000000000..fe3fbd2bc3 --- /dev/null +++ b/web/app/components/tools/mcp/detail/content.spec.tsx @@ -0,0 +1,855 @@ +import type { ReactNode } from 'react' +import type { ToolWithProvider } from '@/app/components/workflow/types' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import * as React from 'react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import MCPDetailContent from './content' + +// Mutable mock functions +const mockUpdateTools = vi.fn().mockResolvedValue({}) +const mockAuthorizeMcp = vi.fn().mockResolvedValue({ result: 'success' }) +const mockUpdateMCP = vi.fn().mockResolvedValue({ result: 'success' }) +const mockDeleteMCP = vi.fn().mockResolvedValue({ result: 'success' }) +const mockInvalidateMCPTools = vi.fn() +const mockOpenOAuthPopup = vi.fn() + +// Mutable mock state +type MockTool = { + id: string + name: string + description?: string +} + +let mockToolsData: { tools: MockTool[] } = { tools: [] } +let mockIsFetching = false +let mockIsUpdating = false +let mockIsAuthorizing = false + +// Mock the services +vi.mock('@/service/use-tools', () => ({ + useMCPTools: () => ({ + data: mockToolsData, + isFetching: mockIsFetching, + }), + useInvalidateMCPTools: () => mockInvalidateMCPTools, + useUpdateMCPTools: () => ({ + mutateAsync: mockUpdateTools, + isPending: mockIsUpdating, + }), + useAuthorizeMCP: () => ({ + mutateAsync: mockAuthorizeMcp, + isPending: mockIsAuthorizing, + }), + useUpdateMCP: () => ({ + mutateAsync: mockUpdateMCP, + }), + useDeleteMCP: () => ({ + mutateAsync: mockDeleteMCP, + }), +})) + +// Mock OAuth hook +type OAuthArgs = readonly unknown[] +vi.mock('@/hooks/use-oauth', () => ({ + openOAuthPopup: (...args: OAuthArgs) => mockOpenOAuthPopup(...args), +})) + +// Mock MCPModal +type MCPModalData = { + name: string + server_url: string +} + +type MCPModalProps = { + show: boolean + onConfirm: (data: MCPModalData) => void + onHide: () => void +} + +vi.mock('../modal', () => ({ + default: ({ show, onConfirm, onHide }: MCPModalProps) => { + if (!show) + return null + return ( +
+ + +
+ ) + }, +})) + +// Mock Confirm dialog +vi.mock('@/app/components/base/confirm', () => ({ + default: ({ isShow, onConfirm, onCancel, title }: { isShow: boolean, onConfirm: () => void, onCancel: () => void, title: string }) => { + if (!isShow) + return null + return ( +
+ + +
+ ) + }, +})) + +// Mock OperationDropdown +vi.mock('./operation-dropdown', () => ({ + default: ({ onEdit, onRemove }: { onEdit: () => void, onRemove: () => void }) => ( +
+ + +
+ ), +})) + +// Mock ToolItem +type ToolItemData = { + name: string +} + +vi.mock('./tool-item', () => ({ + default: ({ tool }: { tool: ToolItemData }) => ( +
{tool.name}
+ ), +})) + +// Mutable workspace manager state +let mockIsCurrentWorkspaceManager = true + +// Mock the app context +vi.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + isCurrentWorkspaceManager: mockIsCurrentWorkspaceManager, + isCurrentWorkspaceEditor: true, + }), +})) + +// Mock the plugins service +vi.mock('@/service/use-plugins', () => ({ + useInstalledPluginList: () => ({ + data: { pages: [] }, + hasNextPage: false, + isFetchingNextPage: false, + fetchNextPage: vi.fn(), + isLoading: false, + isSuccess: true, + }), +})) + +// Mock common service +vi.mock('@/service/common', () => ({ + uploadRemoteFileInfo: vi.fn().mockResolvedValue({ url: 'https://example.com/icon.png' }), +})) + +// Mock copy-to-clipboard +vi.mock('copy-to-clipboard', () => ({ + default: vi.fn(), +})) + +describe('MCPDetailContent', () => { + const createWrapper = () => { + const queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + }, + }) + return ({ children }: { children: ReactNode }) => + React.createElement(QueryClientProvider, { client: queryClient }, children) + } + + const createMockDetail = (overrides = {}): ToolWithProvider => ({ + id: 'mcp-1', + name: 'Test MCP Server', + server_identifier: 'test-mcp', + server_url: 'https://example.com/mcp', + icon: { content: '🔧', background: '#FF0000' }, + tools: [], + is_team_authorization: false, + ...overrides, + } as unknown as ToolWithProvider) + + const defaultProps = { + detail: createMockDetail(), + onUpdate: vi.fn(), + onHide: vi.fn(), + isTriggerAuthorize: false, + onFirstCreate: vi.fn(), + } + + beforeEach(() => { + // Reset mocks + mockUpdateTools.mockClear() + mockAuthorizeMcp.mockClear() + mockUpdateMCP.mockClear() + mockDeleteMCP.mockClear() + mockInvalidateMCPTools.mockClear() + mockOpenOAuthPopup.mockClear() + + // Reset mock return values + mockUpdateTools.mockResolvedValue({}) + mockAuthorizeMcp.mockResolvedValue({ result: 'success' }) + mockUpdateMCP.mockResolvedValue({ result: 'success' }) + mockDeleteMCP.mockResolvedValue({ result: 'success' }) + + // Reset state + mockToolsData = { tools: [] } + mockIsFetching = false + mockIsUpdating = false + mockIsAuthorizing = false + mockIsCurrentWorkspaceManager = true + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('Test MCP Server')).toBeInTheDocument() + }) + + it('should display MCP name', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('Test MCP Server')).toBeInTheDocument() + }) + + it('should display server identifier', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('test-mcp')).toBeInTheDocument() + }) + + it('should display server URL', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('https://example.com/mcp')).toBeInTheDocument() + }) + + it('should render close button', () => { + render(, { wrapper: createWrapper() }) + // Close button should be present + const closeButtons = document.querySelectorAll('button') + expect(closeButtons.length).toBeGreaterThan(0) + }) + + it('should render operation dropdown', () => { + render(, { wrapper: createWrapper() }) + // Operation dropdown trigger should be present + expect(document.querySelector('button')).toBeInTheDocument() + }) + }) + + describe('Authorization State', () => { + it('should show authorize button when not authorized', () => { + const detail = createMockDetail({ is_team_authorization: false }) + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText('tools.mcp.authorize')).toBeInTheDocument() + }) + + it('should show authorized button when authorized', () => { + const detail = createMockDetail({ is_team_authorization: true }) + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText('tools.auth.authorized')).toBeInTheDocument() + }) + + it('should show authorization required message when not authorized', () => { + const detail = createMockDetail({ is_team_authorization: false }) + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText('tools.mcp.authorizingRequired')).toBeInTheDocument() + }) + + it('should show authorization tip', () => { + const detail = createMockDetail({ is_team_authorization: false }) + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText('tools.mcp.authorizeTip')).toBeInTheDocument() + }) + }) + + describe('Empty Tools State', () => { + it('should show empty message when authorized but no tools', () => { + const detail = createMockDetail({ is_team_authorization: true, tools: [] }) + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText('tools.mcp.toolsEmpty')).toBeInTheDocument() + }) + + it('should show get tools button when empty', () => { + const detail = createMockDetail({ is_team_authorization: true, tools: [] }) + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText('tools.mcp.getTools')).toBeInTheDocument() + }) + }) + + describe('Icon Display', () => { + it('should render MCP icon', () => { + render(, { wrapper: createWrapper() }) + // Icon container should be present + const iconContainer = document.querySelector('[class*="rounded-xl"][class*="border"]') + expect(iconContainer).toBeInTheDocument() + }) + }) + + describe('Edge Cases', () => { + it('should handle empty server URL', () => { + const detail = createMockDetail({ server_url: '' }) + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText('Test MCP Server')).toBeInTheDocument() + }) + + it('should handle long MCP name', () => { + const longName = 'A'.repeat(100) + const detail = createMockDetail({ name: longName }) + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText(longName)).toBeInTheDocument() + }) + }) + + describe('Tools List', () => { + it('should show tools list when authorized and has tools', () => { + mockToolsData = { + tools: [ + { id: 'tool1', name: 'tool1', description: 'Tool 1' }, + { id: 'tool2', name: 'tool2', description: 'Tool 2' }, + ], + } + const detail = createMockDetail({ is_team_authorization: true }) + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText('tool1')).toBeInTheDocument() + expect(screen.getByText('tool2')).toBeInTheDocument() + }) + + it('should show single tool label when only one tool', () => { + mockToolsData = { + tools: [{ id: 'tool1', name: 'tool1', description: 'Tool 1' }], + } + const detail = createMockDetail({ is_team_authorization: true }) + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText('tools.mcp.onlyTool')).toBeInTheDocument() + }) + + it('should show tools count when multiple tools', () => { + mockToolsData = { + tools: [ + { id: 'tool1', name: 'tool1', description: 'Tool 1' }, + { id: 'tool2', name: 'tool2', description: 'Tool 2' }, + ], + } + const detail = createMockDetail({ is_team_authorization: true }) + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText(/tools.mcp.toolsNum/)).toBeInTheDocument() + }) + }) + + describe('Loading States', () => { + it('should show loading state when fetching tools', () => { + mockIsFetching = true + mockToolsData = { + tools: [{ id: 'tool1', name: 'tool1', description: 'Tool 1' }], + } + const detail = createMockDetail({ is_team_authorization: true }) + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText('tools.mcp.gettingTools')).toBeInTheDocument() + }) + + it('should show updating state when updating tools', () => { + mockIsUpdating = true + mockToolsData = { + tools: [{ id: 'tool1', name: 'tool1', description: 'Tool 1' }], + } + const detail = createMockDetail({ is_team_authorization: true }) + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText('tools.mcp.updateTools')).toBeInTheDocument() + }) + + it('should show authorizing button when authorizing', () => { + mockIsAuthorizing = true + const detail = createMockDetail({ is_team_authorization: false }) + render( + , + { wrapper: createWrapper() }, + ) + // Multiple elements show authorizing text - use getAllByText + const authorizingElements = screen.getAllByText('tools.mcp.authorizing') + expect(authorizingElements.length).toBeGreaterThan(0) + }) + }) + + describe('Authorize Flow', () => { + it('should call authorizeMcp when authorize button is clicked', async () => { + const onFirstCreate = vi.fn() + const detail = createMockDetail({ is_team_authorization: false }) + render( + , + { wrapper: createWrapper() }, + ) + + const authorizeBtn = screen.getByText('tools.mcp.authorize') + fireEvent.click(authorizeBtn) + + await waitFor(() => { + expect(onFirstCreate).toHaveBeenCalled() + expect(mockAuthorizeMcp).toHaveBeenCalledWith({ provider_id: 'mcp-1' }) + }) + }) + + it('should open OAuth popup when authorization_url is returned', async () => { + mockAuthorizeMcp.mockResolvedValue({ authorization_url: 'https://oauth.example.com' }) + const detail = createMockDetail({ is_team_authorization: false }) + render( + , + { wrapper: createWrapper() }, + ) + + const authorizeBtn = screen.getByText('tools.mcp.authorize') + fireEvent.click(authorizeBtn) + + await waitFor(() => { + expect(mockOpenOAuthPopup).toHaveBeenCalledWith( + 'https://oauth.example.com', + expect.any(Function), + ) + }) + }) + + it('should trigger authorize on mount when isTriggerAuthorize is true', async () => { + const onFirstCreate = vi.fn() + const detail = createMockDetail({ is_team_authorization: false }) + render( + , + { wrapper: createWrapper() }, + ) + + await waitFor(() => { + expect(onFirstCreate).toHaveBeenCalled() + expect(mockAuthorizeMcp).toHaveBeenCalled() + }) + }) + + it('should disable authorize button when not workspace manager', () => { + mockIsCurrentWorkspaceManager = false + const detail = createMockDetail({ is_team_authorization: false }) + render( + , + { wrapper: createWrapper() }, + ) + + const authorizeBtn = screen.getByText('tools.mcp.authorize') + expect(authorizeBtn.closest('button')).toBeDisabled() + }) + }) + + describe('Update Tools Flow', () => { + it('should show update confirm dialog when update button is clicked', async () => { + mockToolsData = { + tools: [{ id: 'tool1', name: 'tool1', description: 'Tool 1' }], + } + const detail = createMockDetail({ is_team_authorization: true }) + render( + , + { wrapper: createWrapper() }, + ) + + const updateBtn = screen.getByText('tools.mcp.update') + fireEvent.click(updateBtn) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + }) + + it('should call updateTools when update is confirmed', async () => { + mockToolsData = { + tools: [{ id: 'tool1', name: 'tool1', description: 'Tool 1' }], + } + const onUpdate = vi.fn() + const detail = createMockDetail({ is_team_authorization: true }) + render( + , + { wrapper: createWrapper() }, + ) + + // Open confirm dialog + const updateBtn = screen.getByText('tools.mcp.update') + fireEvent.click(updateBtn) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + + // Confirm the update + const confirmBtn = screen.getByTestId('confirm-btn') + fireEvent.click(confirmBtn) + + await waitFor(() => { + expect(mockUpdateTools).toHaveBeenCalledWith('mcp-1') + expect(mockInvalidateMCPTools).toHaveBeenCalledWith('mcp-1') + expect(onUpdate).toHaveBeenCalled() + }) + }) + + it('should call handleUpdateTools when get tools button is clicked', async () => { + const onUpdate = vi.fn() + const detail = createMockDetail({ is_team_authorization: true, tools: [] }) + render( + , + { wrapper: createWrapper() }, + ) + + const getToolsBtn = screen.getByText('tools.mcp.getTools') + fireEvent.click(getToolsBtn) + + await waitFor(() => { + expect(mockUpdateTools).toHaveBeenCalledWith('mcp-1') + }) + }) + }) + + describe('Update MCP Modal', () => { + it('should open update modal when edit button is clicked', async () => { + render(, { wrapper: createWrapper() }) + + const editBtn = screen.getByTestId('edit-btn') + fireEvent.click(editBtn) + + await waitFor(() => { + expect(screen.getByTestId('mcp-update-modal')).toBeInTheDocument() + }) + }) + + it('should close update modal when close button is clicked', async () => { + render(, { wrapper: createWrapper() }) + + // Open modal + const editBtn = screen.getByTestId('edit-btn') + fireEvent.click(editBtn) + + await waitFor(() => { + expect(screen.getByTestId('mcp-update-modal')).toBeInTheDocument() + }) + + // Close modal + const closeBtn = screen.getByTestId('modal-close-btn') + fireEvent.click(closeBtn) + + await waitFor(() => { + expect(screen.queryByTestId('mcp-update-modal')).not.toBeInTheDocument() + }) + }) + + it('should call updateMCP when form is confirmed', async () => { + const onUpdate = vi.fn() + render(, { wrapper: createWrapper() }) + + // Open modal + const editBtn = screen.getByTestId('edit-btn') + fireEvent.click(editBtn) + + await waitFor(() => { + expect(screen.getByTestId('mcp-update-modal')).toBeInTheDocument() + }) + + // Confirm form + const confirmBtn = screen.getByTestId('modal-confirm-btn') + fireEvent.click(confirmBtn) + + await waitFor(() => { + expect(mockUpdateMCP).toHaveBeenCalledWith({ + name: 'Updated MCP', + server_url: 'https://updated.com', + provider_id: 'mcp-1', + }) + expect(onUpdate).toHaveBeenCalled() + }) + }) + + it('should not call onUpdate when updateMCP fails', async () => { + mockUpdateMCP.mockResolvedValue({ result: 'error' }) + const onUpdate = vi.fn() + render(, { wrapper: createWrapper() }) + + // Open modal + const editBtn = screen.getByTestId('edit-btn') + fireEvent.click(editBtn) + + await waitFor(() => { + expect(screen.getByTestId('mcp-update-modal')).toBeInTheDocument() + }) + + // Confirm form + const confirmBtn = screen.getByTestId('modal-confirm-btn') + fireEvent.click(confirmBtn) + + await waitFor(() => { + expect(mockUpdateMCP).toHaveBeenCalled() + }) + + expect(onUpdate).not.toHaveBeenCalled() + }) + }) + + describe('Delete MCP Flow', () => { + it('should open delete confirm when remove button is clicked', async () => { + render(, { wrapper: createWrapper() }) + + const removeBtn = screen.getByTestId('remove-btn') + fireEvent.click(removeBtn) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + }) + + it('should close delete confirm when cancel is clicked', async () => { + render(, { wrapper: createWrapper() }) + + // Open confirm + const removeBtn = screen.getByTestId('remove-btn') + fireEvent.click(removeBtn) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + + // Cancel + const cancelBtn = screen.getByTestId('cancel-btn') + fireEvent.click(cancelBtn) + + await waitFor(() => { + expect(screen.queryByTestId('confirm-dialog')).not.toBeInTheDocument() + }) + }) + + it('should call deleteMCP when delete is confirmed', async () => { + const onUpdate = vi.fn() + render(, { wrapper: createWrapper() }) + + // Open confirm + const removeBtn = screen.getByTestId('remove-btn') + fireEvent.click(removeBtn) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + + // Confirm delete + const confirmBtn = screen.getByTestId('confirm-btn') + fireEvent.click(confirmBtn) + + await waitFor(() => { + expect(mockDeleteMCP).toHaveBeenCalledWith('mcp-1') + expect(onUpdate).toHaveBeenCalledWith(true) + }) + }) + + it('should not call onUpdate when deleteMCP fails', async () => { + mockDeleteMCP.mockResolvedValue({ result: 'error' }) + const onUpdate = vi.fn() + render(, { wrapper: createWrapper() }) + + // Open confirm + const removeBtn = screen.getByTestId('remove-btn') + fireEvent.click(removeBtn) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + + // Confirm delete + const confirmBtn = screen.getByTestId('confirm-btn') + fireEvent.click(confirmBtn) + + await waitFor(() => { + expect(mockDeleteMCP).toHaveBeenCalled() + }) + + expect(onUpdate).not.toHaveBeenCalled() + }) + }) + + describe('Close Button', () => { + it('should call onHide when close button is clicked', () => { + const onHide = vi.fn() + render(, { wrapper: createWrapper() }) + + // Find the close button (ActionButton with RiCloseLine) + const buttons = screen.getAllByRole('button') + const closeButton = buttons.find(btn => + btn.querySelector('svg.h-4.w-4'), + ) + + if (closeButton) { + fireEvent.click(closeButton) + expect(onHide).toHaveBeenCalled() + } + }) + }) + + describe('Copy Server Identifier', () => { + it('should copy server identifier when clicked', async () => { + const { default: copy } = await import('copy-to-clipboard') + render(, { wrapper: createWrapper() }) + + // Find the server identifier element + const serverIdentifier = screen.getByText('test-mcp') + fireEvent.click(serverIdentifier) + + expect(copy).toHaveBeenCalledWith('test-mcp') + }) + }) + + describe('OAuth Callback', () => { + it('should call handleUpdateTools on OAuth callback when authorized', async () => { + // Simulate OAuth flow with authorization_url + mockAuthorizeMcp.mockResolvedValue({ authorization_url: 'https://oauth.example.com' }) + const onUpdate = vi.fn() + const detail = createMockDetail({ is_team_authorization: false }) + render( + , + { wrapper: createWrapper() }, + ) + + // Click authorize to trigger OAuth popup + const authorizeBtn = screen.getByText('tools.mcp.authorize') + fireEvent.click(authorizeBtn) + + await waitFor(() => { + expect(mockOpenOAuthPopup).toHaveBeenCalled() + }) + + // Get the callback function and call it + const oauthCallback = mockOpenOAuthPopup.mock.calls[0][1] + oauthCallback() + + await waitFor(() => { + expect(mockUpdateTools).toHaveBeenCalledWith('mcp-1') + }) + }) + + it('should not call handleUpdateTools if not workspace manager', async () => { + mockIsCurrentWorkspaceManager = false + mockAuthorizeMcp.mockResolvedValue({ authorization_url: 'https://oauth.example.com' }) + const detail = createMockDetail({ is_team_authorization: false }) + + // OAuth callback should not trigger update for non-manager + // The button is disabled, so we simulate a scenario where OAuth was already started + render( + , + { wrapper: createWrapper() }, + ) + + // Button should be disabled + const authorizeBtn = screen.getByText('tools.mcp.authorize') + expect(authorizeBtn.closest('button')).toBeDisabled() + }) + }) + + describe('Authorized Button', () => { + it('should show authorized button when team is authorized', () => { + const detail = createMockDetail({ is_team_authorization: true }) + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText('tools.auth.authorized')).toBeInTheDocument() + }) + + it('should call handleAuthorize when authorized button is clicked', async () => { + const onFirstCreate = vi.fn() + const detail = createMockDetail({ is_team_authorization: true }) + render( + , + { wrapper: createWrapper() }, + ) + + const authorizedBtn = screen.getByText('tools.auth.authorized') + fireEvent.click(authorizedBtn) + + await waitFor(() => { + expect(onFirstCreate).toHaveBeenCalled() + expect(mockAuthorizeMcp).toHaveBeenCalled() + }) + }) + + it('should disable authorized button when not workspace manager', () => { + mockIsCurrentWorkspaceManager = false + const detail = createMockDetail({ is_team_authorization: true }) + render( + , + { wrapper: createWrapper() }, + ) + + const authorizedBtn = screen.getByText('tools.auth.authorized') + expect(authorizedBtn.closest('button')).toBeDisabled() + }) + }) + + describe('Cancel Update Confirm', () => { + it('should close update confirm when cancel is clicked', async () => { + mockToolsData = { + tools: [{ id: 'tool1', name: 'tool1', description: 'Tool 1' }], + } + const detail = createMockDetail({ is_team_authorization: true }) + render( + , + { wrapper: createWrapper() }, + ) + + // Open confirm dialog + const updateBtn = screen.getByText('tools.mcp.update') + fireEvent.click(updateBtn) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + + // Cancel the update + const cancelBtn = screen.getByTestId('cancel-btn') + fireEvent.click(cancelBtn) + + await waitFor(() => { + expect(screen.queryByTestId('confirm-dialog')).not.toBeInTheDocument() + }) + }) + }) +}) diff --git a/web/app/components/tools/mcp/detail/list-loading.spec.tsx b/web/app/components/tools/mcp/detail/list-loading.spec.tsx new file mode 100644 index 0000000000..679d4322d9 --- /dev/null +++ b/web/app/components/tools/mcp/detail/list-loading.spec.tsx @@ -0,0 +1,71 @@ +import { render } from '@testing-library/react' +import { describe, expect, it } from 'vitest' +import ListLoading from './list-loading' + +describe('ListLoading', () => { + describe('Rendering', () => { + it('should render without crashing', () => { + const { container } = render() + expect(container).toBeInTheDocument() + }) + + it('should render 5 skeleton items', () => { + render() + const skeletonItems = document.querySelectorAll('[class*="bg-components-panel-on-panel-item-bg-hover"]') + expect(skeletonItems.length).toBe(5) + }) + + it('should have rounded-xl class on skeleton items', () => { + render() + const skeletonItems = document.querySelectorAll('.rounded-xl') + expect(skeletonItems.length).toBeGreaterThanOrEqual(5) + }) + + it('should have proper spacing', () => { + render() + const container = document.querySelector('.space-y-2') + expect(container).toBeInTheDocument() + }) + + it('should render placeholder bars with different widths', () => { + render() + const bar180 = document.querySelector('.w-\\[180px\\]') + const bar148 = document.querySelector('.w-\\[148px\\]') + const bar196 = document.querySelector('.w-\\[196px\\]') + + expect(bar180).toBeInTheDocument() + expect(bar148).toBeInTheDocument() + expect(bar196).toBeInTheDocument() + }) + + it('should have opacity styling on skeleton bars', () => { + render() + const opacity20Bars = document.querySelectorAll('.opacity-20') + const opacity10Bars = document.querySelectorAll('.opacity-10') + + expect(opacity20Bars.length).toBeGreaterThan(0) + expect(opacity10Bars.length).toBeGreaterThan(0) + }) + }) + + describe('Structure', () => { + it('should have correct nested structure', () => { + render() + const items = document.querySelectorAll('.space-y-3') + expect(items.length).toBe(5) + }) + + it('should render padding on skeleton items', () => { + render() + const paddedItems = document.querySelectorAll('.p-4') + expect(paddedItems.length).toBe(5) + }) + + it('should render height-2 skeleton bars', () => { + render() + const h2Bars = document.querySelectorAll('.h-2') + // 3 bars per skeleton item * 5 items = 15 + expect(h2Bars.length).toBe(15) + }) + }) +}) diff --git a/web/app/components/tools/mcp/detail/operation-dropdown.spec.tsx b/web/app/components/tools/mcp/detail/operation-dropdown.spec.tsx new file mode 100644 index 0000000000..077bdc3efe --- /dev/null +++ b/web/app/components/tools/mcp/detail/operation-dropdown.spec.tsx @@ -0,0 +1,193 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import OperationDropdown from './operation-dropdown' + +describe('OperationDropdown', () => { + const defaultProps = { + onEdit: vi.fn(), + onRemove: vi.fn(), + } + + describe('Rendering', () => { + it('should render without crashing', () => { + render() + expect(document.querySelector('button')).toBeInTheDocument() + }) + + it('should render trigger button with more icon', () => { + render() + const button = document.querySelector('button') + expect(button).toBeInTheDocument() + const svg = button?.querySelector('svg') + expect(svg).toBeInTheDocument() + }) + + it('should render medium size by default', () => { + render() + const icon = document.querySelector('.h-4.w-4') + expect(icon).toBeInTheDocument() + }) + + it('should render large size when inCard is true', () => { + render() + const icon = document.querySelector('.h-5.w-5') + expect(icon).toBeInTheDocument() + }) + }) + + describe('Dropdown Behavior', () => { + it('should open dropdown when trigger is clicked', async () => { + render() + + const trigger = document.querySelector('button') + if (trigger) { + fireEvent.click(trigger) + + // Dropdown content should be rendered + expect(screen.getByText('tools.mcp.operation.edit')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.operation.remove')).toBeInTheDocument() + } + }) + + it('should call onOpenChange when opened', () => { + const onOpenChange = vi.fn() + render() + + const trigger = document.querySelector('button') + if (trigger) { + fireEvent.click(trigger) + expect(onOpenChange).toHaveBeenCalledWith(true) + } + }) + + it('should close dropdown when trigger is clicked again', async () => { + const onOpenChange = vi.fn() + render() + + const trigger = document.querySelector('button') + if (trigger) { + fireEvent.click(trigger) + fireEvent.click(trigger) + expect(onOpenChange).toHaveBeenLastCalledWith(false) + } + }) + }) + + describe('Menu Actions', () => { + it('should call onEdit when edit option is clicked', () => { + const onEdit = vi.fn() + render() + + const trigger = document.querySelector('button') + if (trigger) { + fireEvent.click(trigger) + + const editOption = screen.getByText('tools.mcp.operation.edit') + fireEvent.click(editOption) + + expect(onEdit).toHaveBeenCalledTimes(1) + } + }) + + it('should call onRemove when remove option is clicked', () => { + const onRemove = vi.fn() + render() + + const trigger = document.querySelector('button') + if (trigger) { + fireEvent.click(trigger) + + const removeOption = screen.getByText('tools.mcp.operation.remove') + fireEvent.click(removeOption) + + expect(onRemove).toHaveBeenCalledTimes(1) + } + }) + + it('should close dropdown after edit is clicked', () => { + const onOpenChange = vi.fn() + render() + + const trigger = document.querySelector('button') + if (trigger) { + fireEvent.click(trigger) + onOpenChange.mockClear() + + const editOption = screen.getByText('tools.mcp.operation.edit') + fireEvent.click(editOption) + + expect(onOpenChange).toHaveBeenCalledWith(false) + } + }) + + it('should close dropdown after remove is clicked', () => { + const onOpenChange = vi.fn() + render() + + const trigger = document.querySelector('button') + if (trigger) { + fireEvent.click(trigger) + onOpenChange.mockClear() + + const removeOption = screen.getByText('tools.mcp.operation.remove') + fireEvent.click(removeOption) + + expect(onOpenChange).toHaveBeenCalledWith(false) + } + }) + }) + + describe('Styling', () => { + it('should have correct dropdown width', () => { + render() + + const trigger = document.querySelector('button') + if (trigger) { + fireEvent.click(trigger) + + const dropdown = document.querySelector('.w-\\[160px\\]') + expect(dropdown).toBeInTheDocument() + } + }) + + it('should have rounded-xl on dropdown', () => { + render() + + const trigger = document.querySelector('button') + if (trigger) { + fireEvent.click(trigger) + + const dropdown = document.querySelector('[class*="rounded-xl"][class*="border"]') + expect(dropdown).toBeInTheDocument() + } + }) + + it('should show destructive hover style on remove option', () => { + render() + + const trigger = document.querySelector('button') + if (trigger) { + fireEvent.click(trigger) + + // The text is in a div, and the hover style is on the parent div with group class + const removeOptionText = screen.getByText('tools.mcp.operation.remove') + const removeOptionContainer = removeOptionText.closest('.group') + expect(removeOptionContainer).toHaveClass('hover:bg-state-destructive-hover') + } + }) + }) + + describe('inCard prop', () => { + it('should adjust offset when inCard is false', () => { + render() + // Component renders with different offset values + expect(document.querySelector('button')).toBeInTheDocument() + }) + + it('should adjust offset when inCard is true', () => { + render() + // Component renders with different offset values + expect(document.querySelector('button')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/tools/mcp/detail/provider-detail.spec.tsx b/web/app/components/tools/mcp/detail/provider-detail.spec.tsx new file mode 100644 index 0000000000..dc8a427498 --- /dev/null +++ b/web/app/components/tools/mcp/detail/provider-detail.spec.tsx @@ -0,0 +1,153 @@ +import type { ReactNode } from 'react' +import type { ToolWithProvider } from '@/app/components/workflow/types' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { fireEvent, render, screen } from '@testing-library/react' +import * as React from 'react' +import { describe, expect, it, vi } from 'vitest' +import MCPDetailPanel from './provider-detail' + +// Mock the drawer component +vi.mock('@/app/components/base/drawer', () => ({ + default: ({ children, isOpen }: { children: ReactNode, isOpen: boolean }) => { + if (!isOpen) + return null + return
{children}
+ }, +})) + +// Mock the content component to expose onUpdate callback +vi.mock('./content', () => ({ + default: ({ detail, onUpdate }: { detail: ToolWithProvider, onUpdate: (isDelete?: boolean) => void }) => ( +
+ {detail.name} + + +
+ ), +})) + +describe('MCPDetailPanel', () => { + const createWrapper = () => { + const queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + }, + }) + return ({ children }: { children: ReactNode }) => + React.createElement(QueryClientProvider, { client: queryClient }, children) + } + + const createMockDetail = (): ToolWithProvider => ({ + id: 'mcp-1', + name: 'Test MCP', + server_identifier: 'test-mcp', + server_url: 'https://example.com/mcp', + icon: { content: '🔧', background: '#FF0000' }, + tools: [], + is_team_authorization: true, + } as unknown as ToolWithProvider) + + const defaultProps = { + onUpdate: vi.fn(), + onHide: vi.fn(), + isTriggerAuthorize: false, + onFirstCreate: vi.fn(), + } + + describe('Rendering', () => { + it('should render nothing when detail is undefined', () => { + const { container } = render( + , + { wrapper: createWrapper() }, + ) + expect(container.innerHTML).toBe('') + }) + + it('should render drawer when detail is provided', () => { + const detail = createMockDetail() + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByTestId('drawer')).toBeInTheDocument() + }) + + it('should render content when detail is provided', () => { + const detail = createMockDetail() + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByTestId('mcp-detail-content')).toBeInTheDocument() + }) + + it('should pass detail to content component', () => { + const detail = createMockDetail() + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByText('Test MCP')).toBeInTheDocument() + }) + }) + + describe('Callbacks', () => { + it('should call onUpdate when update is triggered', () => { + const onUpdate = vi.fn() + const detail = createMockDetail() + render( + , + { wrapper: createWrapper() }, + ) + // The update callback is passed to content component + expect(screen.getByTestId('mcp-detail-content')).toBeInTheDocument() + }) + + it('should accept isTriggerAuthorize prop', () => { + const detail = createMockDetail() + render( + , + { wrapper: createWrapper() }, + ) + expect(screen.getByTestId('mcp-detail-content')).toBeInTheDocument() + }) + }) + + describe('handleUpdate', () => { + it('should call onUpdate but not onHide when isDelete is false (default)', () => { + const onUpdate = vi.fn() + const onHide = vi.fn() + const detail = createMockDetail() + render( + , + { wrapper: createWrapper() }, + ) + + // Click update button which calls onUpdate() without isDelete parameter + const updateBtn = screen.getByTestId('update-btn') + fireEvent.click(updateBtn) + + expect(onUpdate).toHaveBeenCalledTimes(1) + expect(onHide).not.toHaveBeenCalled() + }) + + it('should call both onHide and onUpdate when isDelete is true', () => { + const onUpdate = vi.fn() + const onHide = vi.fn() + const detail = createMockDetail() + render( + , + { wrapper: createWrapper() }, + ) + + // Click delete button which calls onUpdate(true) + const deleteBtn = screen.getByTestId('delete-btn') + fireEvent.click(deleteBtn) + + expect(onHide).toHaveBeenCalledTimes(1) + expect(onUpdate).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/tools/mcp/detail/tool-item.spec.tsx b/web/app/components/tools/mcp/detail/tool-item.spec.tsx new file mode 100644 index 0000000000..aa04422b48 --- /dev/null +++ b/web/app/components/tools/mcp/detail/tool-item.spec.tsx @@ -0,0 +1,126 @@ +import type { Tool } from '@/app/components/tools/types' +import { render, screen } from '@testing-library/react' +import { describe, expect, it } from 'vitest' +import MCPToolItem from './tool-item' + +describe('MCPToolItem', () => { + const createMockTool = (overrides = {}): Tool => ({ + name: 'test-tool', + label: { + en_US: 'Test Tool', + zh_Hans: '测试工具', + }, + description: { + en_US: 'A test tool description', + zh_Hans: '测试工具描述', + }, + parameters: [], + ...overrides, + } as unknown as Tool) + + describe('Rendering', () => { + it('should render without crashing', () => { + const tool = createMockTool() + render() + expect(screen.getByText('Test Tool')).toBeInTheDocument() + }) + + it('should display tool label', () => { + const tool = createMockTool() + render() + expect(screen.getByText('Test Tool')).toBeInTheDocument() + }) + + it('should display tool description', () => { + const tool = createMockTool() + render() + expect(screen.getByText('A test tool description')).toBeInTheDocument() + }) + }) + + describe('With Parameters', () => { + it('should not show parameters section when no parameters', () => { + const tool = createMockTool({ parameters: [] }) + render() + expect(screen.queryByText('tools.mcp.toolItem.parameters')).not.toBeInTheDocument() + }) + + it('should render with parameters', () => { + const tool = createMockTool({ + parameters: [ + { + name: 'param1', + type: 'string', + human_description: { + en_US: 'A parameter description', + }, + }, + ], + }) + render() + // Tooltip content is rendered in portal, may not be visible immediately + expect(screen.getByText('Test Tool')).toBeInTheDocument() + }) + }) + + describe('Styling', () => { + it('should have cursor-pointer class', () => { + const tool = createMockTool() + render() + const toolElement = document.querySelector('.cursor-pointer') + expect(toolElement).toBeInTheDocument() + }) + + it('should have rounded-xl class', () => { + const tool = createMockTool() + render() + const toolElement = document.querySelector('.rounded-xl') + expect(toolElement).toBeInTheDocument() + }) + + it('should have hover styles', () => { + const tool = createMockTool() + render() + const toolElement = document.querySelector('[class*="hover:bg-components-panel-on-panel-item-bg-hover"]') + expect(toolElement).toBeInTheDocument() + }) + }) + + describe('Edge Cases', () => { + it('should handle empty label', () => { + const tool = createMockTool({ + label: { en_US: '', zh_Hans: '' }, + }) + render() + // Should render without crashing + expect(document.querySelector('.cursor-pointer')).toBeInTheDocument() + }) + + it('should handle empty description', () => { + const tool = createMockTool({ + description: { en_US: '', zh_Hans: '' }, + }) + render() + expect(screen.getByText('Test Tool')).toBeInTheDocument() + }) + + it('should handle long description with line clamp', () => { + const longDescription = 'This is a very long description '.repeat(20) + const tool = createMockTool({ + description: { en_US: longDescription, zh_Hans: longDescription }, + }) + render() + const descElement = document.querySelector('.line-clamp-2') + expect(descElement).toBeInTheDocument() + }) + + it('should handle special characters in tool name', () => { + const tool = createMockTool({ + name: 'special-tool_v2.0', + label: { en_US: 'Special Tool ', zh_Hans: '特殊工具' }, + }) + render() + expect(screen.getByText('Special Tool ')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/tools/mcp/headers-input.spec.tsx b/web/app/components/tools/mcp/headers-input.spec.tsx new file mode 100644 index 0000000000..c271268f5f --- /dev/null +++ b/web/app/components/tools/mcp/headers-input.spec.tsx @@ -0,0 +1,245 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import HeadersInput from './headers-input' + +describe('HeadersInput', () => { + const defaultProps = { + headersItems: [], + onChange: vi.fn(), + } + + describe('Empty State', () => { + it('should render no headers message when empty', () => { + render() + expect(screen.getByText('tools.mcp.modal.noHeaders')).toBeInTheDocument() + }) + + it('should render add header button when empty and not readonly', () => { + render() + expect(screen.getByText('tools.mcp.modal.addHeader')).toBeInTheDocument() + }) + + it('should not render add header button when empty and readonly', () => { + render() + expect(screen.queryByText('tools.mcp.modal.addHeader')).not.toBeInTheDocument() + }) + + it('should call onChange with new item when add button is clicked', () => { + const onChange = vi.fn() + render() + + const addButton = screen.getByText('tools.mcp.modal.addHeader') + fireEvent.click(addButton) + + expect(onChange).toHaveBeenCalledWith([ + expect.objectContaining({ + key: '', + value: '', + }), + ]) + }) + }) + + describe('With Headers', () => { + const headersItems = [ + { id: '1', key: 'Authorization', value: 'Bearer token123' }, + { id: '2', key: 'Content-Type', value: 'application/json' }, + ] + + it('should render header items', () => { + render() + expect(screen.getByDisplayValue('Authorization')).toBeInTheDocument() + expect(screen.getByDisplayValue('Bearer token123')).toBeInTheDocument() + expect(screen.getByDisplayValue('Content-Type')).toBeInTheDocument() + expect(screen.getByDisplayValue('application/json')).toBeInTheDocument() + }) + + it('should render table headers', () => { + render() + expect(screen.getByText('tools.mcp.modal.headerKey')).toBeInTheDocument() + expect(screen.getByText('tools.mcp.modal.headerValue')).toBeInTheDocument() + }) + + it('should render delete buttons for each item when not readonly', () => { + render() + // Should have delete buttons for each header + const deleteButtons = document.querySelectorAll('[class*="text-text-destructive"]') + expect(deleteButtons.length).toBe(headersItems.length) + }) + + it('should not render delete buttons when readonly', () => { + render() + const deleteButtons = document.querySelectorAll('[class*="text-text-destructive"]') + expect(deleteButtons.length).toBe(0) + }) + + it('should render add button at bottom when not readonly', () => { + render() + expect(screen.getByText('tools.mcp.modal.addHeader')).toBeInTheDocument() + }) + + it('should not render add button when readonly', () => { + render() + expect(screen.queryByText('tools.mcp.modal.addHeader')).not.toBeInTheDocument() + }) + }) + + describe('Masked Headers', () => { + const headersItems = [{ id: '1', key: 'Secret', value: '***' }] + + it('should show masked headers tip when isMasked is true', () => { + render() + expect(screen.getByText('tools.mcp.modal.maskedHeadersTip')).toBeInTheDocument() + }) + + it('should not show masked headers tip when isMasked is false', () => { + render() + expect(screen.queryByText('tools.mcp.modal.maskedHeadersTip')).not.toBeInTheDocument() + }) + }) + + describe('Item Interactions', () => { + const headersItems = [ + { id: '1', key: 'Header1', value: 'Value1' }, + ] + + it('should call onChange when key is changed', () => { + const onChange = vi.fn() + render() + + const keyInput = screen.getByDisplayValue('Header1') + fireEvent.change(keyInput, { target: { value: 'NewHeader' } }) + + expect(onChange).toHaveBeenCalledWith([ + { id: '1', key: 'NewHeader', value: 'Value1' }, + ]) + }) + + it('should call onChange when value is changed', () => { + const onChange = vi.fn() + render() + + const valueInput = screen.getByDisplayValue('Value1') + fireEvent.change(valueInput, { target: { value: 'NewValue' } }) + + expect(onChange).toHaveBeenCalledWith([ + { id: '1', key: 'Header1', value: 'NewValue' }, + ]) + }) + + it('should remove item when delete button is clicked', () => { + const onChange = vi.fn() + render() + + const deleteButton = document.querySelector('[class*="text-text-destructive"]')?.closest('button') + if (deleteButton) { + fireEvent.click(deleteButton) + expect(onChange).toHaveBeenCalledWith([]) + } + }) + + it('should add new item when add button is clicked', () => { + const onChange = vi.fn() + render() + + const addButton = screen.getByText('tools.mcp.modal.addHeader') + fireEvent.click(addButton) + + expect(onChange).toHaveBeenCalledWith([ + { id: '1', key: 'Header1', value: 'Value1' }, + expect.objectContaining({ key: '', value: '' }), + ]) + }) + }) + + describe('Multiple Headers', () => { + const headersItems = [ + { id: '1', key: 'Header1', value: 'Value1' }, + { id: '2', key: 'Header2', value: 'Value2' }, + { id: '3', key: 'Header3', value: 'Value3' }, + ] + + it('should render all headers', () => { + render() + expect(screen.getByDisplayValue('Header1')).toBeInTheDocument() + expect(screen.getByDisplayValue('Header2')).toBeInTheDocument() + expect(screen.getByDisplayValue('Header3')).toBeInTheDocument() + }) + + it('should update correct item when changed', () => { + const onChange = vi.fn() + render() + + const header2Input = screen.getByDisplayValue('Header2') + fireEvent.change(header2Input, { target: { value: 'UpdatedHeader2' } }) + + expect(onChange).toHaveBeenCalledWith([ + { id: '1', key: 'Header1', value: 'Value1' }, + { id: '2', key: 'UpdatedHeader2', value: 'Value2' }, + { id: '3', key: 'Header3', value: 'Value3' }, + ]) + }) + + it('should remove correct item when deleted', () => { + const onChange = vi.fn() + render() + + // Find all delete buttons and click the second one + const deleteButtons = document.querySelectorAll('[class*="text-text-destructive"]') + const secondDeleteButton = deleteButtons[1]?.closest('button') + if (secondDeleteButton) { + fireEvent.click(secondDeleteButton) + expect(onChange).toHaveBeenCalledWith([ + { id: '1', key: 'Header1', value: 'Value1' }, + { id: '3', key: 'Header3', value: 'Value3' }, + ]) + } + }) + }) + + describe('Readonly Mode', () => { + const headersItems = [{ id: '1', key: 'ReadOnly', value: 'Value' }] + + it('should make inputs readonly when readonly is true', () => { + render() + + const keyInput = screen.getByDisplayValue('ReadOnly') + const valueInput = screen.getByDisplayValue('Value') + + expect(keyInput).toHaveAttribute('readonly') + expect(valueInput).toHaveAttribute('readonly') + }) + + it('should not make inputs readonly when readonly is false', () => { + render() + + const keyInput = screen.getByDisplayValue('ReadOnly') + const valueInput = screen.getByDisplayValue('Value') + + expect(keyInput).not.toHaveAttribute('readonly') + expect(valueInput).not.toHaveAttribute('readonly') + }) + }) + + describe('Edge Cases', () => { + it('should handle empty key and value', () => { + const headersItems = [{ id: '1', key: '', value: '' }] + render() + + const inputs = screen.getAllByRole('textbox') + expect(inputs.length).toBe(2) + }) + + it('should handle special characters in header key', () => { + const headersItems = [{ id: '1', key: 'X-Custom-Header', value: 'value' }] + render() + expect(screen.getByDisplayValue('X-Custom-Header')).toBeInTheDocument() + }) + + it('should handle JSON value', () => { + const headersItems = [{ id: '1', key: 'Data', value: '{"key":"value"}' }] + render() + expect(screen.getByDisplayValue('{"key":"value"}')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/tools/mcp/hooks/use-mcp-modal-form.spec.ts b/web/app/components/tools/mcp/hooks/use-mcp-modal-form.spec.ts new file mode 100644 index 0000000000..72520e11d1 --- /dev/null +++ b/web/app/components/tools/mcp/hooks/use-mcp-modal-form.spec.ts @@ -0,0 +1,500 @@ +import type { AppIconEmojiSelection, AppIconImageSelection } from '@/app/components/base/app-icon-picker' +import type { ToolWithProvider } from '@/app/components/workflow/types' +import { act, renderHook } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import { MCPAuthMethod } from '@/app/components/tools/types' +import { isValidServerID, isValidUrl, useMCPModalForm } from './use-mcp-modal-form' + +// Mock the API service +vi.mock('@/service/common', () => ({ + uploadRemoteFileInfo: vi.fn(), +})) + +describe('useMCPModalForm', () => { + describe('Utility Functions', () => { + describe('isValidUrl', () => { + it('should return true for valid http URL', () => { + expect(isValidUrl('http://example.com')).toBe(true) + }) + + it('should return true for valid https URL', () => { + expect(isValidUrl('https://example.com')).toBe(true) + }) + + it('should return true for URL with path', () => { + expect(isValidUrl('https://example.com/path/to/resource')).toBe(true) + }) + + it('should return true for URL with query params', () => { + expect(isValidUrl('https://example.com?foo=bar')).toBe(true) + }) + + it('should return false for invalid URL', () => { + expect(isValidUrl('not-a-url')).toBe(false) + }) + + it('should return false for ftp URL', () => { + expect(isValidUrl('ftp://example.com')).toBe(false) + }) + + it('should return false for empty string', () => { + expect(isValidUrl('')).toBe(false) + }) + + it('should return false for file URL', () => { + expect(isValidUrl('file:///path/to/file')).toBe(false) + }) + }) + + describe('isValidServerID', () => { + it('should return true for lowercase letters', () => { + expect(isValidServerID('myserver')).toBe(true) + }) + + it('should return true for numbers', () => { + expect(isValidServerID('123')).toBe(true) + }) + + it('should return true for alphanumeric with hyphens', () => { + expect(isValidServerID('my-server-123')).toBe(true) + }) + + it('should return true for alphanumeric with underscores', () => { + expect(isValidServerID('my_server_123')).toBe(true) + }) + + it('should return true for max length (24 chars)', () => { + expect(isValidServerID('abcdefghijklmnopqrstuvwx')).toBe(true) + }) + + it('should return false for uppercase letters', () => { + expect(isValidServerID('MyServer')).toBe(false) + }) + + it('should return false for spaces', () => { + expect(isValidServerID('my server')).toBe(false) + }) + + it('should return false for special characters', () => { + expect(isValidServerID('my@server')).toBe(false) + }) + + it('should return false for empty string', () => { + expect(isValidServerID('')).toBe(false) + }) + + it('should return false for string longer than 24 chars', () => { + expect(isValidServerID('abcdefghijklmnopqrstuvwxy')).toBe(false) + }) + }) + }) + + describe('Hook Initialization', () => { + describe('Create Mode (no data)', () => { + it('should initialize with default values', () => { + const { result } = renderHook(() => useMCPModalForm()) + + expect(result.current.isCreate).toBe(true) + expect(result.current.formKey).toBe('create') + expect(result.current.state.url).toBe('') + expect(result.current.state.name).toBe('') + expect(result.current.state.serverIdentifier).toBe('') + expect(result.current.state.timeout).toBe(30) + expect(result.current.state.sseReadTimeout).toBe(300) + expect(result.current.state.headers).toEqual([]) + expect(result.current.state.authMethod).toBe(MCPAuthMethod.authentication) + expect(result.current.state.isDynamicRegistration).toBe(true) + expect(result.current.state.clientID).toBe('') + expect(result.current.state.credentials).toBe('') + }) + + it('should initialize with default emoji icon', () => { + const { result } = renderHook(() => useMCPModalForm()) + + expect(result.current.state.appIcon).toEqual({ + type: 'emoji', + icon: '🔗', + background: '#6366F1', + }) + }) + }) + + describe('Edit Mode (with data)', () => { + const mockData: ToolWithProvider = { + id: 'test-id-123', + name: 'Test MCP Server', + server_url: 'https://example.com/mcp', + server_identifier: 'test-server', + icon: { content: '🚀', background: '#FF0000' }, + configuration: { + timeout: 60, + sse_read_timeout: 600, + }, + masked_headers: { + 'Authorization': '***', + 'X-Custom': 'value', + }, + is_dynamic_registration: false, + authentication: { + client_id: 'client-123', + client_secret: 'secret-456', + }, + } as unknown as ToolWithProvider + + it('should initialize with data values', () => { + const { result } = renderHook(() => useMCPModalForm(mockData)) + + expect(result.current.isCreate).toBe(false) + expect(result.current.formKey).toBe('test-id-123') + expect(result.current.state.url).toBe('https://example.com/mcp') + expect(result.current.state.name).toBe('Test MCP Server') + expect(result.current.state.serverIdentifier).toBe('test-server') + expect(result.current.state.timeout).toBe(60) + expect(result.current.state.sseReadTimeout).toBe(600) + expect(result.current.state.isDynamicRegistration).toBe(false) + expect(result.current.state.clientID).toBe('client-123') + expect(result.current.state.credentials).toBe('secret-456') + }) + + it('should initialize headers from masked_headers', () => { + const { result } = renderHook(() => useMCPModalForm(mockData)) + + expect(result.current.state.headers).toHaveLength(2) + expect(result.current.state.headers[0].key).toBe('Authorization') + expect(result.current.state.headers[0].value).toBe('***') + expect(result.current.state.headers[1].key).toBe('X-Custom') + expect(result.current.state.headers[1].value).toBe('value') + }) + + it('should initialize emoji icon from data', () => { + const { result } = renderHook(() => useMCPModalForm(mockData)) + + expect(result.current.state.appIcon.type).toBe('emoji') + expect(((result.current.state.appIcon) as AppIconEmojiSelection).icon).toBe('🚀') + expect(((result.current.state.appIcon) as AppIconEmojiSelection).background).toBe('#FF0000') + }) + + it('should store original server URL and ID', () => { + const { result } = renderHook(() => useMCPModalForm(mockData)) + + expect(result.current.originalServerUrl).toBe('https://example.com/mcp') + expect(result.current.originalServerID).toBe('test-server') + }) + }) + + describe('Edit Mode with string icon', () => { + const mockDataWithImageIcon: ToolWithProvider = { + id: 'test-id', + name: 'Test', + icon: 'https://example.com/files/abc123/file-preview/icon.png', + } as unknown as ToolWithProvider + + it('should initialize image icon from string URL', () => { + const { result } = renderHook(() => useMCPModalForm(mockDataWithImageIcon)) + + expect(result.current.state.appIcon.type).toBe('image') + expect(((result.current.state.appIcon) as AppIconImageSelection).url).toBe('https://example.com/files/abc123/file-preview/icon.png') + expect(((result.current.state.appIcon) as AppIconImageSelection).fileId).toBe('abc123') + }) + }) + }) + + describe('Actions', () => { + it('should update url', () => { + const { result } = renderHook(() => useMCPModalForm()) + + act(() => { + result.current.actions.setUrl('https://new-url.com') + }) + + expect(result.current.state.url).toBe('https://new-url.com') + }) + + it('should update name', () => { + const { result } = renderHook(() => useMCPModalForm()) + + act(() => { + result.current.actions.setName('New Server Name') + }) + + expect(result.current.state.name).toBe('New Server Name') + }) + + it('should update serverIdentifier', () => { + const { result } = renderHook(() => useMCPModalForm()) + + act(() => { + result.current.actions.setServerIdentifier('new-server-id') + }) + + expect(result.current.state.serverIdentifier).toBe('new-server-id') + }) + + it('should update timeout', () => { + const { result } = renderHook(() => useMCPModalForm()) + + act(() => { + result.current.actions.setTimeout(120) + }) + + expect(result.current.state.timeout).toBe(120) + }) + + it('should update sseReadTimeout', () => { + const { result } = renderHook(() => useMCPModalForm()) + + act(() => { + result.current.actions.setSseReadTimeout(900) + }) + + expect(result.current.state.sseReadTimeout).toBe(900) + }) + + it('should update headers', () => { + const { result } = renderHook(() => useMCPModalForm()) + const newHeaders = [{ id: '1', key: 'X-New', value: 'new-value' }] + + act(() => { + result.current.actions.setHeaders(newHeaders) + }) + + expect(result.current.state.headers).toEqual(newHeaders) + }) + + it('should update authMethod', () => { + const { result } = renderHook(() => useMCPModalForm()) + + act(() => { + result.current.actions.setAuthMethod(MCPAuthMethod.headers) + }) + + expect(result.current.state.authMethod).toBe(MCPAuthMethod.headers) + }) + + it('should update isDynamicRegistration', () => { + const { result } = renderHook(() => useMCPModalForm()) + + act(() => { + result.current.actions.setIsDynamicRegistration(false) + }) + + expect(result.current.state.isDynamicRegistration).toBe(false) + }) + + it('should update clientID', () => { + const { result } = renderHook(() => useMCPModalForm()) + + act(() => { + result.current.actions.setClientID('new-client-id') + }) + + expect(result.current.state.clientID).toBe('new-client-id') + }) + + it('should update credentials', () => { + const { result } = renderHook(() => useMCPModalForm()) + + act(() => { + result.current.actions.setCredentials('new-secret') + }) + + expect(result.current.state.credentials).toBe('new-secret') + }) + + it('should update appIcon', () => { + const { result } = renderHook(() => useMCPModalForm()) + const newIcon = { type: 'emoji' as const, icon: '🎉', background: '#00FF00' } + + act(() => { + result.current.actions.setAppIcon(newIcon) + }) + + expect(result.current.state.appIcon).toEqual(newIcon) + }) + + it('should toggle showAppIconPicker', () => { + const { result } = renderHook(() => useMCPModalForm()) + + expect(result.current.state.showAppIconPicker).toBe(false) + + act(() => { + result.current.actions.setShowAppIconPicker(true) + }) + + expect(result.current.state.showAppIconPicker).toBe(true) + }) + + it('should reset icon to default', () => { + const { result } = renderHook(() => useMCPModalForm()) + + // Change icon first + act(() => { + result.current.actions.setAppIcon({ type: 'emoji', icon: '🎉', background: '#00FF00' }) + }) + + expect(((result.current.state.appIcon) as AppIconEmojiSelection).icon).toBe('🎉') + + // Reset icon + act(() => { + result.current.actions.resetIcon() + }) + + expect(result.current.state.appIcon).toEqual({ + type: 'emoji', + icon: '🔗', + background: '#6366F1', + }) + }) + }) + + describe('handleUrlBlur', () => { + it('should not fetch icon in edit mode (when data is provided)', async () => { + const mockData = { + id: 'test', + name: 'Test', + icon: { content: '🔗', background: '#6366F1' }, + } as unknown as ToolWithProvider + const { result } = renderHook(() => useMCPModalForm(mockData)) + + await act(async () => { + await result.current.actions.handleUrlBlur('https://example.com') + }) + + // In edit mode, handleUrlBlur should return early + expect(result.current.state.isFetchingIcon).toBe(false) + }) + + it('should not fetch icon for invalid URL', async () => { + const { result } = renderHook(() => useMCPModalForm()) + + await act(async () => { + await result.current.actions.handleUrlBlur('not-a-valid-url') + }) + + expect(result.current.state.isFetchingIcon).toBe(false) + }) + + it('should handle error when icon fetch fails with error code', async () => { + const { uploadRemoteFileInfo } = await import('@/service/common') + const mockError = { + json: vi.fn().mockResolvedValue({ code: 'UPLOAD_ERROR' }), + } + vi.mocked(uploadRemoteFileInfo).mockRejectedValueOnce(mockError) + + const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + + const { result } = renderHook(() => useMCPModalForm()) + + await act(async () => { + await result.current.actions.handleUrlBlur('https://example.com/mcp') + }) + + // Should have called console.error + expect(consoleErrorSpy).toHaveBeenCalled() + // isFetchingIcon should be reset to false after error + expect(result.current.state.isFetchingIcon).toBe(false) + + consoleErrorSpy.mockRestore() + }) + + it('should handle error when icon fetch fails without error code', async () => { + const { uploadRemoteFileInfo } = await import('@/service/common') + const mockError = { + json: vi.fn().mockResolvedValue({}), + } + vi.mocked(uploadRemoteFileInfo).mockRejectedValueOnce(mockError) + + const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + + const { result } = renderHook(() => useMCPModalForm()) + + await act(async () => { + await result.current.actions.handleUrlBlur('https://example.com/mcp') + }) + + // Should have called console.error + expect(consoleErrorSpy).toHaveBeenCalled() + // isFetchingIcon should be reset to false after error + expect(result.current.state.isFetchingIcon).toBe(false) + + consoleErrorSpy.mockRestore() + }) + + it('should fetch icon successfully for valid URL in create mode', async () => { + vi.mocked(await import('@/service/common').then(m => m.uploadRemoteFileInfo)).mockResolvedValueOnce({ + id: 'file123', + name: 'icon.png', + size: 1024, + mime_type: 'image/png', + url: 'https://example.com/files/file123/file-preview/icon.png', + } as unknown as { id: string, name: string, size: number, mime_type: string, url: string }) + + const { result } = renderHook(() => useMCPModalForm()) + + await act(async () => { + await result.current.actions.handleUrlBlur('https://example.com/mcp') + }) + + // Icon should be set to image type + expect(result.current.state.appIcon.type).toBe('image') + expect(((result.current.state.appIcon) as AppIconImageSelection).url).toBe('https://example.com/files/file123/file-preview/icon.png') + expect(result.current.state.isFetchingIcon).toBe(false) + }) + }) + + describe('Edge Cases', () => { + // Base mock data with required icon field + const baseMockData = { + id: 'test', + name: 'Test', + icon: { content: '🔗', background: '#6366F1' }, + } + + it('should handle undefined configuration', () => { + const mockData = { ...baseMockData } as unknown as ToolWithProvider + + const { result } = renderHook(() => useMCPModalForm(mockData)) + + expect(result.current.state.timeout).toBe(30) + expect(result.current.state.sseReadTimeout).toBe(300) + }) + + it('should handle undefined authentication', () => { + const mockData = { ...baseMockData } as unknown as ToolWithProvider + + const { result } = renderHook(() => useMCPModalForm(mockData)) + + expect(result.current.state.clientID).toBe('') + expect(result.current.state.credentials).toBe('') + }) + + it('should handle undefined masked_headers', () => { + const mockData = { ...baseMockData } as unknown as ToolWithProvider + + const { result } = renderHook(() => useMCPModalForm(mockData)) + + expect(result.current.state.headers).toEqual([]) + }) + + it('should handle undefined is_dynamic_registration (defaults to true)', () => { + const mockData = { ...baseMockData } as unknown as ToolWithProvider + + const { result } = renderHook(() => useMCPModalForm(mockData)) + + expect(result.current.state.isDynamicRegistration).toBe(true) + }) + + it('should handle string icon URL', () => { + const mockData = { + id: 'test', + name: 'Test', + icon: 'https://example.com/icon.png', + } as unknown as ToolWithProvider + + const { result } = renderHook(() => useMCPModalForm(mockData)) + + expect(result.current.state.appIcon.type).toBe('image') + expect(((result.current.state.appIcon) as AppIconImageSelection).url).toBe('https://example.com/icon.png') + }) + }) +}) diff --git a/web/app/components/tools/mcp/hooks/use-mcp-modal-form.ts b/web/app/components/tools/mcp/hooks/use-mcp-modal-form.ts new file mode 100644 index 0000000000..286e2bf2e8 --- /dev/null +++ b/web/app/components/tools/mcp/hooks/use-mcp-modal-form.ts @@ -0,0 +1,203 @@ +'use client' +import type { HeaderItem } from '../headers-input' +import type { AppIconSelection } from '@/app/components/base/app-icon-picker' +import type { ToolWithProvider } from '@/app/components/workflow/types' +import { useCallback, useMemo, useRef, useState } from 'react' +import { getDomain } from 'tldts' +import { v4 as uuid } from 'uuid' +import Toast from '@/app/components/base/toast' +import { MCPAuthMethod } from '@/app/components/tools/types' +import { uploadRemoteFileInfo } from '@/service/common' + +const DEFAULT_ICON = { type: 'emoji', icon: '🔗', background: '#6366F1' } + +const extractFileId = (url: string) => { + const match = url.match(/files\/(.+?)\/file-preview/) + return match ? match[1] : null +} + +const getIcon = (data?: ToolWithProvider): AppIconSelection => { + if (!data) + return DEFAULT_ICON as AppIconSelection + if (typeof data.icon === 'string') + return { type: 'image', url: data.icon, fileId: extractFileId(data.icon) } as AppIconSelection + return { + ...data.icon, + icon: data.icon.content, + type: 'emoji', + } as unknown as AppIconSelection +} + +const getInitialHeaders = (data?: ToolWithProvider): HeaderItem[] => { + return Object.entries(data?.masked_headers || {}).map(([key, value]) => ({ id: uuid(), key, value })) +} + +export const isValidUrl = (string: string) => { + try { + const url = new URL(string) + return url.protocol === 'http:' || url.protocol === 'https:' + } + catch { + return false + } +} + +export const isValidServerID = (str: string) => { + return /^[a-z0-9_-]{1,24}$/.test(str) +} + +export type MCPModalFormState = { + url: string + name: string + appIcon: AppIconSelection + showAppIconPicker: boolean + serverIdentifier: string + timeout: number + sseReadTimeout: number + headers: HeaderItem[] + isFetchingIcon: boolean + authMethod: MCPAuthMethod + isDynamicRegistration: boolean + clientID: string + credentials: string +} + +export type MCPModalFormActions = { + setUrl: (url: string) => void + setName: (name: string) => void + setAppIcon: (icon: AppIconSelection) => void + setShowAppIconPicker: (show: boolean) => void + setServerIdentifier: (id: string) => void + setTimeout: (timeout: number) => void + setSseReadTimeout: (timeout: number) => void + setHeaders: (headers: HeaderItem[]) => void + setAuthMethod: (method: string) => void + setIsDynamicRegistration: (value: boolean) => void + setClientID: (id: string) => void + setCredentials: (credentials: string) => void + handleUrlBlur: (url: string) => Promise + resetIcon: () => void +} + +/** + * Custom hook for MCP Modal form state management. + * + * Note: This hook uses a `formKey` (data ID or 'create') to reset form state when + * switching between edit and create modes. All useState initializers read from `data` + * directly, and the key change triggers a remount of the consumer component. + */ +export const useMCPModalForm = (data?: ToolWithProvider) => { + const isCreate = !data + const originalServerUrl = data?.server_url + const originalServerID = data?.server_identifier + + // Form key for resetting state - changes when data changes + const formKey = useMemo(() => data?.id ?? 'create', [data?.id]) + + // Form state - initialized from data + const [url, setUrl] = useState(() => data?.server_url || '') + const [name, setName] = useState(() => data?.name || '') + const [appIcon, setAppIcon] = useState(() => getIcon(data)) + const [showAppIconPicker, setShowAppIconPicker] = useState(false) + const [serverIdentifier, setServerIdentifier] = useState(() => data?.server_identifier || '') + const [timeout, setMcpTimeout] = useState(() => data?.configuration?.timeout || 30) + const [sseReadTimeout, setSseReadTimeout] = useState(() => data?.configuration?.sse_read_timeout || 300) + const [headers, setHeaders] = useState(() => getInitialHeaders(data)) + const [isFetchingIcon, setIsFetchingIcon] = useState(false) + const appIconRef = useRef(null) + + // Auth state + const [authMethod, setAuthMethod] = useState(MCPAuthMethod.authentication) + const [isDynamicRegistration, setIsDynamicRegistration] = useState(() => isCreate ? true : (data?.is_dynamic_registration ?? true)) + const [clientID, setClientID] = useState(() => data?.authentication?.client_id || '') + const [credentials, setCredentials] = useState(() => data?.authentication?.client_secret || '') + + const handleUrlBlur = useCallback(async (urlValue: string) => { + if (data) + return + if (!isValidUrl(urlValue)) + return + const domain = getDomain(urlValue) + const remoteIcon = `https://www.google.com/s2/favicons?domain=${domain}&sz=128` + setIsFetchingIcon(true) + try { + const res = await uploadRemoteFileInfo(remoteIcon, undefined, true) + setAppIcon({ type: 'image', url: res.url, fileId: extractFileId(res.url) || '' }) + } + catch (e) { + let errorMessage = 'Failed to fetch remote icon' + if (e instanceof Response) { + try { + const errorData = await e.json() + if (errorData?.code) + errorMessage = `Upload failed: ${errorData.code}` + } + catch { + // Ignore JSON parsing errors + } + } + else if (e instanceof Error) { + errorMessage = e.message + } + console.error('Failed to fetch remote icon:', e) + Toast.notify({ type: 'warning', message: errorMessage }) + } + finally { + setIsFetchingIcon(false) + } + }, [data]) + + const resetIcon = useCallback(() => { + setAppIcon(getIcon(data)) + }, [data]) + + const handleAuthMethodChange = useCallback((value: string) => { + setAuthMethod(value as MCPAuthMethod) + }, []) + + return { + // Key for form reset (use as React key on parent) + formKey, + + // Metadata + isCreate, + originalServerUrl, + originalServerID, + appIconRef, + + // State + state: { + url, + name, + appIcon, + showAppIconPicker, + serverIdentifier, + timeout, + sseReadTimeout, + headers, + isFetchingIcon, + authMethod, + isDynamicRegistration, + clientID, + credentials, + } satisfies MCPModalFormState, + + // Actions + actions: { + setUrl, + setName, + setAppIcon, + setShowAppIconPicker, + setServerIdentifier, + setTimeout: setMcpTimeout, + setSseReadTimeout, + setHeaders, + setAuthMethod: handleAuthMethodChange, + setIsDynamicRegistration, + setClientID, + setCredentials, + handleUrlBlur, + resetIcon, + } satisfies MCPModalFormActions, + } +} diff --git a/web/app/components/tools/mcp/hooks/use-mcp-service-card.spec.ts b/web/app/components/tools/mcp/hooks/use-mcp-service-card.spec.ts new file mode 100644 index 0000000000..b36f724857 --- /dev/null +++ b/web/app/components/tools/mcp/hooks/use-mcp-service-card.spec.ts @@ -0,0 +1,451 @@ +import type { ReactNode } from 'react' +import type { AppDetailResponse } from '@/models/app' +import type { AppSSO } from '@/types/app' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { act, renderHook } from '@testing-library/react' +import * as React from 'react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { AppModeEnum } from '@/types/app' +import { useMCPServiceCardState } from './use-mcp-service-card' + +// Mutable mock data for MCP server detail +let mockMCPServerDetailData: { + id: string + status: string + server_code: string + description: string + parameters: Record +} | undefined = { + id: 'server-123', + status: 'active', + server_code: 'abc123', + description: 'Test server', + parameters: {}, +} + +// Mock service hooks +vi.mock('@/service/use-tools', () => ({ + useUpdateMCPServer: () => ({ + mutateAsync: vi.fn().mockResolvedValue({}), + }), + useRefreshMCPServerCode: () => ({ + mutateAsync: vi.fn().mockResolvedValue({}), + isPending: false, + }), + useMCPServerDetail: () => ({ + data: mockMCPServerDetailData, + }), + useInvalidateMCPServerDetail: () => vi.fn(), +})) + +// Mock workflow hook +vi.mock('@/service/use-workflow', () => ({ + useAppWorkflow: (appId: string) => ({ + data: appId + ? { + graph: { + nodes: [ + { data: { type: 'start', variables: [{ variable: 'input', label: 'Input' }] } }, + ], + }, + } + : undefined, + }), +})) + +// Mock app context +vi.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + isCurrentWorkspaceManager: true, + isCurrentWorkspaceEditor: true, + }), +})) + +// Mock apps service +vi.mock('@/service/apps', () => ({ + fetchAppDetail: vi.fn().mockResolvedValue({ + model_config: { + updated_at: '2024-01-01', + user_input_form: [], + }, + }), +})) + +describe('useMCPServiceCardState', () => { + const createWrapper = () => { + const queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + }, + }) + return ({ children }: { children: ReactNode }) => + React.createElement(QueryClientProvider, { client: queryClient }, children) + } + + const createMockAppInfo = (mode: AppModeEnum = AppModeEnum.CHAT): AppDetailResponse & Partial => ({ + id: 'app-123', + name: 'Test App', + mode, + api_base_url: 'https://api.example.com/v1', + } as AppDetailResponse & Partial) + + beforeEach(() => { + // Reset mock data to default (published server) + mockMCPServerDetailData = { + id: 'server-123', + status: 'active', + server_code: 'abc123', + description: 'Test server', + parameters: {}, + } + }) + + describe('Initialization', () => { + it('should initialize with correct default values for basic app', () => { + const appInfo = createMockAppInfo(AppModeEnum.CHAT) + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + expect(result.current.serverPublished).toBe(true) + expect(result.current.serverActivated).toBe(true) + expect(result.current.showConfirmDelete).toBe(false) + expect(result.current.showMCPServerModal).toBe(false) + }) + + it('should initialize with correct values for workflow app', () => { + const appInfo = createMockAppInfo(AppModeEnum.WORKFLOW) + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + expect(result.current.isLoading).toBe(false) + }) + + it('should initialize with correct values for advanced chat app', () => { + const appInfo = createMockAppInfo(AppModeEnum.ADVANCED_CHAT) + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + expect(result.current.isLoading).toBe(false) + }) + }) + + describe('Server URL Generation', () => { + it('should generate correct server URL when published', () => { + const appInfo = createMockAppInfo() + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + expect(result.current.serverURL).toBe('https://api.example.com/mcp/server/abc123/mcp') + }) + }) + + describe('Permission Flags', () => { + it('should have isCurrentWorkspaceManager as true', () => { + const appInfo = createMockAppInfo() + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + expect(result.current.isCurrentWorkspaceManager).toBe(true) + }) + + it('should have toggleDisabled false when editor has permissions', () => { + const appInfo = createMockAppInfo() + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + // Toggle is not disabled when user has permissions and app is published + expect(typeof result.current.toggleDisabled).toBe('boolean') + }) + + it('should have toggleDisabled true when triggerModeDisabled is true', () => { + const appInfo = createMockAppInfo() + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, true), + { wrapper: createWrapper() }, + ) + + expect(result.current.toggleDisabled).toBe(true) + }) + }) + + describe('UI State Actions', () => { + it('should open confirm delete modal', () => { + const appInfo = createMockAppInfo() + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + expect(result.current.showConfirmDelete).toBe(false) + + act(() => { + result.current.openConfirmDelete() + }) + + expect(result.current.showConfirmDelete).toBe(true) + }) + + it('should close confirm delete modal', () => { + const appInfo = createMockAppInfo() + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + act(() => { + result.current.openConfirmDelete() + }) + expect(result.current.showConfirmDelete).toBe(true) + + act(() => { + result.current.closeConfirmDelete() + }) + expect(result.current.showConfirmDelete).toBe(false) + }) + + it('should open server modal', () => { + const appInfo = createMockAppInfo() + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + expect(result.current.showMCPServerModal).toBe(false) + + act(() => { + result.current.openServerModal() + }) + + expect(result.current.showMCPServerModal).toBe(true) + }) + + it('should handle server modal hide', () => { + const appInfo = createMockAppInfo() + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + act(() => { + result.current.openServerModal() + }) + expect(result.current.showMCPServerModal).toBe(true) + + let hideResult: { shouldDeactivate: boolean } | undefined + act(() => { + hideResult = result.current.handleServerModalHide(false) + }) + + expect(result.current.showMCPServerModal).toBe(false) + expect(hideResult?.shouldDeactivate).toBe(true) + }) + + it('should not deactivate when wasActivated is true', () => { + const appInfo = createMockAppInfo() + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + let hideResult: { shouldDeactivate: boolean } | undefined + act(() => { + hideResult = result.current.handleServerModalHide(true) + }) + + expect(hideResult?.shouldDeactivate).toBe(false) + }) + }) + + describe('Handler Functions', () => { + it('should have handleGenCode function', () => { + const appInfo = createMockAppInfo() + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + expect(typeof result.current.handleGenCode).toBe('function') + }) + + it('should call handleGenCode and invalidate server detail', async () => { + const appInfo = createMockAppInfo() + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + await act(async () => { + await result.current.handleGenCode() + }) + + // handleGenCode should complete without error + expect(result.current.genLoading).toBe(false) + }) + + it('should have handleStatusChange function', () => { + const appInfo = createMockAppInfo() + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + expect(typeof result.current.handleStatusChange).toBe('function') + }) + + it('should have invalidateBasicAppConfig function', () => { + const appInfo = createMockAppInfo() + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + expect(typeof result.current.invalidateBasicAppConfig).toBe('function') + }) + + it('should call invalidateBasicAppConfig', () => { + const appInfo = createMockAppInfo() + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + // Call the function - should not throw + act(() => { + result.current.invalidateBasicAppConfig() + }) + + // Function should exist and be callable + expect(typeof result.current.invalidateBasicAppConfig).toBe('function') + }) + }) + + describe('Status Change', () => { + it('should return activated state when status change succeeds', async () => { + const appInfo = createMockAppInfo() + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + let statusResult: { activated: boolean } | undefined + await act(async () => { + statusResult = await result.current.handleStatusChange(true) + }) + + expect(statusResult?.activated).toBe(true) + }) + + it('should return deactivated state when disabling', async () => { + const appInfo = createMockAppInfo() + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + let statusResult: { activated: boolean } | undefined + await act(async () => { + statusResult = await result.current.handleStatusChange(false) + }) + + expect(statusResult?.activated).toBe(false) + }) + }) + + describe('Unpublished Server', () => { + it('should open modal and return not activated when enabling unpublished server', async () => { + // Set mock to return undefined (unpublished server) + mockMCPServerDetailData = undefined + + const appInfo = createMockAppInfo() + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + // Verify server is not published + expect(result.current.serverPublished).toBe(false) + + let statusResult: { activated: boolean } | undefined + await act(async () => { + statusResult = await result.current.handleStatusChange(true) + }) + + // Should open modal and return not activated + expect(result.current.showMCPServerModal).toBe(true) + expect(statusResult?.activated).toBe(false) + }) + }) + + describe('Loading States', () => { + it('should have genLoading state', () => { + const appInfo = createMockAppInfo() + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + expect(typeof result.current.genLoading).toBe('boolean') + }) + + it('should have isLoading state for basic app', () => { + const appInfo = createMockAppInfo(AppModeEnum.CHAT) + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + // Basic app doesn't need workflow, so isLoading should be false + expect(result.current.isLoading).toBe(false) + }) + }) + + describe('Detail Data', () => { + it('should return detail data when available', () => { + const appInfo = createMockAppInfo() + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + expect(result.current.detail).toBeDefined() + expect(result.current.detail?.id).toBe('server-123') + expect(result.current.detail?.status).toBe('active') + }) + }) + + describe('Latest Params', () => { + it('should return latestParams for workflow app', () => { + const appInfo = createMockAppInfo(AppModeEnum.WORKFLOW) + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + expect(Array.isArray(result.current.latestParams)).toBe(true) + }) + + it('should return latestParams for basic app', () => { + const appInfo = createMockAppInfo(AppModeEnum.CHAT) + const { result } = renderHook( + () => useMCPServiceCardState(appInfo, false), + { wrapper: createWrapper() }, + ) + + expect(Array.isArray(result.current.latestParams)).toBe(true) + }) + }) +}) diff --git a/web/app/components/tools/mcp/hooks/use-mcp-service-card.ts b/web/app/components/tools/mcp/hooks/use-mcp-service-card.ts new file mode 100644 index 0000000000..dfb1c75a2a --- /dev/null +++ b/web/app/components/tools/mcp/hooks/use-mcp-service-card.ts @@ -0,0 +1,179 @@ +'use client' +import type { AppDetailResponse } from '@/models/app' +import type { AppSSO } from '@/types/app' +import { useQuery, useQueryClient } from '@tanstack/react-query' +import { useCallback, useMemo, useState } from 'react' +import { BlockEnum } from '@/app/components/workflow/types' +import { useAppContext } from '@/context/app-context' +import { fetchAppDetail } from '@/service/apps' +import { + useInvalidateMCPServerDetail, + useMCPServerDetail, + useRefreshMCPServerCode, + useUpdateMCPServer, +} from '@/service/use-tools' +import { useAppWorkflow } from '@/service/use-workflow' +import { AppModeEnum } from '@/types/app' + +const BASIC_APP_CONFIG_KEY = 'basicAppConfig' + +type AppInfo = AppDetailResponse & Partial + +type BasicAppConfig = { + updated_at?: string + user_input_form?: Array> +} + +export const useMCPServiceCardState = ( + appInfo: AppInfo, + triggerModeDisabled: boolean, +) => { + const appId = appInfo.id + const queryClient = useQueryClient() + + // API hooks + const { mutateAsync: updateMCPServer } = useUpdateMCPServer() + const { mutateAsync: refreshMCPServerCode, isPending: genLoading } = useRefreshMCPServerCode() + const invalidateMCPServerDetail = useInvalidateMCPServerDetail() + + // Context + const { isCurrentWorkspaceManager, isCurrentWorkspaceEditor } = useAppContext() + + // UI state + const [showConfirmDelete, setShowConfirmDelete] = useState(false) + const [showMCPServerModal, setShowMCPServerModal] = useState(false) + + // Derived app type values + const isAdvancedApp = appInfo?.mode === AppModeEnum.ADVANCED_CHAT || appInfo?.mode === AppModeEnum.WORKFLOW + const isBasicApp = !isAdvancedApp + const isWorkflowApp = appInfo.mode === AppModeEnum.WORKFLOW + + // Workflow data for advanced apps + const { data: currentWorkflow } = useAppWorkflow(isAdvancedApp ? appId : '') + + // Basic app config fetch using React Query + const { data: basicAppConfig = {} } = useQuery({ + queryKey: [BASIC_APP_CONFIG_KEY, appId], + queryFn: async () => { + const res = await fetchAppDetail({ url: '/apps', id: appId }) + return (res?.model_config as BasicAppConfig) || {} + }, + enabled: isBasicApp && !!appId, + }) + + // MCP server detail + const { data: detail } = useMCPServerDetail(appId) + const { id, status, server_code } = detail ?? {} + + // Server state + const serverPublished = !!id + const serverActivated = status === 'active' + const serverURL = serverPublished + ? `${appInfo.api_base_url.replace('/v1', '')}/mcp/server/${server_code}/mcp` + : '***********' + + // App state checks + const appUnpublished = isAdvancedApp ? !currentWorkflow?.graph : !basicAppConfig.updated_at + const hasStartNode = currentWorkflow?.graph?.nodes?.some(node => node.data.type === BlockEnum.Start) + const missingStartNode = isWorkflowApp && !hasStartNode + const hasInsufficientPermissions = !isCurrentWorkspaceEditor + const toggleDisabled = hasInsufficientPermissions || appUnpublished || missingStartNode || triggerModeDisabled + const isMinimalState = appUnpublished || missingStartNode + + // Basic app input form + const basicAppInputForm = useMemo(() => { + if (!isBasicApp || !basicAppConfig?.user_input_form) + return [] + return (basicAppConfig.user_input_form as Array>).map((item) => { + const type = Object.keys(item)[0] + return { + ...(item[type] as object), + type: type || 'text-input', + } + }) + }, [basicAppConfig?.user_input_form, isBasicApp]) + + // Latest params for modal + const latestParams = useMemo(() => { + if (isAdvancedApp) { + if (!currentWorkflow?.graph) + return [] + type StartNodeData = { type: string, variables?: Array<{ variable: string, label: string }> } + const startNode = currentWorkflow?.graph.nodes.find(node => node.data.type === BlockEnum.Start) as { data: StartNodeData } | undefined + return startNode?.data.variables || [] + } + return basicAppInputForm + }, [currentWorkflow, basicAppInputForm, isAdvancedApp]) + + // Handlers + const handleGenCode = useCallback(async () => { + await refreshMCPServerCode(detail?.id || '') + invalidateMCPServerDetail(appId) + }, [refreshMCPServerCode, detail?.id, invalidateMCPServerDetail, appId]) + + const handleStatusChange = useCallback(async (state: boolean) => { + if (state && !serverPublished) { + setShowMCPServerModal(true) + return { activated: false } + } + + await updateMCPServer({ + appID: appId, + id: id || '', + description: detail?.description || '', + parameters: detail?.parameters || {}, + status: state ? 'active' : 'inactive', + }) + invalidateMCPServerDetail(appId) + return { activated: state } + }, [serverPublished, updateMCPServer, appId, id, detail, invalidateMCPServerDetail]) + + const handleServerModalHide = useCallback((wasActivated: boolean) => { + setShowMCPServerModal(false) + // If server wasn't activated before opening modal, keep it deactivated + return { shouldDeactivate: !wasActivated } + }, []) + + const openConfirmDelete = useCallback(() => setShowConfirmDelete(true), []) + const closeConfirmDelete = useCallback(() => setShowConfirmDelete(false), []) + const openServerModal = useCallback(() => setShowMCPServerModal(true), []) + + const invalidateBasicAppConfig = useCallback(() => { + queryClient.invalidateQueries({ queryKey: [BASIC_APP_CONFIG_KEY, appId] }) + }, [queryClient, appId]) + + return { + // Loading states + genLoading, + isLoading: isAdvancedApp ? !currentWorkflow : false, + + // Server state + serverPublished, + serverActivated, + serverURL, + detail, + + // Permission & validation flags + isCurrentWorkspaceManager, + toggleDisabled, + isMinimalState, + appUnpublished, + missingStartNode, + + // UI state + showConfirmDelete, + showMCPServerModal, + + // Data + latestParams, + + // Handlers + handleGenCode, + handleStatusChange, + handleServerModalHide, + openConfirmDelete, + closeConfirmDelete, + openServerModal, + invalidateBasicAppConfig, + } +} diff --git a/web/app/components/tools/mcp/mcp-server-modal.spec.tsx b/web/app/components/tools/mcp/mcp-server-modal.spec.tsx new file mode 100644 index 0000000000..62eabd0690 --- /dev/null +++ b/web/app/components/tools/mcp/mcp-server-modal.spec.tsx @@ -0,0 +1,361 @@ +import type { ReactNode } from 'react' +import type { MCPServerDetail } from '@/app/components/tools/types' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import * as React from 'react' +import { describe, expect, it, vi } from 'vitest' +import MCPServerModal from './mcp-server-modal' + +// Mock the services +vi.mock('@/service/use-tools', () => ({ + useCreateMCPServer: () => ({ + mutateAsync: vi.fn().mockResolvedValue({ result: 'success' }), + isPending: false, + }), + useUpdateMCPServer: () => ({ + mutateAsync: vi.fn().mockResolvedValue({ result: 'success' }), + isPending: false, + }), + useInvalidateMCPServerDetail: () => vi.fn(), +})) + +describe('MCPServerModal', () => { + const createWrapper = () => { + const queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + }, + }) + return ({ children }: { children: ReactNode }) => + React.createElement(QueryClientProvider, { client: queryClient }, children) + } + + const defaultProps = { + appID: 'app-123', + show: true, + onHide: vi.fn(), + } + + describe('Rendering', () => { + it('should render without crashing', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.server.modal.addTitle')).toBeInTheDocument() + }) + + it('should render add title when no data is provided', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.server.modal.addTitle')).toBeInTheDocument() + }) + + it('should render edit title when data is provided', () => { + const mockData = { + id: 'server-1', + description: 'Existing description', + parameters: {}, + } as unknown as MCPServerDetail + + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.server.modal.editTitle')).toBeInTheDocument() + }) + + it('should render description label', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.server.modal.description')).toBeInTheDocument() + }) + + it('should render required indicator', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('*')).toBeInTheDocument() + }) + + it('should render description textarea', () => { + render(, { wrapper: createWrapper() }) + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.descriptionPlaceholder') + expect(textarea).toBeInTheDocument() + }) + + it('should render cancel button', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.modal.cancel')).toBeInTheDocument() + }) + + it('should render confirm button in add mode', () => { + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.server.modal.confirm')).toBeInTheDocument() + }) + + it('should render save button in edit mode', () => { + const mockData = { + id: 'server-1', + description: 'Existing description', + parameters: {}, + } as unknown as MCPServerDetail + + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.modal.save')).toBeInTheDocument() + }) + + it('should render close icon', () => { + render(, { wrapper: createWrapper() }) + const closeButton = document.querySelector('.cursor-pointer svg') + expect(closeButton).toBeInTheDocument() + }) + }) + + describe('Parameters Section', () => { + it('should not render parameters section when no latestParams', () => { + render(, { wrapper: createWrapper() }) + expect(screen.queryByText('tools.mcp.server.modal.parameters')).not.toBeInTheDocument() + }) + + it('should render parameters section when latestParams is provided', () => { + const latestParams = [ + { variable: 'param1', label: 'Parameter 1', type: 'string' }, + ] + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.server.modal.parameters')).toBeInTheDocument() + }) + + it('should render parameters tip', () => { + const latestParams = [ + { variable: 'param1', label: 'Parameter 1', type: 'string' }, + ] + render(, { wrapper: createWrapper() }) + expect(screen.getByText('tools.mcp.server.modal.parametersTip')).toBeInTheDocument() + }) + + it('should render parameter items', () => { + const latestParams = [ + { variable: 'param1', label: 'Parameter 1', type: 'string' }, + { variable: 'param2', label: 'Parameter 2', type: 'number' }, + ] + render(, { wrapper: createWrapper() }) + expect(screen.getByText('Parameter 1')).toBeInTheDocument() + expect(screen.getByText('Parameter 2')).toBeInTheDocument() + }) + }) + + describe('Form Interactions', () => { + it('should update description when typing', () => { + render(, { wrapper: createWrapper() }) + + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.descriptionPlaceholder') + fireEvent.change(textarea, { target: { value: 'New description' } }) + + expect(textarea).toHaveValue('New description') + }) + + it('should call onHide when cancel button is clicked', () => { + const onHide = vi.fn() + render(, { wrapper: createWrapper() }) + + const cancelButton = screen.getByText('tools.mcp.modal.cancel') + fireEvent.click(cancelButton) + + expect(onHide).toHaveBeenCalledTimes(1) + }) + + it('should call onHide when close icon is clicked', () => { + const onHide = vi.fn() + render(, { wrapper: createWrapper() }) + + const closeButton = document.querySelector('.cursor-pointer') + if (closeButton) { + fireEvent.click(closeButton) + expect(onHide).toHaveBeenCalled() + } + }) + + it('should disable confirm button when description is empty', () => { + render(, { wrapper: createWrapper() }) + + const confirmButton = screen.getByText('tools.mcp.server.modal.confirm') + expect(confirmButton).toBeDisabled() + }) + + it('should enable confirm button when description is filled', () => { + render(, { wrapper: createWrapper() }) + + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.descriptionPlaceholder') + fireEvent.change(textarea, { target: { value: 'Valid description' } }) + + const confirmButton = screen.getByText('tools.mcp.server.modal.confirm') + expect(confirmButton).not.toBeDisabled() + }) + }) + + describe('Edit Mode', () => { + const mockData = { + id: 'server-1', + description: 'Existing description', + parameters: { param1: 'existing value' }, + } as unknown as MCPServerDetail + + it('should populate description with existing value', () => { + render(, { wrapper: createWrapper() }) + + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.descriptionPlaceholder') + expect(textarea).toHaveValue('Existing description') + }) + + it('should populate parameters with existing values', () => { + const latestParams = [ + { variable: 'param1', label: 'Parameter 1', type: 'string' }, + ] + render( + , + { wrapper: createWrapper() }, + ) + + const paramInput = screen.getByPlaceholderText('tools.mcp.server.modal.parametersPlaceholder') + expect(paramInput).toHaveValue('existing value') + }) + }) + + describe('Form Submission', () => { + it('should submit form with description', async () => { + const onHide = vi.fn() + render(, { wrapper: createWrapper() }) + + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.descriptionPlaceholder') + fireEvent.change(textarea, { target: { value: 'Test description' } }) + + const confirmButton = screen.getByText('tools.mcp.server.modal.confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(onHide).toHaveBeenCalled() + }) + }) + }) + + describe('With App Info', () => { + it('should use appInfo description as default when no data', () => { + const appInfo = { description: 'App default description' } + render(, { wrapper: createWrapper() }) + + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.descriptionPlaceholder') + expect(textarea).toHaveValue('App default description') + }) + + it('should prefer data description over appInfo description', () => { + const appInfo = { description: 'App default description' } + const mockData = { + id: 'server-1', + description: 'Data description', + parameters: {}, + } as unknown as MCPServerDetail + + render( + , + { wrapper: createWrapper() }, + ) + + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.descriptionPlaceholder') + expect(textarea).toHaveValue('Data description') + }) + }) + + describe('Not Shown State', () => { + it('should not render modal content when show is false', () => { + render(, { wrapper: createWrapper() }) + expect(screen.queryByText('tools.mcp.server.modal.addTitle')).not.toBeInTheDocument() + }) + }) + + describe('Update Mode Submission', () => { + it('should submit update when data is provided', async () => { + const onHide = vi.fn() + const mockData = { + id: 'server-1', + description: 'Existing description', + parameters: { param1: 'value1' }, + } as unknown as MCPServerDetail + + render( + , + { wrapper: createWrapper() }, + ) + + // Change description + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.descriptionPlaceholder') + fireEvent.change(textarea, { target: { value: 'Updated description' } }) + + // Click save button + const saveButton = screen.getByText('tools.mcp.modal.save') + fireEvent.click(saveButton) + + await waitFor(() => { + expect(onHide).toHaveBeenCalled() + }) + }) + }) + + describe('Parameter Handling', () => { + it('should update parameter value when changed', async () => { + const latestParams = [ + { variable: 'param1', label: 'Parameter 1', type: 'string' }, + { variable: 'param2', label: 'Parameter 2', type: 'string' }, + ] + + render( + , + { wrapper: createWrapper() }, + ) + + // Fill description first + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.descriptionPlaceholder') + fireEvent.change(textarea, { target: { value: 'Test description' } }) + + // Get all parameter inputs + const paramInputs = screen.getAllByPlaceholderText('tools.mcp.server.modal.parametersPlaceholder') + + // Change the first parameter value + fireEvent.change(paramInputs[0], { target: { value: 'new param value' } }) + + expect(paramInputs[0]).toHaveValue('new param value') + }) + + it('should submit with parameter values', async () => { + const onHide = vi.fn() + const latestParams = [ + { variable: 'param1', label: 'Parameter 1', type: 'string' }, + ] + + render( + , + { wrapper: createWrapper() }, + ) + + // Fill description + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.descriptionPlaceholder') + fireEvent.change(textarea, { target: { value: 'Test description' } }) + + // Fill parameter + const paramInput = screen.getByPlaceholderText('tools.mcp.server.modal.parametersPlaceholder') + fireEvent.change(paramInput, { target: { value: 'param value' } }) + + // Submit + const confirmButton = screen.getByText('tools.mcp.server.modal.confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(onHide).toHaveBeenCalled() + }) + }) + + it('should handle empty description submission', async () => { + const onHide = vi.fn() + render(, { wrapper: createWrapper() }) + + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.descriptionPlaceholder') + fireEvent.change(textarea, { target: { value: '' } }) + + // Button should be disabled + const confirmButton = screen.getByText('tools.mcp.server.modal.confirm') + expect(confirmButton).toBeDisabled() + }) + }) +}) diff --git a/web/app/components/tools/mcp/mcp-server-param-item.spec.tsx b/web/app/components/tools/mcp/mcp-server-param-item.spec.tsx new file mode 100644 index 0000000000..6e3a48e330 --- /dev/null +++ b/web/app/components/tools/mcp/mcp-server-param-item.spec.tsx @@ -0,0 +1,165 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { describe, expect, it, vi } from 'vitest' +import MCPServerParamItem from './mcp-server-param-item' + +describe('MCPServerParamItem', () => { + const defaultProps = { + data: { + label: 'Test Label', + variable: 'test_variable', + type: 'string', + }, + value: '', + onChange: vi.fn(), + } + + describe('Rendering', () => { + it('should render without crashing', () => { + render() + expect(screen.getByText('Test Label')).toBeInTheDocument() + }) + + it('should display label', () => { + render() + expect(screen.getByText('Test Label')).toBeInTheDocument() + }) + + it('should display variable name', () => { + render() + expect(screen.getByText('test_variable')).toBeInTheDocument() + }) + + it('should display type', () => { + render() + expect(screen.getByText('string')).toBeInTheDocument() + }) + + it('should display separator dot', () => { + render() + expect(screen.getByText('·')).toBeInTheDocument() + }) + + it('should render textarea with placeholder', () => { + render() + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.parametersPlaceholder') + expect(textarea).toBeInTheDocument() + }) + }) + + describe('Value Display', () => { + it('should display empty value by default', () => { + render() + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.parametersPlaceholder') + expect(textarea).toHaveValue('') + }) + + it('should display provided value', () => { + render() + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.parametersPlaceholder') + expect(textarea).toHaveValue('test value') + }) + + it('should display long text value', () => { + const longValue = 'This is a very long text value that might span multiple lines' + render() + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.parametersPlaceholder') + expect(textarea).toHaveValue(longValue) + }) + }) + + describe('User Interactions', () => { + it('should call onChange when text is entered', () => { + const onChange = vi.fn() + render() + + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.parametersPlaceholder') + fireEvent.change(textarea, { target: { value: 'new value' } }) + + expect(onChange).toHaveBeenCalledWith('new value') + }) + + it('should call onChange with empty string when cleared', () => { + const onChange = vi.fn() + render() + + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.parametersPlaceholder') + fireEvent.change(textarea, { target: { value: '' } }) + + expect(onChange).toHaveBeenCalledWith('') + }) + + it('should handle multiple changes', () => { + const onChange = vi.fn() + render() + + const textarea = screen.getByPlaceholderText('tools.mcp.server.modal.parametersPlaceholder') + + fireEvent.change(textarea, { target: { value: 'first' } }) + fireEvent.change(textarea, { target: { value: 'second' } }) + fireEvent.change(textarea, { target: { value: 'third' } }) + + expect(onChange).toHaveBeenCalledTimes(3) + expect(onChange).toHaveBeenLastCalledWith('third') + }) + }) + + describe('Different Data Types', () => { + it('should display number type', () => { + const props = { + ...defaultProps, + data: { label: 'Count', variable: 'count', type: 'number' }, + } + render() + expect(screen.getByText('number')).toBeInTheDocument() + }) + + it('should display boolean type', () => { + const props = { + ...defaultProps, + data: { label: 'Enabled', variable: 'enabled', type: 'boolean' }, + } + render() + expect(screen.getByText('boolean')).toBeInTheDocument() + }) + + it('should display array type', () => { + const props = { + ...defaultProps, + data: { label: 'Items', variable: 'items', type: 'array' }, + } + render() + expect(screen.getByText('array')).toBeInTheDocument() + }) + }) + + describe('Edge Cases', () => { + it('should handle special characters in label', () => { + const props = { + ...defaultProps, + data: { label: 'Test