diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index fdc05d1d65..cbd6edf94b 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -47,13 +47,9 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' run: uv run --directory api --dev lint-imports - - name: Run Basedpyright Checks + - name: Run Type Checks if: steps.changed-files.outputs.any_changed == 'true' - run: dev/basedpyright-check - - - name: Run Mypy Type Checks - if: steps.changed-files.outputs.any_changed == 'true' - run: uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped . + run: make type-check - name: Dotenv check if: steps.changed-files.outputs.any_changed == 'true' diff --git a/Makefile b/Makefile index e92a7b1314..20cede9a5e 100644 --- a/Makefile +++ b/Makefile @@ -68,9 +68,11 @@ lint: @echo "โœ… Linting complete" type-check: - @echo "๐Ÿ“ Running type check with basedpyright..." - @uv run --directory api --dev basedpyright - @echo "โœ… Type check complete" + @echo "๐Ÿ“ Running type checks (basedpyright + mypy + ty)..." + @./dev/basedpyright-check $(PATH_TO_CHECK) + @uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped . + @cd api && uv run ty check + @echo "โœ… Type checks complete" test: @echo "๐Ÿงช Running backend unit tests..." @@ -130,7 +132,7 @@ help: @echo " make format - Format code with ruff" @echo " make check - Check code with ruff" @echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)" - @echo " make type-check - Run type checking with basedpyright" + @echo " make type-check - Run type checks (basedpyright, mypy, ty)" @echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/)" @echo "" @echo "Docker Build Targets:" diff --git a/api/.env.example b/api/.env.example index c3b1474549..8bd2c706c1 100644 --- a/api/.env.example +++ b/api/.env.example @@ -617,6 +617,7 @@ PLUGIN_DAEMON_URL=http://127.0.0.1:5002 PLUGIN_REMOTE_INSTALL_PORT=5003 PLUGIN_REMOTE_INSTALL_HOST=localhost PLUGIN_MAX_PACKAGE_SIZE=15728640 +PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600 INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 # Marketplace configuration @@ -716,4 +717,3 @@ SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21 SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000 SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000 - diff --git a/api/.importlinter b/api/.importlinter index 2b4a3a5bd6..9dad254560 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -227,6 +227,9 @@ ignore_imports = core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset + core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service + core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task + core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods core.workflow.nodes.llm.node -> models.dataset core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer @@ -300,6 +303,58 @@ ignore_imports = core.workflow.nodes.agent.agent_node -> services core.workflow.nodes.tool.tool_node -> services +[importlinter:contract:model-runtime-no-internal-imports] +name = Model Runtime Internal Imports +type = forbidden +source_modules = + core.model_runtime +forbidden_modules = + configs + controllers + extensions + models + services + tasks + core.agent + core.app + core.base + core.callback_handler + core.datasource + core.db + core.entities + core.errors + core.extension + core.external_data_tool + core.file + core.helper + core.hosting_configuration + core.indexing_runner + core.llm_generator + core.logging + core.mcp + core.memory + core.model_manager + core.moderation + core.ops + core.plugin + core.prompt + core.provider_manager + core.rag + core.repositories + core.schemas + core.tools + core.trigger + core.variables + core.workflow +ignore_imports = + core.model_runtime.model_providers.__base.ai_model -> configs + core.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis + core.model_runtime.model_providers.__base.large_language_model -> configs + core.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type + core.model_runtime.model_providers.model_provider_factory -> configs + core.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis + core.model_runtime.model_providers.model_provider_factory -> models.provider_ids + [importlinter:contract:rsc] name = RSC type = layers diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 786094f295..d97e9a0440 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -243,6 +243,11 @@ class PluginConfig(BaseSettings): default=15728640 * 12, ) + PLUGIN_MODEL_SCHEMA_CACHE_TTL: PositiveInt = Field( + description="TTL in seconds for caching plugin model schemas in Redis", + default=60 * 60, + ) + class MarketplaceConfig(BaseSettings): """ diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index 7c16bc231f..c52dcf8a57 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -6,7 +6,6 @@ from contexts.wrapper import RecyclableContextVar if TYPE_CHECKING: from core.datasource.__base.datasource_provider import DatasourcePluginProviderController - from core.model_runtime.entities.model_entities import AIModelEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.tools.plugin_tool.provider import PluginToolProviderController from core.trigger.provider import PluginTriggerProviderController @@ -29,12 +28,6 @@ plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( ContextVar("plugin_model_providers_lock") ) -plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock")) - -plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar( - ContextVar("plugin_model_schemas") -) - datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = ( RecyclableContextVar(ContextVar("datasource_plugin_providers")) ) diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index e1ee2c24b8..03b602f6e8 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -243,15 +243,13 @@ class InsertExploreBannerApi(Resource): def post(self): payload = InsertExploreBannerPayload.model_validate(console_ns.payload) - content = { - "category": payload.category, - "title": payload.title, - "description": payload.description, - "img-src": payload.img_src, - } - banner = ExporleBanner( - content=content, + content={ + "category": payload.category, + "title": payload.title, + "description": payload.description, + "img-src": payload.img_src, + }, link=payload.link, sort=payload.sort, language=payload.language, diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 8fbbc51e21..30e4ed1119 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -148,6 +148,7 @@ class DatasetUpdatePayload(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None retrieval_model: dict[str, Any] | None = None + summary_index_setting: dict[str, Any] | None = None partial_member_list: list[dict[str, str]] | None = None external_retrieval_model: dict[str, Any] | None = None external_knowledge_id: str | None = None @@ -288,7 +289,14 @@ class DatasetListApi(Resource): @enterprise_license_required def get(self): current_user, current_tenant_id = current_account_with_tenant() - query = ConsoleDatasetListQuery.model_validate(request.args.to_dict()) + # Convert query parameters to dict, handling list parameters correctly + query_params: dict[str, str | list[str]] = dict(request.args.to_dict()) + # Handle ids and tag_ids as lists (Flask request.args.getlist returns list even for single value) + if "ids" in request.args: + query_params["ids"] = request.args.getlist("ids") + if "tag_ids" in request.args: + query_params["tag_ids"] = request.args.getlist("tag_ids") + query = ConsoleDatasetListQuery.model_validate(query_params) # provider = request.args.get("provider", default="vendor") if query.ids: datasets, total = DatasetService.get_datasets_by_ids(query.ids, current_tenant_id) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 57fb9abf29..6e3c0db8a3 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -45,6 +45,7 @@ from models.dataset import DocumentPipelineExecutionLog from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel from services.file_service import FileService +from tasks.generate_summary_index_task import generate_summary_index_task from ..app.error import ( ProviderModelCurrentlyNotSupportError, @@ -103,6 +104,10 @@ class DocumentRenamePayload(BaseModel): name: str +class GenerateSummaryPayload(BaseModel): + document_list: list[str] + + class DocumentBatchDownloadZipPayload(BaseModel): """Request payload for bulk downloading documents as a zip archive.""" @@ -125,6 +130,7 @@ register_schema_models( RetrievalModel, DocumentRetryPayload, DocumentRenamePayload, + GenerateSummaryPayload, DocumentBatchDownloadZipPayload, ) @@ -312,6 +318,13 @@ class DatasetDocumentListApi(Resource): paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items + + DocumentService.enrich_documents_with_summary_index_status( + documents=documents, + dataset=dataset, + tenant_id=current_tenant_id, + ) + if fetch: for document in documents: completed_segments = ( @@ -797,6 +810,7 @@ class DocumentApi(DocumentResource): "display_status": document.display_status, "doc_form": document.doc_form, "doc_language": document.doc_language, + "need_summary": document.need_summary if document.need_summary is not None else False, } else: dataset_process_rules = DatasetService.get_process_rules(dataset_id) @@ -832,6 +846,7 @@ class DocumentApi(DocumentResource): "display_status": document.display_status, "doc_form": document.doc_form, "doc_language": document.doc_language, + "need_summary": document.need_summary if document.need_summary is not None else False, } return response, 200 @@ -1255,3 +1270,137 @@ class DocumentPipelineExecutionLogApi(DocumentResource): "input_data": log.input_data, "datasource_node_id": log.datasource_node_id, }, 200 + + +@console_ns.route("/datasets//documents/generate-summary") +class DocumentGenerateSummaryApi(Resource): + @console_ns.doc("generate_summary_for_documents") + @console_ns.doc(description="Generate summary index for documents") + @console_ns.doc(params={"dataset_id": "Dataset ID"}) + @console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__]) + @console_ns.response(200, "Summary generation started successfully") + @console_ns.response(400, "Invalid request or dataset configuration") + @console_ns.response(403, "Permission denied") + @console_ns.response(404, "Dataset not found") + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") + def post(self, dataset_id): + """ + Generate summary index for specified documents. + + This endpoint checks if the dataset configuration supports summary generation + (indexing_technique must be 'high_quality' and summary_index_setting.enable must be true), + then asynchronously generates summary indexes for the provided documents. + """ + current_user, _ = current_account_with_tenant() + dataset_id = str(dataset_id) + + # Get dataset + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + + # Check permissions + if not current_user.is_dataset_editor: + raise Forbidden() + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + # Validate request payload + payload = GenerateSummaryPayload.model_validate(console_ns.payload or {}) + document_list = payload.document_list + + if not document_list: + from werkzeug.exceptions import BadRequest + + raise BadRequest("document_list cannot be empty.") + + # Check if dataset configuration supports summary generation + if dataset.indexing_technique != "high_quality": + raise ValueError( + f"Summary generation is only available for 'high_quality' indexing technique. " + f"Current indexing technique: {dataset.indexing_technique}" + ) + + summary_index_setting = dataset.summary_index_setting + if not summary_index_setting or not summary_index_setting.get("enable"): + raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.") + + # Verify all documents exist and belong to the dataset + documents = DocumentService.get_documents_by_ids(dataset_id, document_list) + + if len(documents) != len(document_list): + found_ids = {doc.id for doc in documents} + missing_ids = set(document_list) - found_ids + raise NotFound(f"Some documents not found: {list(missing_ids)}") + + # Dispatch async tasks for each document + for document in documents: + # Skip qa_model documents as they don't generate summaries + if document.doc_form == "qa_model": + logger.info("Skipping summary generation for qa_model document %s", document.id) + continue + + # Dispatch async task + generate_summary_index_task.delay(dataset_id, document.id) + logger.info( + "Dispatched summary generation task for document %s in dataset %s", + document.id, + dataset_id, + ) + + return {"result": "success"}, 200 + + +@console_ns.route("/datasets//documents//summary-status") +class DocumentSummaryStatusApi(DocumentResource): + @console_ns.doc("get_document_summary_status") + @console_ns.doc(description="Get summary index generation status for a document") + @console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @console_ns.response(200, "Summary status retrieved successfully") + @console_ns.response(404, "Document not found") + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, document_id): + """ + Get summary index generation status for a document. + + Returns: + - total_segments: Total number of segments in the document + - summary_status: Dictionary with status counts + - completed: Number of summaries completed + - generating: Number of summaries being generated + - error: Number of summaries with errors + - not_started: Number of segments without summary records + - summaries: List of summary records with status and content preview + """ + current_user, _ = current_account_with_tenant() + dataset_id = str(dataset_id) + document_id = str(document_id) + + # Get dataset + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + + # Check permissions + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + # Get summary status detail from service + from services.summary_index_service import SummaryIndexService + + result = SummaryIndexService.get_document_summary_status_detail( + document_id=document_id, + dataset_id=dataset_id, + ) + + return result, 200 diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 08e1ddd3e0..23a668112d 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -41,6 +41,17 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task +def _get_segment_with_summary(segment, dataset_id): + """Helper function to marshal segment and add summary information.""" + from services.summary_index_service import SummaryIndexService + + segment_dict = dict(marshal(segment, segment_fields)) + # Query summary for this segment (only enabled summaries) + summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id) + segment_dict["summary"] = summary.summary_content if summary else None + return segment_dict + + class SegmentListQuery(BaseModel): limit: int = Field(default=20, ge=1, le=100) status: list[str] = Field(default_factory=list) @@ -63,6 +74,7 @@ class SegmentUpdatePayload(BaseModel): keywords: list[str] | None = None regenerate_child_chunks: bool = False attachment_ids: list[str] | None = None + summary: str | None = None # Summary content for summary index class BatchImportPayload(BaseModel): @@ -181,8 +193,25 @@ class DatasetDocumentSegmentListApi(Resource): segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) + # Query summaries for all segments in this page (batch query for efficiency) + segment_ids = [segment.id for segment in segments.items] + summaries = {} + if segment_ids: + from services.summary_index_service import SummaryIndexService + + summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id) + # Only include enabled summaries (already filtered by service) + summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()} + + # Add summary to each segment + segments_with_summary = [] + for segment in segments.items: + segment_dict = dict(marshal(segment, segment_fields)) + segment_dict["summary"] = summaries.get(segment.id) + segments_with_summary.append(segment_dict) + response = { - "data": marshal(segments.items, segment_fields), + "data": segments_with_summary, "limit": limit, "total": segments.total, "total_pages": segments.pages, @@ -328,7 +357,7 @@ class DatasetDocumentSegmentAddApi(Resource): payload_dict = payload.model_dump(exclude_none=True) SegmentService.segment_create_args_validate(payload_dict, document) segment = SegmentService.create_segment(payload_dict, document, dataset) - return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 + return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200 @console_ns.route("/datasets//documents//segments/") @@ -390,10 +419,12 @@ class DatasetDocumentSegmentUpdateApi(Resource): payload = SegmentUpdatePayload.model_validate(console_ns.payload or {}) payload_dict = payload.model_dump(exclude_none=True) SegmentService.segment_create_args_validate(payload_dict, document) + + # Update segment (summary update with change detection is handled in SegmentService.update_segment) segment = SegmentService.update_segment( SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset ) - return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 + return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200 @setup_required @login_required diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 932cb4fcce..e62be13c2f 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,6 +1,13 @@ -from flask_restx import Resource +from flask_restx import Resource, fields from controllers.common.schema import register_schema_model +from fields.hit_testing_fields import ( + child_chunk_fields, + document_fields, + files_fields, + hit_testing_record_fields, + segment_fields, +) from libs.login import login_required from .. import console_ns @@ -14,13 +21,45 @@ from ..wraps import ( register_schema_model(console_ns, HitTestingPayload) +def _get_or_create_model(model_name: str, field_def): + """Get or create a flask_restx model to avoid dict type issues in Swagger.""" + existing = console_ns.models.get(model_name) + if existing is None: + existing = console_ns.model(model_name, field_def) + return existing + + +# Register models for flask_restx to avoid dict type issues in Swagger +document_model = _get_or_create_model("HitTestingDocument", document_fields) + +segment_fields_copy = segment_fields.copy() +segment_fields_copy["document"] = fields.Nested(document_model) +segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy) + +child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields) +files_model = _get_or_create_model("HitTestingFile", files_fields) + +hit_testing_record_fields_copy = hit_testing_record_fields.copy() +hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model) +hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model)) +hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model)) +hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy) + +# Response model for hit testing API +hit_testing_response_fields = { + "query": fields.String, + "records": fields.List(fields.Nested(hit_testing_record_model)), +} +hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields) + + @console_ns.route("/datasets//hit-testing") class HitTestingApi(Resource, DatasetsHitTestingBase): @console_ns.doc("test_dataset_retrieval") @console_ns.doc(description="Test dataset knowledge retrieval") @console_ns.doc(params={"dataset_id": "Dataset ID"}) @console_ns.expect(console_ns.models[HitTestingPayload.__name__]) - @console_ns.response(200, "Hit testing completed successfully") + @console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model) @console_ns.response(404, "Dataset not found") @console_ns.response(400, "Invalid parameters") @setup_required diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 28864a140a..c11f64585a 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -46,6 +46,7 @@ class DatasetCreatePayload(BaseModel): retrieval_model: RetrievalModel | None = None embedding_model: str | None = None embedding_model_provider: str | None = None + summary_index_setting: dict | None = None class DatasetUpdatePayload(BaseModel): @@ -217,6 +218,7 @@ class DatasetListApi(DatasetApiResource): embedding_model_provider=payload.embedding_model_provider, embedding_model_name=payload.embedding_model, retrieval_model=payload.retrieval_model, + summary_index_setting=payload.summary_index_setting, ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index c85c1cf81e..a01524f1bc 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -45,6 +45,7 @@ from services.entities.knowledge_entities.knowledge_entities import ( Segmentation, ) from services.file_service import FileService +from services.summary_index_service import SummaryIndexService class DocumentTextCreatePayload(BaseModel): @@ -508,6 +509,12 @@ class DocumentListApi(DatasetApiResource): ) documents = paginated_documents.items + DocumentService.enrich_documents_with_summary_index_status( + documents=documents, + dataset=dataset, + tenant_id=tenant_id, + ) + response = { "data": marshal(documents, document_fields), "has_more": len(documents) == query_params.limit, @@ -612,6 +619,16 @@ class DocumentApi(DatasetApiResource): if metadata not in self.METADATA_CHOICES: raise InvalidMetadataError(f"Invalid metadata value: {metadata}") + # Calculate summary_index_status if needed + summary_index_status = None + has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True + if has_summary_index and document.need_summary is True: + summary_index_status = SummaryIndexService.get_document_summary_index_status( + document_id=document_id, + dataset_id=dataset_id, + tenant_id=tenant_id, + ) + if metadata == "only": response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} elif metadata == "without": @@ -646,6 +663,8 @@ class DocumentApi(DatasetApiResource): "display_status": document.display_status, "doc_form": document.doc_form, "doc_language": document.doc_language, + "summary_index_status": summary_index_status, + "need_summary": document.need_summary if document.need_summary is not None else False, } else: dataset_process_rules = DatasetService.get_process_rules(dataset_id) @@ -681,6 +700,8 @@ class DocumentApi(DatasetApiResource): "display_status": document.display_status, "doc_form": document.doc_form, "doc_language": document.doc_language, + "summary_index_status": summary_index_status, + "need_summary": document.need_summary if document.need_summary is not None else False, } return response diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 74c6d2eca6..d1e2f16b6f 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -79,6 +79,7 @@ class AppGenerateResponseConverter(ABC): "document_name": resource["document_name"], "score": resource["score"], "content": resource["content"], + "summary": resource.get("summary"), } ) metadata["retriever_resources"] = updated_resources diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index d4093b5245..b1ba3c3e2a 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, Field, field_validator class PreviewDetail(BaseModel): content: str + summary: str | None = None child_chunks: list[str] | None = None diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index 120fb73cdb..c0fefef3d0 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -104,6 +104,8 @@ def download(f: File, /): ): return _download_file_content(f.storage_key) elif f.transfer_method == FileTransferMethod.REMOTE_URL: + if f.remote_url is None: + raise ValueError("Missing file remote_url") response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response.raise_for_status() return response.content @@ -134,6 +136,8 @@ def _download_file_content(path: str, /): def _get_encoded_string(f: File, /): match f.transfer_method: case FileTransferMethod.REMOTE_URL: + if f.remote_url is None: + raise ValueError("Missing file remote_url") response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response.raise_for_status() data = response.content diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 128c64ff2c..ddccfbaf45 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -4,8 +4,10 @@ Proxy requests to avoid SSRF import logging import time +from typing import Any, TypeAlias import httpx +from pydantic import TypeAdapter, ValidationError from configs import dify_config from core.helper.http_client_pooling import get_pooled_http_client @@ -18,6 +20,9 @@ SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES BACKOFF_FACTOR = 0.5 STATUS_FORCELIST = [429, 500, 502, 503, 504] +Headers: TypeAlias = dict[str, str] +_HEADERS_ADAPTER = TypeAdapter(Headers) + _SSL_VERIFIED_POOL_KEY = "ssrf:verified" _SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified" _SSRF_CLIENT_LIMITS = httpx.Limits( @@ -76,7 +81,7 @@ def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client: ) -def _get_user_provided_host_header(headers: dict | None) -> str | None: +def _get_user_provided_host_header(headers: Headers | None) -> str | None: """ Extract the user-provided Host header from the headers dict. @@ -92,7 +97,7 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None: return None -def _inject_trace_headers(headers: dict | None) -> dict: +def _inject_trace_headers(headers: Headers | None) -> Headers: """ Inject W3C traceparent header for distributed tracing. @@ -125,7 +130,7 @@ def _inject_trace_headers(headers: dict | None) -> dict: return headers -def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def make_request(method: str, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: # Convert requests-style allow_redirects to httpx-style follow_redirects if "allow_redirects" in kwargs: allow_redirects = kwargs.pop("allow_redirects") @@ -142,10 +147,15 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): # prioritize per-call option, which can be switched on and off inside the HTTP node on the web UI verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY) + if not isinstance(verify_option, bool): + raise ValueError("ssl_verify must be a boolean") client = _get_ssrf_client(verify_option) # Inject traceparent header for distributed tracing (when OTEL is not enabled) - headers = kwargs.get("headers") or {} + try: + headers: Headers = _HEADERS_ADAPTER.validate_python(kwargs.get("headers") or {}) + except ValidationError as e: + raise ValueError("headers must be a mapping of string keys to string values") from e headers = _inject_trace_headers(headers) kwargs["headers"] = headers @@ -198,25 +208,25 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}") -def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def get(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("GET", url, max_retries=max_retries, **kwargs) -def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def post(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("POST", url, max_retries=max_retries, **kwargs) -def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def put(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("PUT", url, max_retries=max_retries, **kwargs) -def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def patch(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("PATCH", url, max_retries=max_retries, **kwargs) -def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def delete(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("DELETE", url, max_retries=max_retries, **kwargs) -def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): +def head(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("HEAD", url, max_retries=max_retries, **kwargs) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index f1b50f360b..e172e88298 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -311,14 +311,18 @@ class IndexingRunner: qa_preview_texts: list[QAPreviewDetail] = [] total_segments = 0 + # doc_form represents the segmentation method (general, parent-child, QA) index_type = doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() + # one extract_setting is one source document for extract_setting in extract_settings: # extract processing_rule = DatasetProcessRule( mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) ) + # Extract document content text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) + # Cleaning and segmentation documents = index_processor.transform( text_docs, current_user=None, @@ -361,6 +365,12 @@ class IndexingRunner: if doc_form and doc_form == "qa_model": return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[]) + + # Generate summary preview + summary_index_setting = tmp_processing_rule.get("summary_index_setting") + if summary_index_setting and summary_index_setting.get("enable") and preview_texts: + preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting) + return IndexingEstimate(total_segments=total_segments, preview=preview_texts) def _extract( diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index ec2b7f2d44..d46cf049dd 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -434,3 +434,20 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex You should edit the prompt according to the IDEAL OUTPUT.""" INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}.""" + +DEFAULT_GENERATOR_SUMMARY_PROMPT = ( + """Summarize the following content. Extract only the key information and main points. """ + """Remove redundant details. + +Requirements: +1. Write a concise summary in plain text +2. Use the same language as the input content +3. Focus on important facts, concepts, and details +4. If images are included, describe their key information +5. Do not use words like "ๅฅฝ็š„", "ok", "I understand", "This text discusses", "The content mentions" +6. Write directly without extra words + +Output only the summary text. Start summarizing now: + +""" +) diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 45f0335c2e..c3e50eaddd 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -1,10 +1,11 @@ import decimal import hashlib -from threading import Lock +import logging -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, ValidationError +from redis import RedisError -import contexts +from configs import dify_config from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from core.model_runtime.entities.model_entities import ( @@ -24,6 +25,9 @@ from core.model_runtime.errors.invoke import ( InvokeServerUnavailableError, ) from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from extensions.ext_redis import redis_client + +logger = logging.getLogger(__name__) class AIModel(BaseModel): @@ -144,34 +148,60 @@ class AIModel(BaseModel): plugin_model_manager = PluginModelClient() cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}" - # sort credentials sorted_credentials = sorted(credentials.items()) if credentials else [] cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) + cached_schema_json = None try: - contexts.plugin_model_schemas.get() - except LookupError: - contexts.plugin_model_schemas.set({}) - contexts.plugin_model_schema_lock.set(Lock()) - - with contexts.plugin_model_schema_lock.get(): - if cache_key in contexts.plugin_model_schemas.get(): - return contexts.plugin_model_schemas.get()[cache_key] - - schema = plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model_type=self.model_type.value, - model=model, - credentials=credentials or {}, + cached_schema_json = redis_client.get(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to read plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, ) + if cached_schema_json: + try: + return AIModelEntity.model_validate_json(cached_schema_json) + except ValidationError: + logger.warning( + "Failed to validate cached plugin model schema for model %s", + model, + exc_info=True, + ) + try: + redis_client.delete(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to delete invalid plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) - if schema: - contexts.plugin_model_schemas.get()[cache_key] = schema + schema = plugin_model_manager.get_model_schema( + tenant_id=self.tenant_id, + user_id="unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model_type=self.model_type.value, + model=model, + credentials=credentials or {}, + ) - return schema + if schema: + try: + redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to write plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + return schema def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None: """ diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 28f162a928..64538a6779 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -5,7 +5,11 @@ import logging from collections.abc import Sequence from threading import Lock +from pydantic import ValidationError +from redis import RedisError + import contexts +from configs import dify_config from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.model_providers.__base.ai_model import AIModel @@ -18,6 +22,7 @@ from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from extensions.ext_redis import redis_client from models.provider_ids import ModelProviderID logger = logging.getLogger(__name__) @@ -175,34 +180,60 @@ class ModelProviderFactory: """ plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}" - # sort credentials sorted_credentials = sorted(credentials.items()) if credentials else [] cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) + cached_schema_json = None try: - contexts.plugin_model_schemas.get() - except LookupError: - contexts.plugin_model_schemas.set({}) - contexts.plugin_model_schema_lock.set(Lock()) - - with contexts.plugin_model_schema_lock.get(): - if cache_key in contexts.plugin_model_schemas.get(): - return contexts.plugin_model_schemas.get()[cache_key] - - schema = self.plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_id, - provider=provider_name, - model_type=model_type.value, - model=model, - credentials=credentials or {}, + cached_schema_json = redis_client.get(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to read plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, ) + if cached_schema_json: + try: + return AIModelEntity.model_validate_json(cached_schema_json) + except ValidationError: + logger.warning( + "Failed to validate cached plugin model schema for model %s", + model, + exc_info=True, + ) + try: + redis_client.delete(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to delete invalid plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) - if schema: - contexts.plugin_model_schemas.get()[cache_key] = schema + schema = self.plugin_model_manager.get_model_schema( + tenant_id=self.tenant_id, + user_id="unknown", + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials or {}, + ) - return schema + if schema: + try: + redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to write plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + return schema def get_models( self, diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 8ec1ce6242..91c16ce079 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -24,7 +24,13 @@ from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding +from models.dataset import ( + ChildChunk, + Dataset, + DocumentSegment, + DocumentSegmentSummary, + SegmentAttachmentBinding, +) from models.dataset import Document as DatasetDocument from models.model import UploadFile from services.external_knowledge_service import ExternalDatasetService @@ -389,15 +395,15 @@ class RetrievalService: .all() } - records = [] - include_segment_ids = set() - segment_child_map = {} - valid_dataset_documents = {} image_doc_ids: list[Any] = [] child_index_node_ids = [] index_node_ids = [] doc_to_document_map = {} + summary_segment_ids = set() # Track segments retrieved via summary + summary_score_map: dict[str, float] = {} # Map original_chunk_id to summary score + + # First pass: collect all document IDs and identify summary documents for document in documents: document_id = document.metadata.get("document_id") if document_id not in dataset_documents: @@ -408,16 +414,39 @@ class RetrievalService: continue valid_dataset_documents[document_id] = dataset_document + doc_id = document.metadata.get("doc_id") or "" + doc_to_document_map[doc_id] = document + + # Check if this is a summary document + is_summary = document.metadata.get("is_summary", False) + if is_summary: + # For summary documents, find the original chunk via original_chunk_id + original_chunk_id = document.metadata.get("original_chunk_id") + if original_chunk_id: + summary_segment_ids.add(original_chunk_id) + # Save summary's score for later use + summary_score = document.metadata.get("score") + if summary_score is not None: + try: + summary_score_float = float(summary_score) + # If the same segment has multiple summary hits, take the highest score + if original_chunk_id not in summary_score_map: + summary_score_map[original_chunk_id] = summary_score_float + else: + summary_score_map[original_chunk_id] = max( + summary_score_map[original_chunk_id], summary_score_float + ) + except (ValueError, TypeError): + # Skip invalid score values + pass + continue # Skip adding to other lists for summary documents + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - doc_id = document.metadata.get("doc_id") or "" - doc_to_document_map[doc_id] = document if document.metadata.get("doc_type") == DocType.IMAGE: image_doc_ids.append(doc_id) else: child_index_node_ids.append(doc_id) else: - doc_id = document.metadata.get("doc_id") or "" - doc_to_document_map[doc_id] = document if document.metadata.get("doc_type") == DocType.IMAGE: image_doc_ids.append(doc_id) else: @@ -433,6 +462,7 @@ class RetrievalService: attachment_map: dict[str, list[dict[str, Any]]] = {} child_chunk_map: dict[str, list[ChildChunk]] = {} doc_segment_map: dict[str, list[str]] = {} + segment_summary_map: dict[str, str] = {} # Map segment_id to summary content with session_factory.create_session() as session: attachments = cls.get_segment_attachment_infos(image_doc_ids, session) @@ -447,6 +477,7 @@ class RetrievalService: doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"]) else: doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]] + child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids)) child_index_nodes = session.execute(child_chunk_stmt).scalars().all() @@ -470,6 +501,7 @@ class RetrievalService: index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore for index_node_segment in index_node_segments: doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id] + if segment_ids: document_segment_stmt = select(DocumentSegment).where( DocumentSegment.enabled == True, @@ -481,6 +513,40 @@ class RetrievalService: if index_node_segments: segments.extend(index_node_segments) + # Handle summary documents: query segments by original_chunk_id + if summary_segment_ids: + summary_segment_ids_list = list(summary_segment_ids) + summary_segment_stmt = select(DocumentSegment).where( + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.id.in_(summary_segment_ids_list), + ) + summary_segments = session.execute(summary_segment_stmt).scalars().all() # type: ignore + segments.extend(summary_segments) + # Add summary segment IDs to segment_ids for summary query + for seg in summary_segments: + if seg.id not in segment_ids: + segment_ids.append(seg.id) + + # Batch query summaries for segments retrieved via summary (only enabled summaries) + if summary_segment_ids: + summaries = ( + session.query(DocumentSegmentSummary) + .filter( + DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)), + DocumentSegmentSummary.status == "completed", + DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries + ) + .all() + ) + for summary in summaries: + if summary.summary_content: + segment_summary_map[summary.chunk_id] = summary.summary_content + + include_segment_ids = set() + segment_child_map: dict[str, dict[str, Any]] = {} + records: list[dict[str, Any]] = [] + for segment in segments: child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, []) attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, []) @@ -489,30 +555,44 @@ class RetrievalService: if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: if segment.id not in include_segment_ids: include_segment_ids.add(segment.id) + # Check if this segment was retrieved via summary + # Use summary score as base score if available, otherwise 0.0 + max_score = summary_score_map.get(segment.id, 0.0) + if child_chunks or attachment_infos: child_chunk_details = [] - max_score = 0.0 for child_chunk in child_chunks: - document = doc_to_document_map[child_chunk.index_node_id] + child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id) + if child_document: + child_score = child_document.metadata.get("score", 0.0) + else: + child_score = 0.0 child_chunk_detail = { "id": child_chunk.id, "content": child_chunk.content, "position": child_chunk.position, - "score": document.metadata.get("score", 0.0) if document else 0.0, + "score": child_score, } child_chunk_details.append(child_chunk_detail) - max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0) + max_score = max(max_score, child_score) for attachment_info in attachment_infos: - file_document = doc_to_document_map[attachment_info["id"]] - max_score = max( - max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0 - ) + file_document = doc_to_document_map.get(attachment_info["id"]) + if file_document: + max_score = max(max_score, file_document.metadata.get("score", 0.0)) map_detail = { "max_score": max_score, "child_chunks": child_chunk_details, } segment_child_map[segment.id] = map_detail + else: + # No child chunks or attachments, use summary score if available + summary_score = summary_score_map.get(segment.id) + if summary_score is not None: + segment_child_map[segment.id] = { + "max_score": summary_score, + "child_chunks": [], + } record: dict[str, Any] = { "segment": segment, } @@ -520,14 +600,23 @@ class RetrievalService: else: if segment.id not in include_segment_ids: include_segment_ids.add(segment.id) - max_score = 0.0 - segment_document = doc_to_document_map.get(segment.index_node_id) - if segment_document: - max_score = max(max_score, segment_document.metadata.get("score", 0.0)) + + # Check if this segment was retrieved via summary + # Use summary score if available (summary retrieval takes priority) + max_score = summary_score_map.get(segment.id, 0.0) + + # If not retrieved via summary, use original segment's score + if segment.id not in summary_score_map: + segment_document = doc_to_document_map.get(segment.index_node_id) + if segment_document: + max_score = max(max_score, segment_document.metadata.get("score", 0.0)) + + # Also consider attachment scores for attachment_info in attachment_infos: file_doc = doc_to_document_map.get(attachment_info["id"]) if file_doc: max_score = max(max_score, file_doc.metadata.get("score", 0.0)) + record = { "segment": segment, "score": max_score, @@ -576,9 +665,16 @@ class RetrievalService: else None ) + # Extract summary if this segment was retrieved via summary + summary_content = segment_summary_map.get(segment.id) + # Create RetrievalSegments object retrieval_segment = RetrievalSegments( - segment=segment, child_chunks=child_chunks_list, score=score, files=files + segment=segment, + child_chunks=child_chunks_list, + score=score, + files=files, + summary=summary_content, ) result.append(retrieval_segment) diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index f8c62b908a..4a4a458f2e 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -391,46 +391,78 @@ class QdrantVector(BaseVector): return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - """Return docs most similar by bm25. + """Return docs most similar by full-text search. + + Searches each keyword separately and merges results to ensure documents + matching ANY keyword are returned (OR logic). Results are capped at top_k. + + Args: + query: Search query text. Multi-word queries are split into keywords, + with each keyword searched separately. Limited to 10 keywords. + **kwargs: Additional search parameters (top_k, document_ids_filter) + Returns: - List of documents most similar to the query text and distance for each. + List of up to top_k unique documents matching any query keyword. """ from qdrant_client.http import models - scroll_filter = models.Filter( - must=[ - models.FieldCondition( - key="group_id", - match=models.MatchValue(value=self._group_id), - ), - models.FieldCondition( - key="page_content", - match=models.MatchText(text=query), - ), - ] - ) + # Build base must conditions (AND logic) for metadata filters + base_must_conditions: list = [ + models.FieldCondition( + key="group_id", + match=models.MatchValue(value=self._group_id), + ), + ] + document_ids_filter = kwargs.get("document_ids_filter") if document_ids_filter: - if scroll_filter.must: - scroll_filter.must.append( - models.FieldCondition( - key="metadata.document_id", - match=models.MatchAny(any=document_ids_filter), - ) + base_must_conditions.append( + models.FieldCondition( + key="metadata.document_id", + match=models.MatchAny(any=document_ids_filter), ) - response = self._client.scroll( - collection_name=self._collection_name, - scroll_filter=scroll_filter, - limit=kwargs.get("top_k", 2), - with_payload=True, - with_vectors=True, - ) - results = response[0] - documents = [] - for result in results: - if result: - document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY) - documents.append(document) + ) + + # Split query into keywords, deduplicate and limit to prevent DoS + keywords = list(dict.fromkeys(kw.strip() for kw in query.strip().split() if kw.strip()))[:10] + + if not keywords: + return [] + + top_k = kwargs.get("top_k", 2) + seen_ids: set[str | int] = set() + documents: list[Document] = [] + + # Search each keyword separately and merge results. + # This ensures each keyword gets its own search, preventing one keyword's + # results from completely overshadowing another's due to scroll ordering. + for keyword in keywords: + scroll_filter = models.Filter( + must=[ + *base_must_conditions, + models.FieldCondition( + key="page_content", + match=models.MatchText(text=keyword), + ), + ] + ) + + response = self._client.scroll( + collection_name=self._collection_name, + scroll_filter=scroll_filter, + limit=top_k, + with_payload=True, + with_vectors=True, + ) + results = response[0] + + for result in results: + if result and result.id not in seen_ids: + seen_ids.add(result.id) + document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY) + documents.append(document) + if len(documents) >= top_k: + return documents return documents diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py index b54a37b49e..f6834ab87b 100644 --- a/api/core/rag/embedding/retrieval.py +++ b/api/core/rag/embedding/retrieval.py @@ -20,3 +20,4 @@ class RetrievalSegments(BaseModel): child_chunks: list[RetrievalChildChunk] | None = None score: float | None = None files: list[dict[str, str | int]] | None = None + summary: str | None = None # Summary content if retrieved via summary index diff --git a/api/core/rag/entities/citation_metadata.py b/api/core/rag/entities/citation_metadata.py index 9f66cd9a03..aec5c353f8 100644 --- a/api/core/rag/entities/citation_metadata.py +++ b/api/core/rag/entities/citation_metadata.py @@ -22,3 +22,4 @@ class RetrievalSourceMetadata(BaseModel): doc_metadata: dict[str, Any] | None = None title: str | None = None files: list[dict[str, Any]] | None = None + summary: str | None = None diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 511f5a698d..1ddbfc5864 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -1,4 +1,7 @@ -"""Abstract interface for document loader implementations.""" +"""Word (.docx) document extractor used for RAG ingestion. + +Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`). +""" import logging import mimetypes @@ -8,7 +11,6 @@ import tempfile import uuid from urllib.parse import urlparse -import httpx from docx import Document as DocxDocument from docx.oxml.ns import qn from docx.text.run import Run @@ -44,7 +46,7 @@ class WordExtractor(BaseExtractor): # If the file is a web path, download it to a temporary file, and use that if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path): - response = httpx.get(self.file_path, timeout=None) + response = ssrf_proxy.get(self.file_path) if response.status_code != 200: response.close() @@ -55,6 +57,7 @@ class WordExtractor(BaseExtractor): self.temp_file = tempfile.NamedTemporaryFile() # noqa SIM115 try: self.temp_file.write(response.content) + self.temp_file.flush() finally: response.close() self.file_path = self.temp_file.name diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index e36b54eedd..151a3de7d9 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -13,6 +13,7 @@ from urllib.parse import unquote, urlparse import httpx from configs import dify_config +from core.entities.knowledge_entities import PreviewDetail from core.helper import ssrf_proxy from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.index_processor.constant.doc_type import DocType @@ -45,6 +46,17 @@ class BaseIndexProcessor(ABC): def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: raise NotImplementedError + @abstractmethod + def generate_summary_preview( + self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + ) -> list[PreviewDetail]: + """ + For each segment in preview_texts, generate a summary using LLM and attach it to the segment. + The summary can be stored in a new attribute, e.g., summary. + This method should be implemented by subclasses. + """ + raise NotImplementedError + @abstractmethod def load( self, diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index cf68cff7dc..ab91e29145 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -1,9 +1,27 @@ """Paragraph index processor.""" +import logging +import re import uuid from collections.abc import Mapping -from typing import Any +from typing import Any, cast +logger = logging.getLogger(__name__) + +from core.entities.knowledge_entities import PreviewDetail +from core.file import File, FileTransferMethod, FileType, file_manager +from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT +from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentUnionTypes, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.provider_manager import ProviderManager from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.retrieval_service import RetrievalService @@ -17,12 +35,17 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols +from core.workflow.nodes.llm import llm_utils +from extensions.ext_database import db +from factories.file_factory import build_from_mapping from libs import helper +from models import UploadFile from models.account import Account -from models.dataset import Dataset, DatasetProcessRule +from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import Rule +from services.summary_index_service import SummaryIndexService class ParagraphIndexProcessor(BaseIndexProcessor): @@ -108,6 +131,29 @@ class ParagraphIndexProcessor(BaseIndexProcessor): keyword.add_texts(documents) def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): + # Note: Summary indexes are now disabled (not deleted) when segments are disabled. + # This method is called for actual deletion scenarios (e.g., when segment is deleted). + # For disable operations, disable_summaries_for_segments is called directly in the task. + # Only delete summaries if explicitly requested (e.g., when segment is actually deleted) + delete_summaries = kwargs.get("delete_summaries", False) + if delete_summaries: + if node_ids: + # Find segments by index_node_id + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ) + .all() + ) + segment_ids = [segment.id for segment in segments] + if segment_ids: + SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) + else: + # Delete all summaries for the dataset + SummaryIndexService.delete_summaries_for_segments(dataset, None) + if dataset.indexing_technique == "high_quality": vector = Vector(dataset) if node_ids: @@ -227,3 +273,322 @@ class ParagraphIndexProcessor(BaseIndexProcessor): } else: raise ValueError("Chunks is not a list") + + def generate_summary_preview( + self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + ) -> list[PreviewDetail]: + """ + For each segment, concurrently call generate_summary to generate a summary + and write it to the summary attribute of PreviewDetail. + In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception. + """ + import concurrent.futures + + from flask import current_app + + # Capture Flask app context for worker threads + flask_app = None + try: + flask_app = current_app._get_current_object() # type: ignore + except RuntimeError: + logger.warning("No Flask application context available, summary generation may fail") + + def process(preview: PreviewDetail) -> None: + """Generate summary for a single preview item.""" + if flask_app: + # Ensure Flask app context in worker thread + with flask_app.app_context(): + summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting) + preview.summary = summary + else: + # Fallback: try without app context (may fail) + summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting) + preview.summary = summary + + # Generate summaries concurrently using ThreadPoolExecutor + # Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total) + timeout_seconds = min(300, 60 * len(preview_texts)) + errors: list[Exception] = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor: + futures = [executor.submit(process, preview) for preview in preview_texts] + # Wait for all tasks to complete with timeout + done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds) + + # Cancel tasks that didn't complete in time + if not_done: + timeout_error_msg = ( + f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s" + ) + logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg) + # In preview mode, timeout is also an error + errors.append(TimeoutError(timeout_error_msg)) + for future in not_done: + future.cancel() + # Wait a bit for cancellation to take effect + concurrent.futures.wait(not_done, timeout=5) + + # Collect exceptions from completed futures + for future in done: + try: + future.result() # This will raise any exception that occurred + except Exception as e: + logger.exception("Error in summary generation future") + errors.append(e) + + # In preview mode (indexing-estimate), if there are any errors, fail the request + if errors: + error_messages = [str(e) for e in errors] + error_summary = ( + f"Failed to generate summaries for {len(errors)} chunk(s). " + f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors + ) + if len(errors) > 3: + error_summary += f" (and {len(errors) - 3} more)" + logger.error("Summary generation failed in preview mode: %s", error_summary) + raise ValueError(error_summary) + + return preview_texts + + @staticmethod + def generate_summary( + tenant_id: str, + text: str, + summary_index_setting: dict | None = None, + segment_id: str | None = None, + ) -> tuple[str, LLMUsage]: + """ + Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt, + and supports vision models by including images from the segment attachments or text content. + + Args: + tenant_id: Tenant ID + text: Text content to summarize + summary_index_setting: Summary index configuration + segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table + + Returns: + Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object + """ + if not summary_index_setting or not summary_index_setting.get("enable"): + raise ValueError("summary_index_setting is required and must be enabled to generate summary.") + + model_name = summary_index_setting.get("model_name") + model_provider_name = summary_index_setting.get("model_provider_name") + summary_prompt = summary_index_setting.get("summary_prompt") + + if not model_name or not model_provider_name: + raise ValueError("model_name and model_provider_name are required in summary_index_setting") + + # Import default summary prompt + if not summary_prompt: + summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT + + provider_manager = ProviderManager() + provider_model_bundle = provider_manager.get_provider_model_bundle( + tenant_id, model_provider_name, ModelType.LLM + ) + model_instance = ModelInstance(provider_model_bundle, model_name) + + # Get model schema to check if vision is supported + model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials) + supports_vision = model_schema and model_schema.features and ModelFeature.VISION in model_schema.features + + # Extract images if model supports vision + image_files = [] + if supports_vision: + # First, try to get images from SegmentAttachmentBinding (preferred method) + if segment_id: + image_files = ParagraphIndexProcessor._extract_images_from_segment_attachments(tenant_id, segment_id) + + # If no images from attachments, fall back to extracting from text + if not image_files: + image_files = ParagraphIndexProcessor._extract_images_from_text(tenant_id, text) + + # Build prompt messages + prompt_messages = [] + + if image_files: + # If we have images, create a UserPromptMessage with both text and images + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] + + # Add images first + for file in image_files: + try: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=ImagePromptMessageContent.DETAIL.LOW + ) + prompt_message_contents.append(file_content) + except Exception as e: + logger.warning("Failed to convert image file to prompt message content: %s", str(e)) + continue + + # Add text content + if prompt_message_contents: # Only add text if we successfully added images + prompt_message_contents.append(TextPromptMessageContent(data=f"{summary_prompt}\n{text}")) + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + # If image conversion failed, fall back to text-only + prompt = f"{summary_prompt}\n{text}" + prompt_messages.append(UserPromptMessage(content=prompt)) + else: + # No images, use simple text prompt + prompt = f"{summary_prompt}\n{text}" + prompt_messages.append(UserPromptMessage(content=prompt)) + + result = model_instance.invoke_llm( + prompt_messages=cast(list[PromptMessage], prompt_messages), model_parameters={}, stream=False + ) + + # Type assertion: when stream=False, invoke_llm returns LLMResult, not Generator + if not isinstance(result, LLMResult): + raise ValueError("Expected LLMResult when stream=False") + + summary_content = getattr(result.message, "content", "") + usage = result.usage + + # Deduct quota for summary generation (same as workflow nodes) + try: + llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) + except Exception as e: + # Log but don't fail summary generation if quota deduction fails + logger.warning("Failed to deduct quota for summary generation: %s", str(e)) + + return summary_content, usage + + @staticmethod + def _extract_images_from_text(tenant_id: str, text: str) -> list[File]: + """ + Extract images from markdown text and convert them to File objects. + + Args: + tenant_id: Tenant ID + text: Text content that may contain markdown image links + + Returns: + List of File objects representing images found in the text + """ + # Extract markdown images using regex pattern + pattern = r"!\[.*?\]\((.*?)\)" + images = re.findall(pattern, text) + + if not images: + return [] + + upload_file_id_list = [] + + for image in images: + # For data before v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?" + match = re.search(pattern, image) + if match: + upload_file_id = match.group(1) + upload_file_id_list.append(upload_file_id) + continue + + # For data after v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?" + match = re.search(pattern, image) + if match: + upload_file_id = match.group(1) + upload_file_id_list.append(upload_file_id) + continue + + # For tools directory - direct file formats (e.g., .png, .jpg, etc.) + pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?" + match = re.search(pattern, image) + if match: + # Tool files are handled differently, skip for now + continue + + if not upload_file_id_list: + return [] + + # Get unique IDs for database query + unique_upload_file_ids = list(set(upload_file_id_list)) + upload_files = ( + db.session.query(UploadFile) + .where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id) + .all() + ) + + # Create File objects from UploadFile records + file_objects = [] + for upload_file in upload_files: + # Only process image files + if not upload_file.mime_type or "image" not in upload_file.mime_type: + continue + + mapping = { + "upload_file_id": upload_file.id, + "transfer_method": FileTransferMethod.LOCAL_FILE.value, + "type": FileType.IMAGE.value, + } + + try: + file_obj = build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + file_objects.append(file_obj) + except Exception as e: + logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e)) + continue + + return file_objects + + @staticmethod + def _extract_images_from_segment_attachments(tenant_id: str, segment_id: str) -> list[File]: + """ + Extract images from SegmentAttachmentBinding table (preferred method). + This matches how DatasetRetrieval gets segment attachments. + + Args: + tenant_id: Tenant ID + segment_id: Segment ID to fetch attachments for + + Returns: + List of File objects representing images found in segment attachments + """ + from sqlalchemy import select + + # Query attachments from SegmentAttachmentBinding table + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.segment_id == segment_id, + SegmentAttachmentBinding.tenant_id == tenant_id, + ) + ).all() + + if not attachments_with_bindings: + return [] + + file_objects = [] + for _, upload_file in attachments_with_bindings: + # Only process image files + if not upload_file.mime_type or "image" not in upload_file.mime_type: + continue + + try: + # Create File object directly (similar to DatasetRetrieval) + file_obj = File( + id=upload_file.id, + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + remote_url=upload_file.source_url, + related_id=upload_file.id, + size=upload_file.size, + storage_key=upload_file.key, + ) + file_objects.append(file_obj) + except Exception as e: + logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e)) + continue + + return file_objects diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 0366f3259f..961df2e50c 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -1,11 +1,14 @@ """Paragraph index processor.""" import json +import logging import uuid from collections.abc import Mapping from typing import Any from configs import dify_config +from core.db.session_factory import session_factory +from core.entities.knowledge_entities import PreviewDetail from core.model_manager import ModelInstance from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.retrieval_service import RetrievalService @@ -25,6 +28,9 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm from models.dataset import Document as DatasetDocument from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule +from services.summary_index_service import SummaryIndexService + +logger = logging.getLogger(__name__) class ParentChildIndexProcessor(BaseIndexProcessor): @@ -135,6 +141,30 @@ class ParentChildIndexProcessor(BaseIndexProcessor): def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): # node_ids is segment's node_ids + # Note: Summary indexes are now disabled (not deleted) when segments are disabled. + # This method is called for actual deletion scenarios (e.g., when segment is deleted). + # For disable operations, disable_summaries_for_segments is called directly in the task. + # Only delete summaries if explicitly requested (e.g., when segment is actually deleted) + delete_summaries = kwargs.get("delete_summaries", False) + if delete_summaries: + if node_ids: + # Find segments by index_node_id + with session_factory.create_session() as session: + segments = ( + session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ) + .all() + ) + segment_ids = [segment.id for segment in segments] + if segment_ids: + SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) + else: + # Delete all summaries for the dataset + SummaryIndexService.delete_summaries_for_segments(dataset, None) + if dataset.indexing_technique == "high_quality": delete_child_chunks = kwargs.get("delete_child_chunks") or False precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids") @@ -326,3 +356,91 @@ class ParentChildIndexProcessor(BaseIndexProcessor): "preview": preview, "total_segments": len(parent_childs.parent_child_chunks), } + + def generate_summary_preview( + self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + ) -> list[PreviewDetail]: + """ + For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary + and write it to the summary attribute of PreviewDetail. + In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception. + + Note: For parent-child structure, we only generate summaries for parent chunks. + """ + import concurrent.futures + + from flask import current_app + + # Capture Flask app context for worker threads + flask_app = None + try: + flask_app = current_app._get_current_object() # type: ignore + except RuntimeError: + logger.warning("No Flask application context available, summary generation may fail") + + def process(preview: PreviewDetail) -> None: + """Generate summary for a single preview item (parent chunk).""" + from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor + + if flask_app: + # Ensure Flask app context in worker thread + with flask_app.app_context(): + summary, _ = ParagraphIndexProcessor.generate_summary( + tenant_id=tenant_id, + text=preview.content, + summary_index_setting=summary_index_setting, + ) + preview.summary = summary + else: + # Fallback: try without app context (may fail) + summary, _ = ParagraphIndexProcessor.generate_summary( + tenant_id=tenant_id, + text=preview.content, + summary_index_setting=summary_index_setting, + ) + preview.summary = summary + + # Generate summaries concurrently using ThreadPoolExecutor + # Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total) + timeout_seconds = min(300, 60 * len(preview_texts)) + errors: list[Exception] = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor: + futures = [executor.submit(process, preview) for preview in preview_texts] + # Wait for all tasks to complete with timeout + done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds) + + # Cancel tasks that didn't complete in time + if not_done: + timeout_error_msg = ( + f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s" + ) + logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg) + # In preview mode, timeout is also an error + errors.append(TimeoutError(timeout_error_msg)) + for future in not_done: + future.cancel() + # Wait a bit for cancellation to take effect + concurrent.futures.wait(not_done, timeout=5) + + # Collect exceptions from completed futures + for future in done: + try: + future.result() # This will raise any exception that occurred + except Exception as e: + logger.exception("Error in summary generation future") + errors.append(e) + + # In preview mode (indexing-estimate), if there are any errors, fail the request + if errors: + error_messages = [str(e) for e in errors] + error_summary = ( + f"Failed to generate summaries for {len(errors)} chunk(s). " + f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors + ) + if len(errors) > 3: + error_summary += f" (and {len(errors) - 3} more)" + logger.error("Summary generation failed in preview mode: %s", error_summary) + raise ValueError(error_summary) + + return preview_texts diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 1183d5fbd7..272d2ed351 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -11,6 +11,8 @@ import pandas as pd from flask import Flask, current_app from werkzeug.datastructures import FileStorage +from core.db.session_factory import session_factory +from core.entities.knowledge_entities import PreviewDetail from core.llm_generator.llm_generator import LLMGenerator from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.retrieval_service import RetrievalService @@ -25,9 +27,10 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper from models.account import Account -from models.dataset import Dataset +from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import Rule +from services.summary_index_service import SummaryIndexService logger = logging.getLogger(__name__) @@ -144,6 +147,31 @@ class QAIndexProcessor(BaseIndexProcessor): vector.create_multimodal(multimodal_documents) def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): + # Note: Summary indexes are now disabled (not deleted) when segments are disabled. + # This method is called for actual deletion scenarios (e.g., when segment is deleted). + # For disable operations, disable_summaries_for_segments is called directly in the task. + # Note: qa_model doesn't generate summaries, but we clean them for completeness + # Only delete summaries if explicitly requested (e.g., when segment is actually deleted) + delete_summaries = kwargs.get("delete_summaries", False) + if delete_summaries: + if node_ids: + # Find segments by index_node_id + with session_factory.create_session() as session: + segments = ( + session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ) + .all() + ) + segment_ids = [segment.id for segment in segments] + if segment_ids: + SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) + else: + # Delete all summaries for the dataset + SummaryIndexService.delete_summaries_for_segments(dataset, None) + vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) @@ -212,6 +240,17 @@ class QAIndexProcessor(BaseIndexProcessor): "total_segments": len(qa_chunks.qa_chunks), } + def generate_summary_preview( + self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + ) -> list[PreviewDetail]: + """ + QA model doesn't generate summaries, so this method returns preview_texts unchanged. + + Note: QA model uses question-answer pairs, which don't require summary generation. + """ + # QA model doesn't generate summaries, return as-is + return preview_texts + def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): format_documents = [] if document_node.page_content is None or not document_node.page_content.strip(): diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index f8f85d141a..541c241ae5 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -236,20 +236,24 @@ class DatasetRetrieval: if records: for record in records: segment = record.segment + # Build content: if summary exists, add it before the segment content if segment.answer: - document_context_list.append( - DocumentContext( - content=f"question:{segment.get_sign_content()} answer:{segment.answer}", - score=record.score, - ) - ) + segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}" else: - document_context_list.append( - DocumentContext( - content=segment.get_sign_content(), - score=record.score, - ) + segment_content = segment.get_sign_content() + + # If summary exists, prepend it to the content + if record.summary: + final_content = f"{record.summary}\n{segment_content}" + else: + final_content = segment_content + + document_context_list.append( + DocumentContext( + content=final_content, + score=record.score, ) + ) if vision_enabled: attachments_with_bindings = db.session.execute( select(SegmentAttachmentBinding, UploadFile) @@ -316,6 +320,9 @@ class DatasetRetrieval: source.content = f"question:{segment.content} \nanswer:{segment.answer}" else: source.content = segment.content + # Add summary if this segment was retrieved via summary + if hasattr(record, "summary") and record.summary: + source.summary = record.summary retrieval_resource_list.append(source) if hit_callback and retrieval_resource_list: retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True) diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index f96510fb45..057ec41f65 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -169,20 +169,24 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): if records: for record in records: segment = record.segment + # Build content: if summary exists, add it before the segment content if segment.answer: - document_context_list.append( - DocumentContext( - content=f"question:{segment.get_sign_content()} answer:{segment.answer}", - score=record.score, - ) - ) + segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}" else: - document_context_list.append( - DocumentContext( - content=segment.get_sign_content(), - score=record.score, - ) + segment_content = segment.get_sign_content() + + # If summary exists, prepend it to the content + if record.summary: + final_content = f"{record.summary}\n{segment_content}" + else: + final_content = segment_content + + document_context_list.append( + DocumentContext( + content=final_content, + score=record.score, ) + ) if self.return_resource: for record in records: @@ -216,6 +220,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): source.content = f"question:{segment.content} \nanswer:{segment.answer}" else: source.content = segment.content + # Add summary if this segment was retrieved via summary + if hasattr(record, "summary") and record.summary: + source.summary = record.summary retrieval_resource_list.append(source) if self.return_resource and retrieval_resource_list: diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 3daca90b9b..bfeb9b5b79 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -158,3 +158,5 @@ class KnowledgeIndexNodeData(BaseNodeData): type: str = "knowledge-index" chunk_structure: str index_chunk_variable_selector: list[str] + indexing_technique: str | None = None + summary_index_setting: dict | None = None diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 17ca4bef7b..b88c2d510f 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,9 +1,11 @@ +import concurrent.futures import datetime import logging import time from collections.abc import Mapping from typing import Any +from flask import current_app from sqlalchemy import func, select from core.app.entities.app_invoke_entities import InvokeFrom @@ -16,7 +18,9 @@ from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.template import Template from core.workflow.runtime import VariablePool from extensions.ext_database import db -from models.dataset import Dataset, Document, DocumentSegment +from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary +from services.summary_index_service import SummaryIndexService +from tasks.generate_summary_index_task import generate_summary_index_task from .entities import KnowledgeIndexNodeData from .exc import ( @@ -67,7 +71,20 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): # index knowledge try: if is_preview: - outputs = self._get_preview_output(node_data.chunk_structure, chunks) + # Preview mode: generate summaries for chunks directly without saving to database + # Format preview and generate summaries on-the-fly + # Get indexing_technique and summary_index_setting from node_data (workflow graph config) + # or fallback to dataset if not available in node_data + indexing_technique = node_data.indexing_technique or dataset.indexing_technique + summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting + + outputs = self._get_preview_output_with_summaries( + node_data.chunk_structure, + chunks, + dataset=dataset, + indexing_technique=indexing_technique, + summary_index_setting=summary_index_setting, + ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, @@ -148,6 +165,11 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): ) .scalar() ) + # Update need_summary based on dataset's summary_index_setting + if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True: + document.need_summary = True + else: + document.need_summary = False db.session.add(document) # update document segment status db.session.query(DocumentSegment).where( @@ -163,6 +185,9 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): db.session.commit() + # Generate summary index if enabled + self._handle_summary_index_generation(dataset, document, variable_pool) + return { "dataset_id": ds_id_value, "dataset_name": dataset_name_value, @@ -173,9 +198,304 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): "display_status": "completed", } - def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]: + def _handle_summary_index_generation( + self, + dataset: Dataset, + document: Document, + variable_pool: VariablePool, + ) -> None: + """ + Handle summary index generation based on mode (debug/preview or production). + + Args: + dataset: Dataset containing the document + document: Document to generate summaries for + variable_pool: Variable pool to check invoke_from + """ + # Only generate summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + return + + # Check if summary index is enabled + summary_index_setting = dataset.summary_index_setting + if not summary_index_setting or not summary_index_setting.get("enable"): + return + + # Skip qa_model documents + if document.doc_form == "qa_model": + return + + # Determine if in preview/debug mode + invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) + is_preview = invoke_from and invoke_from.value == InvokeFrom.DEBUGGER + + if is_preview: + try: + # Query segments that need summary generation + query = db.session.query(DocumentSegment).filter_by( + dataset_id=dataset.id, + document_id=document.id, + status="completed", + enabled=True, + ) + segments = query.all() + + if not segments: + logger.info("No segments found for document %s", document.id) + return + + # Filter segments based on mode + segments_to_process = [] + for segment in segments: + # Skip if summary already exists + existing_summary = ( + db.session.query(DocumentSegmentSummary) + .filter_by(chunk_id=segment.id, dataset_id=dataset.id, status="completed") + .first() + ) + if existing_summary: + continue + + # For parent-child mode, all segments are parent chunks, so process all + segments_to_process.append(segment) + + if not segments_to_process: + logger.info("No segments need summary generation for document %s", document.id) + return + + # Use ThreadPoolExecutor for concurrent generation + flask_app = current_app._get_current_object() # type: ignore + max_workers = min(10, len(segments_to_process)) # Limit to 10 workers + + def process_segment(segment: DocumentSegment) -> None: + """Process a single segment in a thread with Flask app context.""" + with flask_app.app_context(): + try: + SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting) + except Exception: + logger.exception( + "Failed to generate summary for segment %s", + segment.id, + ) + # Continue processing other segments + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(process_segment, segment) for segment in segments_to_process] + # Wait for all tasks to complete + concurrent.futures.wait(futures) + + logger.info( + "Successfully generated summary index for %s segments in document %s", + len(segments_to_process), + document.id, + ) + except Exception: + logger.exception("Failed to generate summary index for document %s", document.id) + # Don't fail the entire indexing process if summary generation fails + else: + # Production mode: asynchronous generation + logger.info( + "Queuing summary index generation task for document %s (production mode)", + document.id, + ) + try: + generate_summary_index_task.delay(dataset.id, document.id, None) + logger.info("Summary index generation task queued for document %s", document.id) + except Exception: + logger.exception( + "Failed to queue summary index generation task for document %s", + document.id, + ) + # Don't fail the entire indexing process if task queuing fails + + def _get_preview_output_with_summaries( + self, + chunk_structure: str, + chunks: Any, + dataset: Dataset, + indexing_technique: str | None = None, + summary_index_setting: dict | None = None, + ) -> Mapping[str, Any]: + """ + Generate preview output with summaries for chunks in preview mode. + This method generates summaries on-the-fly without saving to database. + + Args: + chunk_structure: Chunk structure type + chunks: Chunks to generate preview for + dataset: Dataset object (for tenant_id) + indexing_technique: Indexing technique from node config or dataset + summary_index_setting: Summary index setting from node config or dataset + """ index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() - return index_processor.format_preview(chunks) + preview_output = index_processor.format_preview(chunks) + + # Check if summary index is enabled + if indexing_technique != "high_quality": + return preview_output + + if not summary_index_setting or not summary_index_setting.get("enable"): + return preview_output + + # Generate summaries for chunks + if "preview" in preview_output and isinstance(preview_output["preview"], list): + chunk_count = len(preview_output["preview"]) + logger.info( + "Generating summaries for %s chunks in preview mode (dataset: %s)", + chunk_count, + dataset.id, + ) + # Use ParagraphIndexProcessor's generate_summary method + from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor + + # Get Flask app for application context in worker threads + flask_app = None + try: + flask_app = current_app._get_current_object() # type: ignore + except RuntimeError: + logger.warning("No Flask application context available, summary generation may fail") + + def generate_summary_for_chunk(preview_item: dict) -> None: + """Generate summary for a single chunk.""" + if "content" in preview_item: + # Set Flask application context in worker thread + if flask_app: + with flask_app.app_context(): + summary, _ = ParagraphIndexProcessor.generate_summary( + tenant_id=dataset.tenant_id, + text=preview_item["content"], + summary_index_setting=summary_index_setting, + ) + if summary: + preview_item["summary"] = summary + else: + # Fallback: try without app context (may fail) + summary, _ = ParagraphIndexProcessor.generate_summary( + tenant_id=dataset.tenant_id, + text=preview_item["content"], + summary_index_setting=summary_index_setting, + ) + if summary: + preview_item["summary"] = summary + + # Generate summaries concurrently using ThreadPoolExecutor + # Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total) + timeout_seconds = min(300, 60 * len(preview_output["preview"])) + errors: list[Exception] = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output["preview"]))) as executor: + futures = [ + executor.submit(generate_summary_for_chunk, preview_item) + for preview_item in preview_output["preview"] + ] + # Wait for all tasks to complete with timeout + done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds) + + # Cancel tasks that didn't complete in time + if not_done: + timeout_error_msg = ( + f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s" + ) + logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg) + # In preview mode, timeout is also an error + errors.append(TimeoutError(timeout_error_msg)) + for future in not_done: + future.cancel() + # Wait a bit for cancellation to take effect + concurrent.futures.wait(not_done, timeout=5) + + # Collect exceptions from completed futures + for future in done: + try: + future.result() # This will raise any exception that occurred + except Exception as e: + logger.exception("Error in summary generation future") + errors.append(e) + + # In preview mode, if there are any errors, fail the request + if errors: + error_messages = [str(e) for e in errors] + error_summary = ( + f"Failed to generate summaries for {len(errors)} chunk(s). " + f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors + ) + if len(errors) > 3: + error_summary += f" (and {len(errors) - 3} more)" + logger.error("Summary generation failed in preview mode: %s", error_summary) + raise KnowledgeIndexNodeError(error_summary) + + completed_count = sum(1 for item in preview_output["preview"] if item.get("summary") is not None) + logger.info( + "Completed summary generation for preview chunks: %s/%s succeeded", + completed_count, + len(preview_output["preview"]), + ) + + return preview_output + + def _get_preview_output( + self, + chunk_structure: str, + chunks: Any, + dataset: Dataset | None = None, + variable_pool: VariablePool | None = None, + ) -> Mapping[str, Any]: + index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() + preview_output = index_processor.format_preview(chunks) + + # If dataset is provided, try to enrich preview with summaries + if dataset and variable_pool: + document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + if document_id: + document = db.session.query(Document).filter_by(id=document_id.value).first() + if document: + # Query summaries for this document + summaries = ( + db.session.query(DocumentSegmentSummary) + .filter_by( + dataset_id=dataset.id, + document_id=document.id, + status="completed", + enabled=True, + ) + .all() + ) + + if summaries: + # Create a map of segment content to summary for matching + # Use content matching as chunks in preview might not be indexed yet + summary_by_content = {} + for summary in summaries: + segment = ( + db.session.query(DocumentSegment) + .filter_by(id=summary.chunk_id, dataset_id=dataset.id) + .first() + ) + if segment: + # Normalize content for matching (strip whitespace) + normalized_content = segment.content.strip() + summary_by_content[normalized_content] = summary.summary_content + + # Enrich preview with summaries by content matching + if "preview" in preview_output and isinstance(preview_output["preview"], list): + matched_count = 0 + for preview_item in preview_output["preview"]: + if "content" in preview_item: + # Normalize content for matching + normalized_chunk_content = preview_item["content"].strip() + if normalized_chunk_content in summary_by_content: + preview_item["summary"] = summary_by_content[normalized_chunk_content] + matched_count += 1 + + if matched_count > 0: + logger.info( + "Enriched preview with %s existing summaries (dataset: %s, document: %s)", + matched_count, + dataset.id, + document.id, + ) + + return preview_output @classmethod def version(cls) -> str: diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 8670a71aa3..3c4850ebac 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -419,6 +419,9 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}" else: source["content"] = segment.get_sign_content() + # Add summary if available + if record.summary: + source["summary"] = record.summary retrieval_resource_list.append(source) if retrieval_resource_list: retrieval_resource_list = sorted( diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index dfb55dcd80..17d82c2118 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -685,6 +685,8 @@ class LLMNode(Node[LLMNodeData]): if "content" not in item: raise InvalidContextStructureError(f"Invalid context structure: {item}") + if item.get("summary"): + context_str += item["summary"] + "\n" context_str += item["content"] + "\n" retriever_resource = self._convert_to_original_retriever_resource(item) @@ -746,6 +748,7 @@ class LLMNode(Node[LLMNodeData]): page=metadata.get("page"), doc_metadata=metadata.get("doc_metadata"), files=context_dict.get("files"), + summary=context_dict.get("summary"), ) return source diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 08cf96c1c1..af983f6d87 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -102,6 +102,8 @@ def init_app(app: DifyApp) -> Celery: imports = [ "tasks.async_workflow_tasks", # trigger workers "tasks.trigger_processing_tasks", # async trigger processing + "tasks.generate_summary_index_task", # summary index generation + "tasks.regenerate_summary_index_task", # summary index regeneration ] day = dify_config.CELERY_BEAT_SCHEDULER_TIME diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 1e5ec7d200..ff6578098b 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -39,6 +39,14 @@ dataset_retrieval_model_fields = { "score_threshold_enabled": fields.Boolean, "score_threshold": fields.Float, } + +dataset_summary_index_fields = { + "enable": fields.Boolean, + "model_name": fields.String, + "model_provider_name": fields.String, + "summary_prompt": fields.String, +} + external_retrieval_model_fields = { "top_k": fields.Integer, "score_threshold": fields.Float, @@ -83,6 +91,7 @@ dataset_detail_fields = { "embedding_model_provider": fields.String, "embedding_available": fields.Boolean, "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), + "summary_index_setting": fields.Nested(dataset_summary_index_fields), "tags": fields.List(fields.Nested(tag_fields)), "doc_form": fields.String, "external_knowledge_info": fields.Nested(external_knowledge_info_fields), diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index 9be59f7454..35a2a04f3e 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -33,6 +33,11 @@ document_fields = { "hit_count": fields.Integer, "doc_form": fields.String, "doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"), + # Summary index generation status: + # "SUMMARIZING" (when task is queued and generating) + "summary_index_status": fields.String, + # Whether this document needs summary index generation + "need_summary": fields.Boolean, } document_with_segments_fields = { @@ -60,6 +65,10 @@ document_with_segments_fields = { "completed_segments": fields.Integer, "total_segments": fields.Integer, "doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"), + # Summary index generation status: + # "SUMMARIZING" (when task is queued and generating) + "summary_index_status": fields.String, + "need_summary": fields.Boolean, # Whether this document needs summary index generation } dataset_and_document_fields = { diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index e70f9fa722..0b54992835 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -58,4 +58,5 @@ hit_testing_record_fields = { "score": fields.Float, "tsne_position": fields.Raw, "files": fields.List(fields.Nested(files_fields)), + "summary": fields.String, # Summary content if retrieved via summary index } diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index c81e482f73..e6c3b42f93 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -36,6 +36,7 @@ class RetrieverResource(ResponseModel): segment_position: int | None = None index_node_hash: str | None = None content: str | None = None + summary: str | None = None created_at: int | None = None @field_validator("created_at", mode="before") diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 56d6b68378..2ce9fb154c 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -49,4 +49,5 @@ segment_fields = { "stopped_at": TimestampField, "child_chunks": fields.List(fields.Nested(child_chunk_fields)), "attachments": fields.List(fields.Nested(attachment_fields)), + "summary": fields.String, # Summary content for the segment } diff --git a/api/migrations/versions/2026_01_27_1815-788d3099ae3a_add_summary_index_feature.py b/api/migrations/versions/2026_01_27_1815-788d3099ae3a_add_summary_index_feature.py new file mode 100644 index 0000000000..3c2e0822e1 --- /dev/null +++ b/api/migrations/versions/2026_01_27_1815-788d3099ae3a_add_summary_index_feature.py @@ -0,0 +1,107 @@ +"""add summary index feature + +Revision ID: 788d3099ae3a +Revises: 9d77545f524e +Create Date: 2026-01-27 18:15:45.277928 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + +def _is_pg(conn): + return conn.dialect.name == "postgresql" + +# revision identifiers, used by Alembic. +revision = '788d3099ae3a' +down_revision = '9d77545f524e' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + conn = op.get_bind() + if _is_pg(conn): + op.create_table('document_segment_summaries', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('chunk_id', models.types.StringUUID(), nullable=False), + sa.Column('summary_content', models.types.LongText(), nullable=True), + sa.Column('summary_index_node_id', sa.String(length=255), nullable=True), + sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True), + sa.Column('tokens', sa.Integer(), nullable=True), + sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False), + sa.Column('error', models.types.LongText(), nullable=True), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('disabled_at', sa.DateTime(), nullable=True), + sa.Column('disabled_by', models.types.StringUUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='document_segment_summaries_pkey') + ) + with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op: + batch_op.create_index('document_segment_summaries_chunk_id_idx', ['chunk_id'], unique=False) + batch_op.create_index('document_segment_summaries_dataset_id_idx', ['dataset_id'], unique=False) + batch_op.create_index('document_segment_summaries_document_id_idx', ['document_id'], unique=False) + batch_op.create_index('document_segment_summaries_status_idx', ['status'], unique=False) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True)) + + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=True)) + else: + # MySQL: Use compatible syntax + op.create_table( + 'document_segment_summaries', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('chunk_id', models.types.StringUUID(), nullable=False), + sa.Column('summary_content', models.types.LongText(), nullable=True), + sa.Column('summary_index_node_id', sa.String(length=255), nullable=True), + sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True), + sa.Column('tokens', sa.Integer(), nullable=True), + sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False), + sa.Column('error', models.types.LongText(), nullable=True), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('disabled_at', sa.DateTime(), nullable=True), + sa.Column('disabled_by', models.types.StringUUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='document_segment_summaries_pkey'), + ) + with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op: + batch_op.create_index('document_segment_summaries_chunk_id_idx', ['chunk_id'], unique=False) + batch_op.create_index('document_segment_summaries_dataset_id_idx', ['dataset_id'], unique=False) + batch_op.create_index('document_segment_summaries_document_id_idx', ['document_id'], unique=False) + batch_op.create_index('document_segment_summaries_status_idx', ['status'], unique=False) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True)) + + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.drop_column('need_summary') + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.drop_column('summary_index_setting') + + with op.batch_alter_table('document_segment_summaries', schema=None) as batch_op: + batch_op.drop_index('document_segment_summaries_status_idx') + batch_op.drop_index('document_segment_summaries_document_id_idx') + batch_op.drop_index('document_segment_summaries_dataset_id_idx') + batch_op.drop_index('document_segment_summaries_chunk_id_idx') + + op.drop_table('document_segment_summaries') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index 62f11b8c72..6ab8f372bf 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -72,6 +72,7 @@ class Dataset(Base): keyword_number = mapped_column(sa.Integer, nullable=True, server_default=sa.text("10")) collection_binding_id = mapped_column(StringUUID, nullable=True) retrieval_model = mapped_column(AdjustedJSON, nullable=True) + summary_index_setting = mapped_column(AdjustedJSON, nullable=True) built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) icon_info = mapped_column(AdjustedJSON, nullable=True) runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'")) @@ -419,6 +420,7 @@ class Document(Base): doc_metadata = mapped_column(AdjustedJSON, nullable=True) doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'")) doc_language = mapped_column(String(255), nullable=True) + need_summary: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] @@ -1575,3 +1577,36 @@ class SegmentAttachmentBinding(Base): segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class DocumentSegmentSummary(Base): + __tablename__ = "document_segment_summaries" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"), + sa.Index("document_segment_summaries_dataset_id_idx", "dataset_id"), + sa.Index("document_segment_summaries_document_id_idx", "document_id"), + sa.Index("document_segment_summaries_chunk_id_idx", "chunk_id"), + sa.Index("document_segment_summaries_status_idx", "status"), + ) + + id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) + dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + # corresponds to DocumentSegment.id or parent chunk id + chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + summary_content: Mapped[str] = mapped_column(LongText, nullable=True) + summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True) + summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) + status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'")) + error: Mapped[str] = mapped_column(LongText, nullable=True) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + disabled_by = mapped_column(StringUUID, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + + def __repr__(self): + return f"" diff --git a/api/models/model.py b/api/models/model.py index be0cfd58a7..c1c6e04ce9 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -657,16 +657,22 @@ class AccountTrialAppRecord(Base): return user -class ExporleBanner(Base): +class ExporleBanner(TypeBase): __tablename__ = "exporle_banners" __table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),) - id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) - content = mapped_column(sa.JSON, nullable=False) - link = mapped_column(String(255), nullable=False) - sort = mapped_column(sa.Integer, nullable=False) - status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying")) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) + content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) + link: Mapped[str] = mapped_column(String(255), nullable=False) + sort: Mapped[int] = mapped_column(sa.Integer, nullable=False) + status: Mapped[str] = mapped_column( + sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled" + ) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + language: Mapped[str] = mapped_column( + String(255), nullable=False, server_default=sa.text("'en-US'::character varying"), default="en-US" + ) class OAuthProviderApp(TypeBase): diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index be9a0e9279..0b3fcbe4ae 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -89,6 +89,7 @@ from tasks.disable_segments_from_index_task import disable_segments_from_index_t from tasks.document_indexing_update_task import document_indexing_update_task from tasks.enable_segments_to_index_task import enable_segments_to_index_task from tasks.recover_document_indexing_task import recover_document_indexing_task +from tasks.regenerate_summary_index_task import regenerate_summary_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task from tasks.retry_document_indexing_task import retry_document_indexing_task from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task @@ -211,6 +212,7 @@ class DatasetService: embedding_model_provider: str | None = None, embedding_model_name: str | None = None, retrieval_model: RetrievalModel | None = None, + summary_index_setting: dict | None = None, ): # check if dataset name already exists if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): @@ -253,6 +255,8 @@ class DatasetService: dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None dataset.permission = permission or DatasetPermissionEnum.ONLY_ME dataset.provider = provider + if summary_index_setting is not None: + dataset.summary_index_setting = summary_index_setting db.session.add(dataset) db.session.flush() @@ -476,6 +480,11 @@ class DatasetService: if external_retrieval_model: dataset.retrieval_model = external_retrieval_model + # Update summary index setting if provided + summary_index_setting = data.get("summary_index_setting", None) + if summary_index_setting is not None: + dataset.summary_index_setting = summary_index_setting + # Update basic dataset properties dataset.name = data.get("name", dataset.name) dataset.description = data.get("description", dataset.description) @@ -564,6 +573,9 @@ class DatasetService: # update Retrieval model if data.get("retrieval_model"): filtered_data["retrieval_model"] = data["retrieval_model"] + # update summary index setting + if data.get("summary_index_setting"): + filtered_data["summary_index_setting"] = data.get("summary_index_setting") # update icon info if data.get("icon_info"): filtered_data["icon_info"] = data.get("icon_info") @@ -572,12 +584,27 @@ class DatasetService: db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) db.session.commit() + # Reload dataset to get updated values + db.session.refresh(dataset) + # update pipeline knowledge base node data DatasetService._update_pipeline_knowledge_base_node_data(dataset, user.id) # Trigger vector index task if indexing technique changed if action: deal_dataset_vector_index_task.delay(dataset.id, action) + # If embedding_model changed, also regenerate summary vectors + if action == "update": + regenerate_summary_index_task.delay( + dataset.id, + regenerate_reason="embedding_model_changed", + regenerate_vectors_only=True, + ) + + # Note: summary_index_setting changes do not trigger automatic regeneration of existing summaries. + # The new setting will only apply to: + # 1. New documents added after the setting change + # 2. Manual summary generation requests return dataset @@ -616,6 +643,7 @@ class DatasetService: knowledge_index_node_data["chunk_structure"] = dataset.chunk_structure knowledge_index_node_data["indexing_technique"] = dataset.indexing_technique # pyright: ignore[reportAttributeAccessIssue] knowledge_index_node_data["keyword_number"] = dataset.keyword_number + knowledge_index_node_data["summary_index_setting"] = dataset.summary_index_setting node["data"] = knowledge_index_node_data updated = True except Exception: @@ -854,6 +882,54 @@ class DatasetService: ) filtered_data["collection_binding_id"] = dataset_collection_binding.id + @staticmethod + def _check_summary_index_setting_model_changed(dataset: Dataset, data: dict[str, Any]) -> bool: + """ + Check if summary_index_setting model (model_name or model_provider_name) has changed. + + Args: + dataset: Current dataset object + data: Update data dictionary + + Returns: + bool: True if summary model changed, False otherwise + """ + # Check if summary_index_setting is being updated + if "summary_index_setting" not in data or data.get("summary_index_setting") is None: + return False + + new_summary_setting = data.get("summary_index_setting") + old_summary_setting = dataset.summary_index_setting + + # If new setting is disabled, no need to regenerate + if not new_summary_setting or not new_summary_setting.get("enable"): + return False + + # If old setting doesn't exist, no need to regenerate (no existing summaries to regenerate) + # Note: This task only regenerates existing summaries, not generates new ones + if not old_summary_setting: + return False + + # Compare model_name and model_provider_name + old_model_name = old_summary_setting.get("model_name") + old_model_provider = old_summary_setting.get("model_provider_name") + new_model_name = new_summary_setting.get("model_name") + new_model_provider = new_summary_setting.get("model_provider_name") + + # Check if model changed + if old_model_name != new_model_name or old_model_provider != new_model_provider: + logger.info( + "Summary index setting model changed for dataset %s: old=%s/%s, new=%s/%s", + dataset.id, + old_model_provider, + old_model_name, + new_model_provider, + new_model_name, + ) + return True + + return False + @staticmethod def update_rag_pipeline_dataset_settings( session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False @@ -889,6 +965,9 @@ class DatasetService: else: raise ValueError("Invalid index method") dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() + # Update summary_index_setting if provided + if knowledge_configuration.summary_index_setting is not None: + dataset.summary_index_setting = knowledge_configuration.summary_index_setting session.add(dataset) else: if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure: @@ -994,6 +1073,9 @@ class DatasetService: if dataset.keyword_number != knowledge_configuration.keyword_number: dataset.keyword_number = knowledge_configuration.keyword_number dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() + # Update summary_index_setting if provided + if knowledge_configuration.summary_index_setting is not None: + dataset.summary_index_setting = knowledge_configuration.summary_index_setting session.add(dataset) session.commit() if action: @@ -1314,6 +1396,50 @@ class DocumentService: upload_file = DocumentService._get_upload_file_for_upload_file_document(document) return file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True) + @staticmethod + def enrich_documents_with_summary_index_status( + documents: Sequence[Document], + dataset: Dataset, + tenant_id: str, + ) -> None: + """ + Enrich documents with summary_index_status based on dataset summary index settings. + + This method calculates and sets the summary_index_status for each document that needs summary. + Documents that don't need summary or when summary index is disabled will have status set to None. + + Args: + documents: List of Document instances to enrich + dataset: Dataset instance containing summary_index_setting + tenant_id: Tenant ID for summary status lookup + """ + # Check if dataset has summary index enabled + has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True + + # Filter documents that need summary calculation + documents_need_summary = [doc for doc in documents if doc.need_summary is True] + document_ids_need_summary = [str(doc.id) for doc in documents_need_summary] + + # Calculate summary_index_status for documents that need summary (only if dataset summary index is enabled) + summary_status_map: dict[str, str | None] = {} + if has_summary_index and document_ids_need_summary: + from services.summary_index_service import SummaryIndexService + + summary_status_map = SummaryIndexService.get_documents_summary_index_status( + document_ids=document_ids_need_summary, + dataset_id=dataset.id, + tenant_id=tenant_id, + ) + + # Add summary_index_status to each document + for document in documents: + if has_summary_index and document.need_summary is True: + # Get status from map, default to None (not queued yet) + document.summary_index_status = summary_status_map.get(str(document.id)) # type: ignore[attr-defined] + else: + # Return null if summary index is not enabled or document doesn't need summary + document.summary_index_status = None # type: ignore[attr-defined] + @staticmethod def prepare_document_batch_download_zip( *, @@ -1964,6 +2090,8 @@ class DocumentService: DuplicateDocumentIndexingTaskProxy( dataset.tenant_id, dataset.id, duplicate_document_ids ).delay() + # Note: Summary index generation is triggered in document_indexing_task after indexing completes + # to ensure segments are available. See tasks/document_indexing_task.py except LockNotOwnedError: pass @@ -2268,6 +2396,11 @@ class DocumentService: name: str, batch: str, ): + # Set need_summary based on dataset's summary_index_setting + need_summary = False + if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True: + need_summary = True + document = Document( tenant_id=dataset.tenant_id, dataset_id=dataset.id, @@ -2281,6 +2414,7 @@ class DocumentService: created_by=account.id, doc_form=document_form, doc_language=document_language, + need_summary=need_summary, ) doc_metadata = {} if dataset.built_in_field_enabled: @@ -2505,6 +2639,7 @@ class DocumentService: embedding_model_provider=knowledge_config.embedding_model_provider, collection_binding_id=dataset_collection_binding_id, retrieval_model=retrieval_model.model_dump() if retrieval_model else None, + summary_index_setting=knowledge_config.summary_index_setting, is_multimodal=knowledge_config.is_multimodal, ) @@ -2686,6 +2821,14 @@ class DocumentService: if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): raise ValueError("Process rule segmentation max_tokens is invalid") + # valid summary index setting + summary_index_setting = args["process_rule"].get("summary_index_setting") + if summary_index_setting and summary_index_setting.get("enable"): + if "model_name" not in summary_index_setting or not summary_index_setting["model_name"]: + raise ValueError("Summary index model name is required") + if "model_provider_name" not in summary_index_setting or not summary_index_setting["model_provider_name"]: + raise ValueError("Summary index model provider name is required") + @staticmethod def batch_update_document_status( dataset: Dataset, document_ids: list[str], action: Literal["enable", "disable", "archive", "un_archive"], user @@ -3154,6 +3297,35 @@ class SegmentService: if args.enabled or keyword_changed: # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) + # update summary index if summary is provided and has changed + if args.summary is not None: + # When user manually provides summary, allow saving even if summary_index_setting doesn't exist + # summary_index_setting is only needed for LLM generation, not for manual summary vectorization + # Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting + if dataset.indexing_technique == "high_quality": + # Query existing summary from database + from models.dataset import DocumentSegmentSummary + + existing_summary = ( + db.session.query(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .first() + ) + + # Check if summary has changed + existing_summary_content = existing_summary.summary_content if existing_summary else None + if existing_summary_content != args.summary: + # Summary has changed, update it + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary) + except Exception: + logger.exception("Failed to update summary for segment %s", segment.id) + # Don't fail the entire update if summary update fails else: segment_hash = helper.generate_text_hash(content) tokens = 0 @@ -3228,6 +3400,73 @@ class SegmentService: elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX): # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) + # Handle summary index when content changed + if dataset.indexing_technique == "high_quality": + from models.dataset import DocumentSegmentSummary + + existing_summary = ( + db.session.query(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .first() + ) + + if args.summary is None: + # User didn't provide summary, auto-regenerate if segment previously had summary + # Auto-regeneration only happens if summary_index_setting exists and enable is True + if ( + existing_summary + and dataset.summary_index_setting + and dataset.summary_index_setting.get("enable") is True + ): + # Segment previously had summary, regenerate it with new content + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.generate_and_vectorize_summary( + segment, dataset, dataset.summary_index_setting + ) + logger.info("Auto-regenerated summary for segment %s after content change", segment.id) + except Exception: + logger.exception("Failed to auto-regenerate summary for segment %s", segment.id) + # Don't fail the entire update if summary regeneration fails + else: + # User provided summary, check if it has changed + # Manual summary updates are allowed even if summary_index_setting doesn't exist + existing_summary_content = existing_summary.summary_content if existing_summary else None + if existing_summary_content != args.summary: + # Summary has changed, use user-provided summary + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary) + logger.info("Updated summary for segment %s with user-provided content", segment.id) + except Exception: + logger.exception("Failed to update summary for segment %s", segment.id) + # Don't fail the entire update if summary update fails + else: + # Summary hasn't changed, regenerate based on new content + # Auto-regeneration only happens if summary_index_setting exists and enable is True + if ( + existing_summary + and dataset.summary_index_setting + and dataset.summary_index_setting.get("enable") is True + ): + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.generate_and_vectorize_summary( + segment, dataset, dataset.summary_index_setting + ) + logger.info( + "Regenerated summary for segment %s after content change (summary unchanged)", + segment.id, + ) + except Exception: + logger.exception("Failed to regenerate summary for segment %s", segment.id) + # Don't fail the entire update if summary regeneration fails # update multimodel vector index VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset) except Exception as e: @@ -3616,6 +3855,39 @@ class SegmentService: ) return result if isinstance(result, DocumentSegment) else None + @classmethod + def get_segments_by_document_and_dataset( + cls, + document_id: str, + dataset_id: str, + status: str | None = None, + enabled: bool | None = None, + ) -> Sequence[DocumentSegment]: + """ + Get segments for a document in a dataset with optional filtering. + + Args: + document_id: Document ID + dataset_id: Dataset ID + status: Optional status filter (e.g., "completed") + enabled: Optional enabled filter (True/False) + + Returns: + Sequence of DocumentSegment instances + """ + query = select(DocumentSegment).where( + DocumentSegment.document_id == document_id, + DocumentSegment.dataset_id == dataset_id, + ) + + if status is not None: + query = query.where(DocumentSegment.status == status) + + if enabled is not None: + query = query.where(DocumentSegment.enabled == enabled) + + return db.session.scalars(query).all() + class DatasetCollectionBindingService: @classmethod diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 7959734e89..8dc5b93501 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -119,6 +119,7 @@ class KnowledgeConfig(BaseModel): data_source: DataSource | None = None process_rule: ProcessRule | None = None retrieval_model: RetrievalModel | None = None + summary_index_setting: dict | None = None doc_form: str = "text_model" doc_language: str = "English" embedding_model: str | None = None @@ -141,6 +142,7 @@ class SegmentUpdateArgs(BaseModel): regenerate_child_chunks: bool = False enabled: bool | None = None attachment_ids: list[str] | None = None + summary: str | None = None # Summary content for summary index class ChildChunkUpdateArgs(BaseModel): diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index cbb0efcc2a..041ae4edba 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -116,6 +116,8 @@ class KnowledgeConfiguration(BaseModel): embedding_model: str = "" keyword_number: int | None = 10 retrieval_model: RetrievalSetting + # add summary index setting + summary_index_setting: dict | None = None @field_validator("embedding_model_provider", mode="before") @classmethod diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index c1c6e204fb..be1ce834f6 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -343,6 +343,9 @@ class RagPipelineDslService: dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider elif knowledge_configuration.indexing_technique == "economy": dataset.keyword_number = knowledge_configuration.keyword_number + # Update summary_index_setting if provided + if knowledge_configuration.summary_index_setting is not None: + dataset.summary_index_setting = knowledge_configuration.summary_index_setting dataset.pipeline_id = pipeline.id self._session.add(dataset) self._session.commit() @@ -477,6 +480,9 @@ class RagPipelineDslService: dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider elif knowledge_configuration.indexing_technique == "economy": dataset.keyword_number = knowledge_configuration.keyword_number + # Update summary_index_setting if provided + if knowledge_configuration.summary_index_setting is not None: + dataset.summary_index_setting = knowledge_configuration.summary_index_setting dataset.pipeline_id = pipeline.id self._session.add(dataset) self._session.commit() diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py new file mode 100644 index 0000000000..b8e1f8bc3f --- /dev/null +++ b/api/services/summary_index_service.py @@ -0,0 +1,1432 @@ +"""Summary index service for generating and managing document segment summaries.""" + +import logging +import time +import uuid +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy.orm import Session + +from core.db.session_factory import session_factory +from core.model_manager import ModelManager +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.models.document import Document +from libs import helper +from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary +from models.dataset import Document as DatasetDocument + +logger = logging.getLogger(__name__) + + +class SummaryIndexService: + """Service for generating and managing summary indexes.""" + + @staticmethod + def generate_summary_for_segment( + segment: DocumentSegment, + dataset: Dataset, + summary_index_setting: dict, + ) -> tuple[str, LLMUsage]: + """ + Generate summary for a single segment. + + Args: + segment: DocumentSegment to generate summary for + dataset: Dataset containing the segment + summary_index_setting: Summary index configuration + + Returns: + Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object + + Raises: + ValueError: If summary_index_setting is invalid or generation fails + """ + # Reuse the existing generate_summary method from ParagraphIndexProcessor + # Use lazy import to avoid circular import + from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor + + summary_content, usage = ParagraphIndexProcessor.generate_summary( + tenant_id=dataset.tenant_id, + text=segment.content, + summary_index_setting=summary_index_setting, + segment_id=segment.id, + ) + + if not summary_content: + raise ValueError("Generated summary is empty") + + return summary_content, usage + + @staticmethod + def create_summary_record( + segment: DocumentSegment, + dataset: Dataset, + summary_content: str, + status: str = "generating", + ) -> DocumentSegmentSummary: + """ + Create or update a DocumentSegmentSummary record. + If a summary record already exists for this segment, it will be updated instead of creating a new one. + + Args: + segment: DocumentSegment to create summary for + dataset: Dataset containing the segment + summary_content: Generated summary content + status: Summary status (default: "generating") + + Returns: + Created or updated DocumentSegmentSummary instance + """ + with session_factory.create_session() as session: + # Check if summary record already exists + existing_summary = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + + if existing_summary: + # Update existing record + existing_summary.summary_content = summary_content + existing_summary.status = status + existing_summary.error = None # type: ignore[assignment] # Clear any previous errors + # Re-enable if it was disabled + if not existing_summary.enabled: + existing_summary.enabled = True + existing_summary.disabled_at = None + existing_summary.disabled_by = None + session.add(existing_summary) + session.flush() + return existing_summary + else: + # Create new record (enabled by default) + summary_record = DocumentSegmentSummary( + dataset_id=dataset.id, + document_id=segment.document_id, + chunk_id=segment.id, + summary_content=summary_content, + status=status, + enabled=True, # Explicitly set enabled to True + ) + session.add(summary_record) + session.flush() + return summary_record + + @staticmethod + def vectorize_summary( + summary_record: DocumentSegmentSummary, + segment: DocumentSegment, + dataset: Dataset, + session: Session | None = None, + ) -> None: + """ + Vectorize summary and store in vector database. + + Args: + summary_record: DocumentSegmentSummary record + segment: Original DocumentSegment + dataset: Dataset containing the segment + session: Optional SQLAlchemy session. If provided, uses this session instead of creating a new one. + If not provided, creates a new session and commits automatically. + """ + if dataset.indexing_technique != "high_quality": + logger.warning( + "Summary vectorization skipped for dataset %s: indexing_technique is not high_quality", + dataset.id, + ) + return + + # Get summary_record_id for later session queries + summary_record_id = summary_record.id + # Save the original session parameter for use in error handling + original_session = session + logger.debug( + "Starting vectorization for segment %s, summary_record_id=%s, using_provided_session=%s", + segment.id, + summary_record_id, + original_session is not None, + ) + + # Reuse existing index_node_id if available (like segment does), otherwise generate new one + old_summary_node_id = summary_record.summary_index_node_id + if old_summary_node_id: + # Reuse existing index_node_id (like segment behavior) + summary_index_node_id = old_summary_node_id + logger.debug("Reusing existing index_node_id %s for segment %s", summary_index_node_id, segment.id) + else: + # Generate new index node ID only for new summaries + summary_index_node_id = str(uuid.uuid4()) + logger.debug("Generated new index_node_id %s for segment %s", summary_index_node_id, segment.id) + + # Always regenerate hash (in case summary content changed) + summary_content = summary_record.summary_content + if not summary_content or not summary_content.strip(): + raise ValueError(f"Summary content is empty for segment {segment.id}, cannot vectorize") + summary_hash = helper.generate_text_hash(summary_content) + + # Delete old vector only if we're reusing the same index_node_id (to overwrite) + # If index_node_id changed, the old vector should have been deleted elsewhere + if old_summary_node_id and old_summary_node_id == summary_index_node_id: + try: + vector = Vector(dataset) + vector.delete_by_ids([old_summary_node_id]) + except Exception as e: + logger.warning( + "Failed to delete old summary vector for segment %s: %s. Continuing with new vectorization.", + segment.id, + str(e), + ) + + # Calculate embedding tokens for summary (for logging and statistics) + embedding_tokens = 0 + try: + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + if embedding_model: + tokens_list = embedding_model.get_text_embedding_num_tokens([summary_content]) + embedding_tokens = tokens_list[0] if tokens_list else 0 + except Exception as e: + logger.warning("Failed to calculate embedding tokens for summary: %s", str(e)) + + # Create document with summary content and metadata + summary_document = Document( + page_content=summary_content, + metadata={ + "doc_id": summary_index_node_id, + "doc_hash": summary_hash, + "dataset_id": dataset.id, + "document_id": segment.document_id, + "original_chunk_id": segment.id, # Key: link to original chunk + "doc_type": DocType.TEXT, + "is_summary": True, # Identifier for summary documents + }, + ) + + # Vectorize and store with retry mechanism for connection errors + max_retries = 3 + retry_delay = 2.0 + + for attempt in range(max_retries): + try: + logger.debug( + "Attempting to vectorize summary for segment %s (attempt %s/%s)", + segment.id, + attempt + 1, + max_retries, + ) + vector = Vector(dataset) + # Use duplicate_check=False to ensure re-vectorization even if old vector still exists + # The old vector should have been deleted above, but if deletion failed, + # we still want to re-vectorize (upsert will overwrite) + vector.add_texts([summary_document], duplicate_check=False) + logger.debug( + "Successfully added summary vector to database for segment %s (attempt %s/%s)", + segment.id, + attempt + 1, + max_retries, + ) + + # Log embedding token usage + if embedding_tokens > 0: + logger.info( + "Summary embedding for segment %s used %s tokens", + segment.id, + embedding_tokens, + ) + + # Success - update summary record with index node info + # Use provided session if available, otherwise create a new one + use_provided_session = session is not None + if not use_provided_session: + logger.debug("Creating new session for vectorization of segment %s", segment.id) + session_context = session_factory.create_session() + session = session_context.__enter__() + else: + logger.debug("Using provided session for vectorization of segment %s", segment.id) + session_context = None # Don't use context manager for provided session + + # At this point, session is guaranteed to be not None + # Type narrowing: session is definitely not None after the if/else above + if session is None: + raise RuntimeError("Session should not be None at this point") + + try: + # Declare summary_record_in_session variable + summary_record_in_session: DocumentSegmentSummary | None + + # If using provided session, merge the summary_record into it + if use_provided_session: + # Merge the summary_record into the provided session + logger.debug( + "Merging summary_record (id=%s) into provided session for segment %s", + summary_record_id, + segment.id, + ) + summary_record_in_session = session.merge(summary_record) + logger.debug( + "Successfully merged summary_record for segment %s, merged_id=%s", + segment.id, + summary_record_in_session.id, + ) + else: + # Query the summary record in the new session + logger.debug( + "Querying summary_record by id=%s for segment %s in new session", + summary_record_id, + segment.id, + ) + summary_record_in_session = ( + session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first() + ) + + if not summary_record_in_session: + # Record not found - try to find by chunk_id and dataset_id instead + logger.debug( + "Summary record not found by id=%s, trying chunk_id=%s and dataset_id=%s " + "for segment %s", + summary_record_id, + segment.id, + dataset.id, + segment.id, + ) + summary_record_in_session = ( + session.query(DocumentSegmentSummary) + .filter_by(chunk_id=segment.id, dataset_id=dataset.id) + .first() + ) + + if not summary_record_in_session: + # Still not found - create a new one using the parameter data + logger.warning( + "Summary record not found in database for segment %s (id=%s), creating new one. " + "This may indicate a session isolation issue.", + segment.id, + summary_record_id, + ) + summary_record_in_session = DocumentSegmentSummary( + id=summary_record_id, # Use the same ID if available + dataset_id=dataset.id, + document_id=segment.document_id, + chunk_id=segment.id, + summary_content=summary_content, + summary_index_node_id=summary_index_node_id, + summary_index_node_hash=summary_hash, + tokens=embedding_tokens, + status="completed", + enabled=True, + ) + session.add(summary_record_in_session) + logger.info( + "Created new summary record (id=%s) for segment %s after vectorization", + summary_record_id, + segment.id, + ) + else: + # Found by chunk_id - update it + logger.info( + "Found summary record for segment %s by chunk_id " + "(id mismatch: expected %s, found %s). " + "This may indicate the record was created in a different session.", + segment.id, + summary_record_id, + summary_record_in_session.id, + ) + else: + logger.debug( + "Found summary_record (id=%s) for segment %s in new session", + summary_record_id, + segment.id, + ) + + # At this point, summary_record_in_session is guaranteed to be not None + if summary_record_in_session is None: + raise RuntimeError("summary_record_in_session should not be None at this point") + + # Update all fields including summary_content + # Always use the summary_content from the parameter (which is the latest from outer session) + # rather than relying on what's in the database, in case outer session hasn't committed yet + summary_record_in_session.summary_index_node_id = summary_index_node_id + summary_record_in_session.summary_index_node_hash = summary_hash + summary_record_in_session.tokens = embedding_tokens # Save embedding tokens + summary_record_in_session.status = "completed" + # Ensure summary_content is preserved (use the latest from summary_record parameter) + # This is critical: use the parameter value, not the database value + summary_record_in_session.summary_content = summary_content + # Explicitly update updated_at to ensure it's refreshed even if other fields haven't changed + summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None) + session.add(summary_record_in_session) + + # Only commit if we created the session ourselves + if not use_provided_session: + logger.debug("Committing session for segment %s (self-created session)", segment.id) + session.commit() + logger.debug("Successfully committed session for segment %s", segment.id) + else: + # When using provided session, flush to ensure changes are written to database + # This prevents refresh() from overwriting our changes + logger.debug( + "Flushing session for segment %s (using provided session, caller will commit)", + segment.id, + ) + session.flush() + logger.debug("Successfully flushed session for segment %s", segment.id) + # If using provided session, let the caller handle commit + + logger.info( + "Successfully vectorized summary for segment %s, index_node_id=%s, index_node_hash=%s, " + "tokens=%s, summary_record_id=%s, use_provided_session=%s", + segment.id, + summary_index_node_id, + summary_hash, + embedding_tokens, + summary_record_in_session.id, + use_provided_session, + ) + # Update the original object for consistency + summary_record.summary_index_node_id = summary_index_node_id + summary_record.summary_index_node_hash = summary_hash + summary_record.tokens = embedding_tokens + summary_record.status = "completed" + summary_record.summary_content = summary_content + if summary_record_in_session.updated_at: + summary_record.updated_at = summary_record_in_session.updated_at + finally: + # Only close session if we created it ourselves + if not use_provided_session and session_context: + session_context.__exit__(None, None, None) + # Success, exit function + return + + except (ConnectionError, Exception) as e: + error_str = str(e).lower() + # Check if it's a connection-related error that might be transient + is_connection_error = any( + keyword in error_str + for keyword in [ + "connection", + "disconnected", + "timeout", + "network", + "could not connect", + "server disconnected", + "weaviate", + ] + ) + + if is_connection_error and attempt < max_retries - 1: + # Retry for connection errors + wait_time = retry_delay * (2**attempt) # Exponential backoff + logger.warning( + "Vectorization attempt %s/%s failed for segment %s (connection error): %s. " + "Retrying in %.1f seconds...", + attempt + 1, + max_retries, + segment.id, + str(e), + wait_time, + ) + time.sleep(wait_time) + continue + else: + # Final attempt failed or non-connection error - log and update status + logger.error( + "Failed to vectorize summary for segment %s after %s attempts: %s. " + "summary_record_id=%s, index_node_id=%s, use_provided_session=%s", + segment.id, + attempt + 1, + str(e), + summary_record_id, + summary_index_node_id, + session is not None, + exc_info=True, + ) + # Update error status in session + # Use the original_session saved at function start (the function parameter) + logger.debug( + "Updating error status for segment %s, summary_record_id=%s, has_original_session=%s", + segment.id, + summary_record_id, + original_session is not None, + ) + # Always create a new session for error handling to avoid issues with closed sessions + # Even if original_session was provided, we create a new one for safety + with session_factory.create_session() as error_session: + # Try to find the record by id first + # Note: Using assignment only (no type annotation) to avoid redeclaration error + summary_record_in_session = ( + error_session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first() + ) + if not summary_record_in_session: + # Try to find by chunk_id and dataset_id + logger.debug( + "Summary record not found by id=%s, trying chunk_id=%s and dataset_id=%s " + "for segment %s", + summary_record_id, + segment.id, + dataset.id, + segment.id, + ) + summary_record_in_session = ( + error_session.query(DocumentSegmentSummary) + .filter_by(chunk_id=segment.id, dataset_id=dataset.id) + .first() + ) + + if summary_record_in_session: + summary_record_in_session.status = "error" + summary_record_in_session.error = f"Vectorization failed: {str(e)}" + summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None) + error_session.add(summary_record_in_session) + error_session.commit() + logger.info( + "Updated error status in new session for segment %s, record_id=%s", + segment.id, + summary_record_in_session.id, + ) + # Update the original object for consistency + summary_record.status = "error" + summary_record.error = summary_record_in_session.error + summary_record.updated_at = summary_record_in_session.updated_at + else: + logger.warning( + "Could not update error status: summary record not found for segment %s (id=%s). " + "This may indicate a session isolation issue.", + segment.id, + summary_record_id, + ) + raise + + @staticmethod + def batch_create_summary_records( + segments: list[DocumentSegment], + dataset: Dataset, + status: str = "not_started", + ) -> None: + """ + Batch create summary records for segments with specified status. + If a record already exists, update its status. + + Args: + segments: List of DocumentSegment instances + dataset: Dataset containing the segments + status: Initial status for the records (default: "not_started") + """ + segment_ids = [segment.id for segment in segments] + if not segment_ids: + return + + with session_factory.create_session() as session: + # Query existing summary records + existing_summaries = ( + session.query(DocumentSegmentSummary) + .filter( + DocumentSegmentSummary.chunk_id.in_(segment_ids), + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .all() + ) + existing_summary_map = {summary.chunk_id: summary for summary in existing_summaries} + + # Create or update records + for segment in segments: + existing_summary = existing_summary_map.get(segment.id) + if existing_summary: + # Update existing record + existing_summary.status = status + existing_summary.error = None # type: ignore[assignment] # Clear any previous errors + if not existing_summary.enabled: + existing_summary.enabled = True + existing_summary.disabled_at = None + existing_summary.disabled_by = None + session.add(existing_summary) + else: + # Create new record + summary_record = DocumentSegmentSummary( + dataset_id=dataset.id, + document_id=segment.document_id, + chunk_id=segment.id, + summary_content=None, # Will be filled later + status=status, + enabled=True, + ) + session.add(summary_record) + + @staticmethod + def update_summary_record_error( + segment: DocumentSegment, + dataset: Dataset, + error: str, + ) -> None: + """ + Update summary record with error status. + + Args: + segment: DocumentSegment + dataset: Dataset containing the segment + error: Error message + """ + with session_factory.create_session() as session: + summary_record = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + + if summary_record: + summary_record.status = "error" + summary_record.error = error + session.add(summary_record) + session.commit() + else: + logger.warning("Summary record not found for segment %s when updating error", segment.id) + + @staticmethod + def generate_and_vectorize_summary( + segment: DocumentSegment, + dataset: Dataset, + summary_index_setting: dict, + ) -> DocumentSegmentSummary: + """ + Generate summary for a segment and vectorize it. + Assumes summary record already exists (created by batch_create_summary_records). + + Args: + segment: DocumentSegment to generate summary for + dataset: Dataset containing the segment + summary_index_setting: Summary index configuration + + Returns: + Created DocumentSegmentSummary instance + + Raises: + ValueError: If summary generation fails + """ + with session_factory.create_session() as session: + try: + # Get or refresh summary record in this session + summary_record_in_session = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + + if not summary_record_in_session: + # If not found, create one + logger.warning("Summary record not found for segment %s, creating one", segment.id) + summary_record_in_session = DocumentSegmentSummary( + dataset_id=dataset.id, + document_id=segment.document_id, + chunk_id=segment.id, + summary_content="", + status="generating", + enabled=True, + ) + session.add(summary_record_in_session) + session.flush() + + # Update status to "generating" + summary_record_in_session.status = "generating" + summary_record_in_session.error = None # type: ignore[assignment] + session.add(summary_record_in_session) + # Don't flush here - wait until after vectorization succeeds + + # Generate summary (returns summary_content and llm_usage) + summary_content, llm_usage = SummaryIndexService.generate_summary_for_segment( + segment, dataset, summary_index_setting + ) + + # Update summary content + summary_record_in_session.summary_content = summary_content + session.add(summary_record_in_session) + # Flush to ensure summary_content is saved before vectorize_summary queries it + session.flush() + + # Log LLM usage for summary generation + if llm_usage and llm_usage.total_tokens > 0: + logger.info( + "Summary generation for segment %s used %s tokens (prompt: %s, completion: %s)", + segment.id, + llm_usage.total_tokens, + llm_usage.prompt_tokens, + llm_usage.completion_tokens, + ) + + # Vectorize summary (will delete old vector if exists before creating new one) + # Pass the session-managed record to vectorize_summary + # vectorize_summary will update status to "completed" and tokens in its own session + # vectorize_summary will also ensure summary_content is preserved + try: + # Pass the session to vectorize_summary to avoid session isolation issues + SummaryIndexService.vectorize_summary(summary_record_in_session, segment, dataset, session=session) + # Refresh the object from database to get the updated status and tokens from vectorize_summary + session.refresh(summary_record_in_session) + # Commit the session + # (summary_record_in_session should have status="completed" and tokens from refresh) + session.commit() + logger.info("Successfully generated and vectorized summary for segment %s", segment.id) + return summary_record_in_session + except Exception as vectorize_error: + # If vectorization fails, update status to error in current session + logger.exception("Failed to vectorize summary for segment %s", segment.id) + summary_record_in_session.status = "error" + summary_record_in_session.error = f"Vectorization failed: {str(vectorize_error)}" + session.add(summary_record_in_session) + session.commit() + raise + + except Exception as e: + logger.exception("Failed to generate summary for segment %s", segment.id) + # Update summary record with error status + summary_record_in_session = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + if summary_record_in_session: + summary_record_in_session.status = "error" + summary_record_in_session.error = str(e) + session.add(summary_record_in_session) + session.commit() + raise + + @staticmethod + def generate_summaries_for_document( + dataset: Dataset, + document: DatasetDocument, + summary_index_setting: dict, + segment_ids: list[str] | None = None, + only_parent_chunks: bool = False, + ) -> list[DocumentSegmentSummary]: + """ + Generate summaries for all segments in a document including vectorization. + + Args: + dataset: Dataset containing the document + document: DatasetDocument to generate summaries for + summary_index_setting: Summary index configuration + segment_ids: Optional list of specific segment IDs to process + only_parent_chunks: If True, only process parent chunks (for parent-child mode) + + Returns: + List of created DocumentSegmentSummary instances + """ + # Only generate summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + logger.info( + "Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'", + dataset.id, + dataset.indexing_technique, + ) + return [] + + if not summary_index_setting or not summary_index_setting.get("enable"): + logger.info("Summary index is disabled for dataset %s", dataset.id) + return [] + + # Skip qa_model documents + if document.doc_form == "qa_model": + logger.info("Skipping summary generation for qa_model document %s", document.id) + return [] + + logger.info( + "Starting summary generation for document %s in dataset %s, segment_ids: %s, only_parent_chunks: %s", + document.id, + dataset.id, + len(segment_ids) if segment_ids else "all", + only_parent_chunks, + ) + + with session_factory.create_session() as session: + # Query segments (only enabled segments) + query = session.query(DocumentSegment).filter_by( + dataset_id=dataset.id, + document_id=document.id, + status="completed", + enabled=True, # Only generate summaries for enabled segments + ) + + if segment_ids: + query = query.filter(DocumentSegment.id.in_(segment_ids)) + + segments = query.all() + + if not segments: + logger.info("No segments found for document %s", document.id) + return [] + + # Batch create summary records with "not_started" status before processing + # This ensures all records exist upfront, allowing status tracking + SummaryIndexService.batch_create_summary_records( + segments=segments, + dataset=dataset, + status="not_started", + ) + session.commit() # Commit initial records + + summary_records = [] + + for segment in segments: + # For parent-child mode, only process parent chunks + # In parent-child mode, all DocumentSegments are parent chunks, + # so we process all of them. Child chunks are stored in ChildChunk table + # and are not DocumentSegments, so they won't be in the segments list. + # This check is mainly for clarity and future-proofing. + if only_parent_chunks: + # In parent-child mode, all segments in the query are parent chunks + # Child chunks are not DocumentSegments, so they won't appear here + # We can process all segments + pass + + try: + summary_record = SummaryIndexService.generate_and_vectorize_summary( + segment, dataset, summary_index_setting + ) + summary_records.append(summary_record) + except Exception as e: + logger.exception("Failed to generate summary for segment %s", segment.id) + # Update summary record with error status + SummaryIndexService.update_summary_record_error( + segment=segment, + dataset=dataset, + error=str(e), + ) + # Continue with other segments + continue + + logger.info( + "Completed summary generation for document %s: %s summaries generated and vectorized", + document.id, + len(summary_records), + ) + return summary_records + + @staticmethod + def disable_summaries_for_segments( + dataset: Dataset, + segment_ids: list[str] | None = None, + disabled_by: str | None = None, + ) -> None: + """ + Disable summary records and remove vectors from vector database for segments. + Unlike delete, this preserves the summary records but marks them as disabled. + + Args: + dataset: Dataset containing the segments + segment_ids: List of segment IDs to disable summaries for. If None, disable all. + disabled_by: User ID who disabled the summaries + """ + from libs.datetime_utils import naive_utc_now + + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter_by( + dataset_id=dataset.id, + enabled=True, # Only disable enabled summaries + ) + + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + + summaries = query.all() + + if not summaries: + return + + logger.info( + "Disabling %s summary records for dataset %s, segment_ids: %s", + len(summaries), + dataset.id, + len(segment_ids) if segment_ids else "all", + ) + + # Remove from vector database (but keep records) + if dataset.indexing_technique == "high_quality": + summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] + if summary_node_ids: + try: + vector = Vector(dataset) + vector.delete_by_ids(summary_node_ids) + except Exception as e: + logger.warning("Failed to remove summary vectors: %s", str(e)) + + # Disable summary records (don't delete) + now = naive_utc_now() + for summary in summaries: + summary.enabled = False + summary.disabled_at = now + summary.disabled_by = disabled_by + session.add(summary) + + session.commit() + logger.info("Disabled %s summary records for dataset %s", len(summaries), dataset.id) + + @staticmethod + def enable_summaries_for_segments( + dataset: Dataset, + segment_ids: list[str] | None = None, + ) -> None: + """ + Enable summary records and re-add vectors to vector database for segments. + + Note: This method enables summaries based on chunk status, not summary_index_setting.enable. + The summary_index_setting.enable flag only controls automatic generation, + not whether existing summaries can be used. + Summary.enabled should always be kept in sync with chunk.enabled. + + Args: + dataset: Dataset containing the segments + segment_ids: List of segment IDs to enable summaries for. If None, enable all. + """ + # Only enable summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + return + + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter_by( + dataset_id=dataset.id, + enabled=False, # Only enable disabled summaries + ) + + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + + summaries = query.all() + + if not summaries: + return + + logger.info( + "Enabling %s summary records for dataset %s, segment_ids: %s", + len(summaries), + dataset.id, + len(segment_ids) if segment_ids else "all", + ) + + # Re-vectorize and re-add to vector database + enabled_count = 0 + for summary in summaries: + # Get the original segment + segment = ( + session.query(DocumentSegment) + .filter_by( + id=summary.chunk_id, + dataset_id=dataset.id, + ) + .first() + ) + + # Summary.enabled stays in sync with chunk.enabled, + # only enable summary if the associated chunk is enabled. + if not segment or not segment.enabled or segment.status != "completed": + continue + + if not summary.summary_content: + continue + + try: + # Re-vectorize summary (this will update status and tokens in its own session) + # Pass the session to vectorize_summary to avoid session isolation issues + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=session) + + # Refresh the object from database to get the updated status and tokens from vectorize_summary + session.refresh(summary) + + # Enable summary record + summary.enabled = True + summary.disabled_at = None + summary.disabled_by = None + session.add(summary) + enabled_count += 1 + except Exception: + logger.exception("Failed to re-vectorize summary %s", summary.id) + # Keep it disabled if vectorization fails + continue + + session.commit() + logger.info("Enabled %s summary records for dataset %s", enabled_count, dataset.id) + + @staticmethod + def delete_summaries_for_segments( + dataset: Dataset, + segment_ids: list[str] | None = None, + ) -> None: + """ + Delete summary records and vectors for segments (used only for actual deletion scenarios). + For disable/enable operations, use disable_summaries_for_segments/enable_summaries_for_segments. + + Args: + dataset: Dataset containing the segments + segment_ids: List of segment IDs to delete summaries for. If None, delete all. + """ + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id) + + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + + summaries = query.all() + + if not summaries: + return + + # Delete from vector database + if dataset.indexing_technique == "high_quality": + summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] + if summary_node_ids: + vector = Vector(dataset) + vector.delete_by_ids(summary_node_ids) + + # Delete summary records + for summary in summaries: + session.delete(summary) + + session.commit() + logger.info("Deleted %s summary records for dataset %s", len(summaries), dataset.id) + + @staticmethod + def update_summary_for_segment( + segment: DocumentSegment, + dataset: Dataset, + summary_content: str, + ) -> DocumentSegmentSummary | None: + """ + Update summary for a segment and re-vectorize it. + + Args: + segment: DocumentSegment to update summary for + dataset: Dataset containing the segment + summary_content: New summary content + + Returns: + Updated DocumentSegmentSummary instance, or None if indexing technique is not high_quality + """ + # Only update summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + return None + + # When user manually provides summary, allow saving even if summary_index_setting doesn't exist + # summary_index_setting is only needed for LLM generation, not for manual summary vectorization + # Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting + + # Skip qa_model documents + if segment.document and segment.document.doc_form == "qa_model": + return None + + with session_factory.create_session() as session: + try: + # Check if summary_content is empty (whitespace-only strings are considered empty) + if not summary_content or not summary_content.strip(): + # If summary is empty, only delete existing summary vector and record + summary_record = ( + session.query(DocumentSegmentSummary) + .filter_by(chunk_id=segment.id, dataset_id=dataset.id) + .first() + ) + + if summary_record: + # Delete old vector if exists + old_summary_node_id = summary_record.summary_index_node_id + if old_summary_node_id: + try: + vector = Vector(dataset) + vector.delete_by_ids([old_summary_node_id]) + except Exception as e: + logger.warning( + "Failed to delete old summary vector for segment %s: %s", + segment.id, + str(e), + ) + + # Delete summary record since summary is empty + session.delete(summary_record) + session.commit() + logger.info("Deleted summary for segment %s (empty content provided)", segment.id) + return None + else: + # No existing summary record, nothing to do + logger.info("No summary record found for segment %s, nothing to delete", segment.id) + return None + + # Find existing summary record + summary_record = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + + if summary_record: + # Update existing summary + old_summary_node_id = summary_record.summary_index_node_id + + # Update summary content + summary_record.summary_content = summary_content + summary_record.status = "generating" + summary_record.error = None # type: ignore[assignment] # Clear any previous errors + session.add(summary_record) + # Flush to ensure summary_content is saved before vectorize_summary queries it + session.flush() + + # Delete old vector if exists (before vectorization) + if old_summary_node_id: + try: + vector = Vector(dataset) + vector.delete_by_ids([old_summary_node_id]) + except Exception as e: + logger.warning( + "Failed to delete old summary vector for segment %s: %s", + segment.id, + str(e), + ) + + # Re-vectorize summary (this will update status to "completed" and tokens in its own session) + # vectorize_summary will also ensure summary_content is preserved + # Note: vectorize_summary may take time due to embedding API calls, but we need to complete it + # to ensure the summary is properly indexed + try: + # Pass the session to vectorize_summary to avoid session isolation issues + SummaryIndexService.vectorize_summary(summary_record, segment, dataset, session=session) + # Refresh the object from database to get the updated status and tokens from vectorize_summary + session.refresh(summary_record) + # Now commit the session (summary_record should have status="completed" and tokens from refresh) + session.commit() + logger.info("Successfully updated and re-vectorized summary for segment %s", segment.id) + return summary_record + except Exception as e: + # If vectorization fails, update status to error in current session + # Don't raise the exception - just log it and return the record with error status + # This allows the segment update to complete even if vectorization fails + summary_record.status = "error" + summary_record.error = f"Vectorization failed: {str(e)}" + session.commit() + logger.exception("Failed to vectorize summary for segment %s", segment.id) + # Return the record with error status instead of raising + # The caller can check the status if needed + return summary_record + else: + # Create new summary record if doesn't exist + summary_record = SummaryIndexService.create_summary_record( + segment, dataset, summary_content, status="generating" + ) + # Re-vectorize summary (this will update status to "completed" and tokens in its own session) + # Note: summary_record was created in a different session, + # so we need to merge it into current session + try: + # Merge the record into current session first (since it was created in a different session) + summary_record = session.merge(summary_record) + # Pass the session to vectorize_summary - it will update the merged record + SummaryIndexService.vectorize_summary(summary_record, segment, dataset, session=session) + # Refresh to get updated status and tokens from database + session.refresh(summary_record) + # Commit the session to persist the changes + session.commit() + logger.info("Successfully created and vectorized summary for segment %s", segment.id) + return summary_record + except Exception as e: + # If vectorization fails, update status to error in current session + # Merge the record into current session first + error_record = session.merge(summary_record) + error_record.status = "error" + error_record.error = f"Vectorization failed: {str(e)}" + session.commit() + logger.exception("Failed to vectorize summary for segment %s", segment.id) + # Return the record with error status instead of raising + return error_record + + except Exception as e: + logger.exception("Failed to update summary for segment %s", segment.id) + # Update summary record with error status if it exists + summary_record = ( + session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + ) + if summary_record: + summary_record.status = "error" + summary_record.error = str(e) + session.add(summary_record) + session.commit() + raise + + @staticmethod + def get_segment_summary(segment_id: str, dataset_id: str) -> DocumentSegmentSummary | None: + """ + Get summary for a single segment. + + Args: + segment_id: Segment ID (chunk_id) + dataset_id: Dataset ID + + Returns: + DocumentSegmentSummary instance if found, None otherwise + """ + with session_factory.create_session() as session: + return ( + session.query(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment_id, + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.enabled == True, # Only return enabled summaries + ) + .first() + ) + + @staticmethod + def get_segments_summaries(segment_ids: list[str], dataset_id: str) -> dict[str, DocumentSegmentSummary]: + """ + Get summaries for multiple segments. + + Args: + segment_ids: List of segment IDs (chunk_ids) + dataset_id: Dataset ID + + Returns: + Dictionary mapping segment_id to DocumentSegmentSummary (only enabled summaries) + """ + if not segment_ids: + return {} + + with session_factory.create_session() as session: + summary_records = ( + session.query(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id.in_(segment_ids), + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.enabled == True, # Only return enabled summaries + ) + .all() + ) + + return {summary.chunk_id: summary for summary in summary_records} + + @staticmethod + def get_document_summaries( + document_id: str, dataset_id: str, segment_ids: list[str] | None = None + ) -> list[DocumentSegmentSummary]: + """ + Get all summary records for a document. + + Args: + document_id: Document ID + dataset_id: Dataset ID + segment_ids: Optional list of segment IDs to filter by + + Returns: + List of DocumentSegmentSummary instances (only enabled summaries) + """ + with session_factory.create_session() as session: + query = session.query(DocumentSegmentSummary).filter( + DocumentSegmentSummary.document_id == document_id, + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.enabled == True, # Only return enabled summaries + ) + + if segment_ids: + query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + + return query.all() + + @staticmethod + def get_document_summary_index_status(document_id: str, dataset_id: str, tenant_id: str) -> str | None: + """ + Get summary_index_status for a single document. + + Args: + document_id: Document ID + dataset_id: Dataset ID + tenant_id: Tenant ID + + Returns: + "SUMMARIZING" if there are pending summaries, None otherwise + """ + # Get all segments for this document (excluding qa_model and re_segment) + with session_factory.create_session() as session: + segments = ( + session.query(DocumentSegment.id) + .where( + DocumentSegment.document_id == document_id, + DocumentSegment.status != "re_segment", + DocumentSegment.tenant_id == tenant_id, + ) + .all() + ) + segment_ids = [seg.id for seg in segments] + + if not segment_ids: + return None + + # Get all summary records for these segments + summaries = SummaryIndexService.get_segments_summaries(segment_ids, dataset_id) + summary_status_map = {chunk_id: summary.status for chunk_id, summary in summaries.items()} + + # Check if there are any "not_started" or "generating" status summaries + has_pending_summaries = any( + summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True) + and summary_status_map[segment_id] in ("not_started", "generating") + for segment_id in segment_ids + ) + + return "SUMMARIZING" if has_pending_summaries else None + + @staticmethod + def get_documents_summary_index_status( + document_ids: list[str], dataset_id: str, tenant_id: str + ) -> dict[str, str | None]: + """ + Get summary_index_status for multiple documents. + + Args: + document_ids: List of document IDs + dataset_id: Dataset ID + tenant_id: Tenant ID + + Returns: + Dictionary mapping document_id to summary_index_status ("SUMMARIZING" or None) + """ + if not document_ids: + return {} + + # Get all segments for these documents (excluding qa_model and re_segment) + with session_factory.create_session() as session: + segments = ( + session.query(DocumentSegment.id, DocumentSegment.document_id) + .where( + DocumentSegment.document_id.in_(document_ids), + DocumentSegment.status != "re_segment", + DocumentSegment.tenant_id == tenant_id, + ) + .all() + ) + + # Group segments by document_id + document_segments_map: dict[str, list[str]] = {} + for segment in segments: + doc_id = str(segment.document_id) + if doc_id not in document_segments_map: + document_segments_map[doc_id] = [] + document_segments_map[doc_id].append(segment.id) + + # Get all summary records for these segments + all_segment_ids = [seg.id for seg in segments] + summaries = SummaryIndexService.get_segments_summaries(all_segment_ids, dataset_id) + summary_status_map = {chunk_id: summary.status for chunk_id, summary in summaries.items()} + + # Calculate summary_index_status for each document + result: dict[str, str | None] = {} + for doc_id in document_ids: + segment_ids = document_segments_map.get(doc_id, []) + if not segment_ids: + # No segments, status is None (not started) + result[doc_id] = None + continue + + # Check if there are any "not_started" or "generating" status summaries + # Only check enabled=True summaries (already filtered in query) + # If segment has no summary record (summary_status_map.get returns None), + # it means the summary is disabled (enabled=False) or not created yet, ignore it + has_pending_summaries = any( + summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True) + and summary_status_map[segment_id] in ("not_started", "generating") + for segment_id in segment_ids + ) + + if has_pending_summaries: + # Task is still running (not started or generating) + result[doc_id] = "SUMMARIZING" + else: + # All enabled=True summaries are "completed" or "error", task finished + # Or no enabled=True summaries exist (all disabled) + result[doc_id] = None + + return result + + @staticmethod + def get_document_summary_status_detail( + document_id: str, + dataset_id: str, + ) -> dict[str, Any]: + """ + Get detailed summary status for a document. + + Args: + document_id: Document ID + dataset_id: Dataset ID + + Returns: + Dictionary containing: + - total_segments: Total number of segments in the document + - summary_status: Dictionary with status counts + - completed: Number of summaries completed + - generating: Number of summaries being generated + - error: Number of summaries with errors + - not_started: Number of segments without summary records + - summaries: List of summary records with status and content preview + """ + from services.dataset_service import SegmentService + + # Get all segments for this document + segments = SegmentService.get_segments_by_document_and_dataset( + document_id=document_id, + dataset_id=dataset_id, + status="completed", + enabled=True, + ) + + total_segments = len(segments) + + # Get all summary records for these segments + segment_ids = [segment.id for segment in segments] + summaries = [] + if segment_ids: + summaries = SummaryIndexService.get_document_summaries( + document_id=document_id, + dataset_id=dataset_id, + segment_ids=segment_ids, + ) + + # Create a mapping of chunk_id to summary + summary_map = {summary.chunk_id: summary for summary in summaries} + + # Count statuses + status_counts = { + "completed": 0, + "generating": 0, + "error": 0, + "not_started": 0, + } + + summary_list = [] + for segment in segments: + summary = summary_map.get(segment.id) + if summary: + status = summary.status + status_counts[status] = status_counts.get(status, 0) + 1 + summary_list.append( + { + "segment_id": segment.id, + "segment_position": segment.position, + "status": summary.status, + "summary_preview": ( + summary.summary_content[:100] + "..." + if summary.summary_content and len(summary.summary_content) > 100 + else summary.summary_content + ), + "error": summary.error, + "created_at": int(summary.created_at.timestamp()) if summary.created_at else None, + "updated_at": int(summary.updated_at.timestamp()) if summary.updated_at else None, + } + ) + else: + status_counts["not_started"] += 1 + summary_list.append( + { + "segment_id": segment.id, + "segment_position": segment.position, + "status": "not_started", + "summary_preview": None, + "error": None, + "created_at": None, + "updated_at": None, + } + ) + + return { + "total_segments": total_segments, + "summary_status": status_counts, + "summaries": summary_list, + } diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 62e6497e9d..2d3d00cd50 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -118,6 +118,19 @@ def add_document_to_index_task(dataset_document_id: str): ) session.commit() + # Enable summary indexes for all segments in this document + from services.summary_index_service import SummaryIndexService + + segment_ids_list = [segment.id for segment in segments] + if segment_ids_list: + try: + SummaryIndexService.enable_summaries_for_segments( + dataset=dataset, + segment_ids=segment_ids_list, + ) + except Exception as e: + logger.warning("Failed to enable summaries for document %s: %s", dataset_document.id, str(e)) + end_at = time.perf_counter() logger.info( click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green") diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index 74b939e84d..d388284980 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -50,7 +50,9 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form if segments: index_node_ids = [segment.index_node_id for segment in segments] index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + index_processor.clean( + dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 86e7cc7160..91ace6be02 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -51,7 +51,9 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i if segments: index_node_ids = [segment.index_node_id for segment in segments] index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + index_processor.clean( + dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index bcca1bf49f..4214f043e0 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -42,7 +42,9 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): ).all() index_node_ids = [segment.index_node_id for segment in segments] - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + index_processor.clean( + dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) segment_ids = [segment.id for segment in segments] segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) session.execute(segment_delete_stmt) diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index bfa709502c..764c635d83 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -47,6 +47,7 @@ def delete_segment_from_index_task( doc_form = dataset_document.doc_form # Proceed with index cleanup using the index_node_ids directly + # For actual deletion, we should delete summaries (not just disable them) index_processor = IndexProcessorFactory(doc_form).init_index_processor() index_processor.clean( dataset, @@ -54,6 +55,7 @@ def delete_segment_from_index_task( with_keywords=True, delete_child_chunks=True, precomputed_child_node_ids=child_node_ids, + delete_summaries=True, # Actually delete summaries when segment is deleted ) if dataset.is_multimodal: # delete segment attachment binding diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 0ce6429a94..bc45171623 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -60,6 +60,18 @@ def disable_segment_from_index_task(segment_id: str): index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor.clean(dataset, [segment.index_node_id]) + # Disable summary index for this segment + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.disable_summaries_for_segments( + dataset=dataset, + segment_ids=[segment.id], + disabled_by=segment.disabled_by, + ) + except Exception as e: + logger.warning("Failed to disable summary for segment %s: %s", segment.id, str(e)) + end_at = time.perf_counter() logger.info( click.style( diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index 03635902d1..3cc267e821 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -68,6 +68,21 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen index_node_ids.extend(attachment_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) + # Disable summary indexes for these segments + from services.summary_index_service import SummaryIndexService + + segment_ids_list = [segment.id for segment in segments] + try: + # Get disabled_by from first segment (they should all have the same disabled_by) + disabled_by = segments[0].disabled_by if segments else None + SummaryIndexService.disable_summaries_for_segments( + dataset=dataset, + segment_ids=segment_ids_list, + disabled_by=disabled_by, + ) + except Exception as e: + logger.warning("Failed to disable summaries for segments: %s", str(e)) + end_at = time.perf_counter() logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green")) except Exception: diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 3bdff60196..34496e9c6f 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -14,6 +14,7 @@ from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document from services.feature_service import FeatureService +from tasks.generate_summary_index_task import generate_summary_index_task logger = logging.getLogger(__name__) @@ -99,6 +100,78 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): indexing_runner.run(documents) end_at = time.perf_counter() logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + + # Trigger summary index generation for completed documents if enabled + # Only generate for high_quality indexing technique and when summary_index_setting is enabled + # Re-query dataset to get latest summary_index_setting (in case it was updated) + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logger.warning("Dataset %s not found after indexing", dataset_id) + return + + if dataset.indexing_technique == "high_quality": + summary_index_setting = dataset.summary_index_setting + if summary_index_setting and summary_index_setting.get("enable"): + # expire all session to get latest document's indexing status + session.expire_all() + # Check each document's indexing status and trigger summary generation if completed + for document_id in document_ids: + # Re-query document to get latest status (IndexingRunner may have updated it) + document = ( + session.query(Document) + .where(Document.id == document_id, Document.dataset_id == dataset_id) + .first() + ) + if document: + logger.info( + "Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s", + document_id, + document.indexing_status, + document.doc_form, + document.need_summary, + ) + if ( + document.indexing_status == "completed" + and document.doc_form != "qa_model" + and document.need_summary is True + ): + try: + generate_summary_index_task.delay(dataset.id, document_id, None) + logger.info( + "Queued summary index generation task for document %s in dataset %s " + "after indexing completed", + document_id, + dataset.id, + ) + except Exception: + logger.exception( + "Failed to queue summary index generation task for document %s", + document_id, + ) + # Don't fail the entire indexing process if summary task queuing fails + else: + logger.info( + "Skipping summary generation for document %s: " + "status=%s, doc_form=%s, need_summary=%s", + document_id, + document.indexing_status, + document.doc_form, + document.need_summary, + ) + else: + logger.warning("Document %s not found after indexing", document_id) + else: + logger.info( + "Summary index generation skipped for dataset %s: summary_index_setting.enable=%s", + dataset.id, + summary_index_setting.get("enable") if summary_index_setting else None, + ) + else: + logger.info( + "Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')", + dataset.id, + dataset.indexing_technique, + ) except DocumentIsPausedError as ex: logger.info(click.style(str(ex), fg="yellow")) except Exception: diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 1f9f21aa7e..41ebb0b076 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -106,6 +106,17 @@ def enable_segment_to_index_task(segment_id: str): # save vector index index_processor.load(dataset, [document], multimodal_documents=multimodel_documents) + # Enable summary index for this segment + from services.summary_index_service import SummaryIndexService + + try: + SummaryIndexService.enable_summaries_for_segments( + dataset=dataset, + segment_ids=[segment.id], + ) + except Exception as e: + logger.warning("Failed to enable summary for segment %s: %s", segment.id, str(e)) + end_at = time.perf_counter() logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green")) except Exception as e: diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py index 48d3c8e178..d90eb4c39f 100644 --- a/api/tasks/enable_segments_to_index_task.py +++ b/api/tasks/enable_segments_to_index_task.py @@ -106,6 +106,18 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i # save vector index index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) + # Enable summary indexes for these segments + from services.summary_index_service import SummaryIndexService + + segment_ids_list = [segment.id for segment in segments] + try: + SummaryIndexService.enable_summaries_for_segments( + dataset=dataset, + segment_ids=segment_ids_list, + ) + except Exception as e: + logger.warning("Failed to enable summaries for segments: %s", str(e)) + end_at = time.perf_counter() logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green")) except Exception as e: diff --git a/api/tasks/generate_summary_index_task.py b/api/tasks/generate_summary_index_task.py new file mode 100644 index 0000000000..e4273e16b5 --- /dev/null +++ b/api/tasks/generate_summary_index_task.py @@ -0,0 +1,119 @@ +"""Async task for generating summary indexes.""" + +import logging +import time + +import click +from celery import shared_task + +from core.db.session_factory import session_factory +from models.dataset import Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument +from services.summary_index_service import SummaryIndexService + +logger = logging.getLogger(__name__) + + +@shared_task(queue="dataset") +def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: list[str] | None = None): + """ + Async generate summary index for document segments. + + Args: + dataset_id: Dataset ID + document_id: Document ID + segment_ids: Optional list of specific segment IDs to process. If None, process all segments. + + Usage: + generate_summary_index_task.delay(dataset_id, document_id) + generate_summary_index_task.delay(dataset_id, document_id, segment_ids) + """ + logger.info( + click.style( + f"Start generating summary index for document {document_id} in dataset {dataset_id}", + fg="green", + ) + ) + start_at = time.perf_counter() + + try: + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red")) + return + + document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() + if not document: + logger.error(click.style(f"Document not found: {document_id}", fg="red")) + return + + # Check if document needs summary + if not document.need_summary: + logger.info( + click.style( + f"Skipping summary generation for document {document_id}: need_summary is False", + fg="cyan", + ) + ) + return + + # Only generate summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + logger.info( + click.style( + f"Skipping summary generation for dataset {dataset_id}: " + f"indexing_technique is {dataset.indexing_technique}, not 'high_quality'", + fg="cyan", + ) + ) + return + + # Check if summary index is enabled + summary_index_setting = dataset.summary_index_setting + if not summary_index_setting or not summary_index_setting.get("enable"): + logger.info( + click.style( + f"Summary index is disabled for dataset {dataset_id}", + fg="cyan", + ) + ) + return + + # Determine if only parent chunks should be processed + only_parent_chunks = dataset.chunk_structure == "parent_child_index" + + # Generate summaries + summary_records = SummaryIndexService.generate_summaries_for_document( + dataset=dataset, + document=document, + summary_index_setting=summary_index_setting, + segment_ids=segment_ids, + only_parent_chunks=only_parent_chunks, + ) + + end_at = time.perf_counter() + logger.info( + click.style( + f"Summary index generation completed for document {document_id}: " + f"{len(summary_records)} summaries generated, latency: {end_at - start_at}", + fg="green", + ) + ) + + except Exception as e: + logger.exception("Failed to generate summary index for document %s", document_id) + # Update document segments with error status if needed + if segment_ids: + error_message = f"Summary generation failed: {str(e)}" + with session_factory.create_session() as session: + session.query(DocumentSegment).filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + ).update( + { + DocumentSegment.error: error_message, + }, + synchronize_session=False, + ) + session.commit() diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py new file mode 100644 index 0000000000..cf8988d13e --- /dev/null +++ b/api/tasks/regenerate_summary_index_task.py @@ -0,0 +1,315 @@ +"""Task for regenerating summary indexes when dataset settings change.""" + +import logging +import time +from collections import defaultdict + +import click +from celery import shared_task +from sqlalchemy import or_, select + +from core.db.session_factory import session_factory +from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary +from models.dataset import Document as DatasetDocument +from services.summary_index_service import SummaryIndexService + +logger = logging.getLogger(__name__) + + +@shared_task(queue="dataset") +def regenerate_summary_index_task( + dataset_id: str, + regenerate_reason: str = "summary_model_changed", + regenerate_vectors_only: bool = False, +): + """ + Regenerate summary indexes for all documents in a dataset. + + This task is triggered when: + 1. summary_index_setting model changes (regenerate_reason="summary_model_changed") + - Regenerates summary content and vectors for all existing summaries + 2. embedding_model changes (regenerate_reason="embedding_model_changed") + - Only regenerates vectors for existing summaries (keeps summary content) + + Args: + dataset_id: Dataset ID + regenerate_reason: Reason for regeneration ("summary_model_changed" or "embedding_model_changed") + regenerate_vectors_only: If True, only regenerate vectors without regenerating summary content + """ + logger.info( + click.style( + f"Start regenerate summary index for dataset {dataset_id}, reason: {regenerate_reason}", + fg="green", + ) + ) + start_at = time.perf_counter() + + try: + with session_factory.create_session() as session: + dataset = session.query(Dataset).filter_by(id=dataset_id).first() + if not dataset: + logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red")) + return + + # Only regenerate summary index for high_quality indexing technique + if dataset.indexing_technique != "high_quality": + logger.info( + click.style( + f"Skipping summary regeneration for dataset {dataset_id}: " + f"indexing_technique is {dataset.indexing_technique}, not 'high_quality'", + fg="cyan", + ) + ) + return + + # Check if summary index is enabled (only for summary_model change) + # For embedding_model change, we still re-vectorize existing summaries even if setting is disabled + summary_index_setting = dataset.summary_index_setting + if not regenerate_vectors_only: + # For summary_model change, require summary_index_setting to be enabled + if not summary_index_setting or not summary_index_setting.get("enable"): + logger.info( + click.style( + f"Summary index is disabled for dataset {dataset_id}", + fg="cyan", + ) + ) + return + + total_segments_processed = 0 + total_segments_failed = 0 + + if regenerate_vectors_only: + # For embedding_model change: directly query all segments with existing summaries + # Don't require document indexing_status == "completed" + # Include summaries with status "completed" or "error" (if they have content) + segments_with_summaries = ( + session.query(DocumentSegment, DocumentSegmentSummary) + .join( + DocumentSegmentSummary, + DocumentSegment.id == DocumentSegmentSummary.chunk_id, + ) + .join( + DatasetDocument, + DocumentSegment.document_id == DatasetDocument.id, + ) + .where( + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.status == "completed", # Segment must be completed + DocumentSegment.enabled == True, + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.summary_content.isnot(None), # Must have summary content + # Include completed summaries or error summaries (with content) + or_( + DocumentSegmentSummary.status == "completed", + DocumentSegmentSummary.status == "error", + ), + DatasetDocument.enabled == True, # Document must be enabled + DatasetDocument.archived == False, # Document must not be archived + DatasetDocument.doc_form != "qa_model", # Skip qa_model documents + ) + .order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc()) + .all() + ) + + if not segments_with_summaries: + logger.info( + click.style( + f"No segments with summaries found for re-vectorization in dataset {dataset_id}", + fg="cyan", + ) + ) + return + + logger.info( + "Found %s segments with summaries for re-vectorization in dataset %s", + len(segments_with_summaries), + dataset_id, + ) + + # Group by document for logging + segments_by_document = defaultdict(list) + for segment, summary_record in segments_with_summaries: + segments_by_document[segment.document_id].append((segment, summary_record)) + + logger.info( + "Segments grouped into %s documents for re-vectorization", + len(segments_by_document), + ) + + for document_id, segment_summary_pairs in segments_by_document.items(): + logger.info( + "Re-vectorizing summaries for %s segments in document %s", + len(segment_summary_pairs), + document_id, + ) + + for segment, summary_record in segment_summary_pairs: + try: + # Delete old vector + if summary_record.summary_index_node_id: + try: + from core.rag.datasource.vdb.vector_factory import Vector + + vector = Vector(dataset) + vector.delete_by_ids([summary_record.summary_index_node_id]) + except Exception as e: + logger.warning( + "Failed to delete old summary vector for segment %s: %s", + segment.id, + str(e), + ) + + # Re-vectorize with new embedding model + SummaryIndexService.vectorize_summary(summary_record, segment, dataset) + session.commit() + total_segments_processed += 1 + + except Exception as e: + logger.error( + "Failed to re-vectorize summary for segment %s: %s", + segment.id, + str(e), + exc_info=True, + ) + total_segments_failed += 1 + # Update summary record with error status + summary_record.status = "error" + summary_record.error = f"Re-vectorization failed: {str(e)}" + session.add(summary_record) + session.commit() + continue + + else: + # For summary_model change: require document indexing_status == "completed" + # Get all documents with completed indexing status + dataset_documents = session.scalars( + select(DatasetDocument).where( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + ).all() + + if not dataset_documents: + logger.info( + click.style( + f"No documents found for summary regeneration in dataset {dataset_id}", + fg="cyan", + ) + ) + return + + logger.info( + "Found %s documents for summary regeneration in dataset %s", + len(dataset_documents), + dataset_id, + ) + + for dataset_document in dataset_documents: + # Skip qa_model documents + if dataset_document.doc_form == "qa_model": + continue + + try: + # Get all segments with existing summaries + segments = ( + session.query(DocumentSegment) + .join( + DocumentSegmentSummary, + DocumentSegment.id == DocumentSegmentSummary.chunk_id, + ) + .where( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + DocumentSegmentSummary.dataset_id == dataset_id, + ) + .order_by(DocumentSegment.position.asc()) + .all() + ) + + if not segments: + continue + + logger.info( + "Regenerating summaries for %s segments in document %s", + len(segments), + dataset_document.id, + ) + + for segment in segments: + summary_record = None + try: + # Get existing summary record + summary_record = ( + session.query(DocumentSegmentSummary) + .filter_by( + chunk_id=segment.id, + dataset_id=dataset_id, + ) + .first() + ) + + if not summary_record: + logger.warning("Summary record not found for segment %s, skipping", segment.id) + continue + + # Regenerate both summary content and vectors (for summary_model change) + SummaryIndexService.generate_and_vectorize_summary( + segment, dataset, summary_index_setting + ) + session.commit() + total_segments_processed += 1 + + except Exception as e: + logger.error( + "Failed to regenerate summary for segment %s: %s", + segment.id, + str(e), + exc_info=True, + ) + total_segments_failed += 1 + # Update summary record with error status + if summary_record: + summary_record.status = "error" + summary_record.error = f"Regeneration failed: {str(e)}" + session.add(summary_record) + session.commit() + continue + + except Exception as e: + logger.error( + "Failed to process document %s for summary regeneration: %s", + dataset_document.id, + str(e), + exc_info=True, + ) + continue + + end_at = time.perf_counter() + if regenerate_vectors_only: + logger.info( + click.style( + f"Summary re-vectorization completed for dataset {dataset_id}: " + f"{total_segments_processed} segments processed successfully, " + f"{total_segments_failed} segments failed, " + f"latency: {end_at - start_at:.2f}s", + fg="green", + ) + ) + else: + logger.info( + click.style( + f"Summary index regeneration completed for dataset {dataset_id}: " + f"{total_segments_processed} segments processed successfully, " + f"{total_segments_failed} segments failed, " + f"latency: {end_at - start_at:.2f}s", + fg="green", + ) + ) + + except Exception: + logger.exception("Regenerate summary index failed for dataset %s", dataset_id) diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index c3c255fb17..55259ab527 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -46,6 +46,21 @@ def remove_document_from_index_task(document_id: str): index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all() + + # Disable summary indexes for all segments in this document + from services.summary_index_service import SummaryIndexService + + segment_ids_list = [segment.id for segment in segments] + if segment_ids_list: + try: + SummaryIndexService.disable_summaries_for_segments( + dataset=dataset, + segment_ids=segment_ids_list, + disabled_by=document.disabled_by, + ) + except Exception as e: + logger.warning("Failed to disable summaries for document %s: %s", document.id, str(e)) + index_node_ids = [segment.index_node_id for segment in segments] if index_node_ids: try: diff --git a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py index fe0e03f7b8..a2bf10001a 100644 --- a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py +++ b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py @@ -1,3 +1,5 @@ +import uuid + from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector from core.rag.models.document import Document from tests.integration_tests.vdb.test_vector_store import ( @@ -18,6 +20,10 @@ class QdrantVectorTest(AbstractVectorTest): api_key="difyai123456", ), ) + # Additional doc IDs for multi-keyword search tests + self.doc_apple_id = "" + self.doc_banana_id = "" + self.doc_both_id = "" def search_by_vector(self): super().search_by_vector() @@ -27,6 +33,77 @@ class QdrantVectorTest(AbstractVectorTest): ) assert len(hits_by_vector) == 0 + def _create_document(self, content: str, doc_id: str) -> Document: + """Create a document with the given content and doc_id.""" + return Document( + page_content=content, + metadata={ + "doc_id": doc_id, + "doc_hash": doc_id, + "document_id": doc_id, + "dataset_id": self.dataset_id, + }, + ) + + def setup_multi_keyword_documents(self): + """Create test documents with different keyword combinations for multi-keyword search tests.""" + self.doc_apple_id = str(uuid.uuid4()) + self.doc_banana_id = str(uuid.uuid4()) + self.doc_both_id = str(uuid.uuid4()) + + documents = [ + self._create_document("This document contains apple only", self.doc_apple_id), + self._create_document("This document contains banana only", self.doc_banana_id), + self._create_document("This document contains both apple and banana", self.doc_both_id), + ] + embeddings = [self.example_embedding] * len(documents) + + self.vector.add_texts(documents=documents, embeddings=embeddings) + + def search_by_full_text_multi_keyword(self): + """Test multi-keyword search returns docs matching ANY keyword (OR logic).""" + # First verify single keyword searches work correctly + hits_apple = self.vector.search_by_full_text(query="apple", top_k=10) + apple_ids = {doc.metadata["doc_id"] for doc in hits_apple} + assert self.doc_apple_id in apple_ids, "Document with 'apple' should be found" + assert self.doc_both_id in apple_ids, "Document with 'apple and banana' should be found" + + hits_banana = self.vector.search_by_full_text(query="banana", top_k=10) + banana_ids = {doc.metadata["doc_id"] for doc in hits_banana} + assert self.doc_banana_id in banana_ids, "Document with 'banana' should be found" + assert self.doc_both_id in banana_ids, "Document with 'apple and banana' should be found" + + # Test multi-keyword search returns all matching documents + hits = self.vector.search_by_full_text(query="apple banana", top_k=10) + doc_ids = {doc.metadata["doc_id"] for doc in hits} + + assert self.doc_apple_id in doc_ids, "Document with 'apple' should be found in multi-keyword search" + assert self.doc_banana_id in doc_ids, "Document with 'banana' should be found in multi-keyword search" + assert self.doc_both_id in doc_ids, "Document with both keywords should be found" + # Expect 3 results: doc_apple (apple only), doc_banana (banana only), doc_both (contains both) + assert len(hits) == 3, f"Expected 3 documents, got {len(hits)}" + + # Test keyword order independence + hits_ba = self.vector.search_by_full_text(query="banana apple", top_k=10) + ids_ba = {doc.metadata["doc_id"] for doc in hits_ba} + assert doc_ids == ids_ba, "Keyword order should not affect search results" + + # Test no duplicates in results + doc_id_list = [doc.metadata["doc_id"] for doc in hits] + assert len(doc_id_list) == len(set(doc_id_list)), "Search results should not contain duplicates" + + def run_all_tests(self): + self.create_vector() + self.search_by_vector() + self.search_by_full_text() + self.text_exists() + self.get_ids_by_metadata_field() + # Multi-keyword search tests + self.setup_multi_keyword_documents() + self.search_by_full_text_multi_keyword() + # Cleanup - delete_vector() removes the entire collection + self.delete_vector() + def test_qdrant_vector(setup_mock_redis): QdrantVectorTest().run_all_tests() diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index f9e59a5f05..0792ada194 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -1,7 +1,9 @@ """Primarily used for testing merged cell scenarios""" +import io import os import tempfile +from pathlib import Path from types import SimpleNamespace from docx import Document @@ -56,6 +58,42 @@ def test_parse_row(): assert extractor._parse_row(row, {}, 3) == gt[idx] +def test_init_downloads_via_ssrf_proxy(monkeypatch): + doc = Document() + doc.add_paragraph("hello") + buf = io.BytesIO() + doc.save(buf) + docx_bytes = buf.getvalue() + + calls: list[tuple[str, object]] = [] + + class FakeResponse: + status_code = 200 + content = docx_bytes + + def close(self) -> None: + calls.append(("close", None)) + + def fake_get(url: str, **kwargs): + calls.append(("get", (url, kwargs))) + return FakeResponse() + + monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get)) + + extractor = WordExtractor("https://example.com/test.docx", "tenant_id", "user_id") + try: + assert calls + assert calls[0][0] == "get" + url, kwargs = calls[0][1] + assert url == "https://example.com/test.docx" + assert kwargs.get("timeout") is None + assert extractor.web_path == "https://example.com/test.docx" + assert extractor.file_path != extractor.web_path + assert Path(extractor.file_path).read_bytes() == docx_bytes + finally: + extractor.temp_file.close() + + def test_extract_images_from_docx(monkeypatch): external_bytes = b"ext-bytes" internal_bytes = b"int-bytes" diff --git a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py index 0aabe2fc30..08818945e3 100644 --- a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py @@ -138,6 +138,7 @@ class TestDatasetServiceUpdateDataset: "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task, + patch("services.dataset_service.regenerate_summary_index_task") as mock_regenerate_task, patch( "services.dataset_service.current_user", create_autospec(Account, instance=True) ) as mock_current_user, @@ -147,6 +148,7 @@ class TestDatasetServiceUpdateDataset: "model_manager": mock_model_manager, "get_binding": mock_get_binding, "task": mock_task, + "regenerate_task": mock_regenerate_task, "current_user": mock_current_user, } @@ -549,6 +551,13 @@ class TestDatasetServiceUpdateDataset: # Verify vector index task was triggered mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "update") + # Verify regenerate summary index task was triggered (when embedding_model changes) + mock_internal_provider_dependencies["regenerate_task"].delay.assert_called_once_with( + "dataset-123", + regenerate_reason="embedding_model_changed", + regenerate_vectors_only=True, + ) + # Verify return value assert result == dataset diff --git a/api/ty.toml b/api/ty.toml index bb4ff5bbcf..640ed6cdee 100644 --- a/api/ty.toml +++ b/api/ty.toml @@ -1,11 +1,33 @@ [src] exclude = [ - # TODO: enable when violations fixed + # deps groups (A1/A2/B/C/D/E) + # A1: foundational runtime typing / provider plumbing + "core/mcp/session", + "core/model_runtime/model_providers", + "core/workflow/nodes/protocols.py", + "libs/gmpy2_pkcs10aep_cipher.py", + # A2: workflow engine/nodes + "core/workflow", + "core/app/workflow", + "core/helper/code_executor", + # B: app runner + prompt + "core/prompt", + "core/app/apps/base_app_runner.py", "core/app/apps/workflow_app_runner.py", + # C: services/controllers/fields/libs + "services", "controllers/console/app", "controllers/console/explore", "controllers/console/datasets", "controllers/console/workspace", + "controllers/service_api/wraps.py", + "fields/conversation_fields.py", + "libs/external_api.py", + # D: observability + integrations + "core/ops", + "extensions", + # E: vector DB integrations + "core/rag/datasource/vdb", # non-producition or generated code "migrations", "tests", diff --git a/docker/.env.example b/docker/.env.example index c7246ae11f..41a0205bf5 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -397,7 +397,7 @@ WEB_API_CORS_ALLOW_ORIGINS=* # Specifies the allowed origins for cross-origin requests to the console API, # e.g. https://cloud.dify.ai or * for all origins. CONSOLE_CORS_ALLOW_ORIGINS=* -# When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the siteโ€™s top-level domain (e.g., `example.com`). Leading dots are optional. +# When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site's top-level domain (e.g., `example.com`). Leading dots are optional. COOKIE_DOMAIN= # When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1. NEXT_PUBLIC_COOKIE_DOMAIN= @@ -1080,7 +1080,7 @@ ALIYUN_SLS_ENDPOINT= ALIYUN_SLS_REGION= # Aliyun SLS Project Name ALIYUN_SLS_PROJECT_NAME= -# Number of days to retain workflow run logs (default: 365 days๏ผŒ 3650 for permanent storage) +# Number of days to retain workflow run logs (default: 365 days, 3650 for permanent storage) ALIYUN_SLS_LOGSTORE_TTL=365 # Enable dual-write to both SLS LogStore and SQL database (default: false) LOGSTORE_DUAL_WRITE_ENABLED=false @@ -1375,6 +1375,7 @@ PLUGIN_DAEMON_PORT=5002 PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi PLUGIN_DAEMON_URL=http://plugin_daemon:5002 PLUGIN_MAX_PACKAGE_SIZE=52428800 +PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600 PLUGIN_PPROF_ENABLED=false PLUGIN_DEBUGGING_HOST=0.0.0.0 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 902ca3103c..2e97891a60 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -589,6 +589,7 @@ x-shared-env: &shared-api-worker-env PLUGIN_DAEMON_KEY: ${PLUGIN_DAEMON_KEY:-lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi} PLUGIN_DAEMON_URL: ${PLUGIN_DAEMON_URL:-http://plugin_daemon:5002} PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} + PLUGIN_MODEL_SCHEMA_CACHE_TTL: ${PLUGIN_MODEL_SCHEMA_CACHE_TTL:-3600} PLUGIN_PPROF_ENABLED: ${PLUGIN_PPROF_ENABLED:-false} PLUGIN_DEBUGGING_HOST: ${PLUGIN_DEBUGGING_HOST:-0.0.0.0} PLUGIN_DEBUGGING_PORT: ${PLUGIN_DEBUGGING_PORT:-5003} diff --git a/docker/generate_docker_compose b/docker/generate_docker_compose index b5c0acefb1..bf6c1423c9 100755 --- a/docker/generate_docker_compose +++ b/docker/generate_docker_compose @@ -9,7 +9,7 @@ def parse_env_example(file_path): Parses the .env.example file and returns a dictionary with variable names as keys and default values as values. """ env_vars = {} - with open(file_path, "r") as f: + with open(file_path, "r", encoding="utf-8") as f: for line_number, line in enumerate(f, 1): line = line.strip() # Ignore empty lines and comments @@ -55,7 +55,7 @@ def insert_shared_env(template_path, output_path, shared_env_block, header_comme Inserts the shared environment variables block and header comments into the template file, removing any existing x-shared-env anchors, and generates the final docker-compose.yaml file. """ - with open(template_path, "r") as f: + with open(template_path, "r", encoding="utf-8") as f: template_content = f.read() # Remove existing x-shared-env: &shared-api-worker-env lines @@ -69,7 +69,7 @@ def insert_shared_env(template_path, output_path, shared_env_block, header_comme # Prepare the final content with header comments and shared env block final_content = f"{header_comments}\n{shared_env_block}\n\n{template_content}" - with open(output_path, "w") as f: + with open(output_path, "w", encoding="utf-8") as f: f.write(final_content) print(f"Generated {output_path}") diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index 255feaccdf..aa31f0201f 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -31,6 +31,7 @@ import { fetchWorkflowDraft } from '@/service/workflow' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' +import { downloadBlob } from '@/utils/download' import AppIcon from '../base/app-icon' import AppOperations from './app-operations' @@ -145,13 +146,8 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx appID: appDetail.id, include, }) - const a = document.createElement('a') const file = new Blob([data], { type: 'application/yaml' }) - const url = URL.createObjectURL(file) - a.href = url - a.download = `${appDetail.name}.yml` - a.click() - URL.revokeObjectURL(url) + downloadBlob({ data: file, fileName: `${appDetail.name}.yml` }) } catch { notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) diff --git a/web/app/components/app-sidebar/dataset-info/dropdown.tsx b/web/app/components/app-sidebar/dataset-info/dropdown.tsx index 4d7c832e04..96127c4210 100644 --- a/web/app/components/app-sidebar/dataset-info/dropdown.tsx +++ b/web/app/components/app-sidebar/dataset-info/dropdown.tsx @@ -11,6 +11,7 @@ import { datasetDetailQueryKeyPrefix, useInvalidDatasetList } from '@/service/kn import { useInvalid } from '@/service/use-base' import { useExportPipelineDSL } from '@/service/use-pipeline' import { cn } from '@/utils/classnames' +import { downloadBlob } from '@/utils/download' import ActionButton from '../../base/action-button' import Confirm from '../../base/confirm' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem' @@ -64,13 +65,8 @@ const DropDown = ({ pipelineId: pipeline_id, include, }) - const a = document.createElement('a') const file = new Blob([data], { type: 'application/yaml' }) - const url = URL.createObjectURL(file) - a.href = url - a.download = `${name}.pipeline` - a.click() - URL.revokeObjectURL(url) + downloadBlob({ data: file, fileName: `${name}.pipeline` }) } catch { Toast.notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) diff --git a/web/app/components/app/annotation/header-opts/index.tsx b/web/app/components/app/annotation/header-opts/index.tsx index 5add1aed32..4fc1e26007 100644 --- a/web/app/components/app/annotation/header-opts/index.tsx +++ b/web/app/components/app/annotation/header-opts/index.tsx @@ -21,6 +21,7 @@ import { LanguagesSupported } from '@/i18n-config/language' import { clearAllAnnotations, fetchExportAnnotationList } from '@/service/annotation' import { cn } from '@/utils/classnames' +import { downloadBlob } from '@/utils/download' import Button from '../../../base/button' import AddAnnotationModal from '../add-annotation-modal' import BatchAddModal from '../batch-add-annotation-modal' @@ -56,28 +57,23 @@ const HeaderOptions: FC = ({ ) const JSONLOutput = () => { - const a = document.createElement('a') const content = listTransformer(list).join('\n') const file = new Blob([content], { type: 'application/jsonl' }) - const url = URL.createObjectURL(file) - a.href = url - a.download = `annotations-${locale}.jsonl` - a.click() - URL.revokeObjectURL(url) + downloadBlob({ data: file, fileName: `annotations-${locale}.jsonl` }) } - const fetchList = async () => { + const fetchList = React.useCallback(async () => { const { data }: any = await fetchExportAnnotationList(appId) setList(data as AnnotationItemBasic[]) - } + }, [appId]) useEffect(() => { fetchList() - }, []) + }, [fetchList]) useEffect(() => { if (controlUpdateList) fetchList() - }, [controlUpdateList]) + }, [controlUpdateList, fetchList]) const [showBulkImportModal, setShowBulkImportModal] = useState(false) const [showClearConfirm, setShowClearConfirm] = useState(false) diff --git a/web/app/components/app/configuration/config-var/index.spec.tsx b/web/app/components/app/configuration/config-var/index.spec.tsx index b5015ed079..490d7b4410 100644 --- a/web/app/components/app/configuration/config-var/index.spec.tsx +++ b/web/app/components/app/configuration/config-var/index.spec.tsx @@ -2,7 +2,7 @@ import type { ReactNode } from 'react' import type { IConfigVarProps } from './index' import type { ExternalDataTool } from '@/models/common' import type { PromptVariable } from '@/models/debug' -import { act, fireEvent, render, screen } from '@testing-library/react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' import { vi } from 'vitest' import Toast from '@/app/components/base/toast' @@ -240,7 +240,9 @@ describe('ConfigVar', () => { const saveButton = await screen.findByRole('button', { name: 'common.operation.save' }) fireEvent.click(saveButton) - expect(onPromptVariablesChange).toHaveBeenCalledTimes(1) + await waitFor(() => { + expect(onPromptVariablesChange).toHaveBeenCalledTimes(1) + }) }) it('should show error when variable key is duplicated', async () => { diff --git a/web/app/components/app/configuration/config/automatic/automatic-btn.spec.tsx b/web/app/components/app/configuration/config/automatic/automatic-btn.spec.tsx new file mode 100644 index 0000000000..f027f643a7 --- /dev/null +++ b/web/app/components/app/configuration/config/automatic/automatic-btn.spec.tsx @@ -0,0 +1,77 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import AutomaticBtn from './automatic-btn' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +describe('AutomaticBtn', () => { + const mockOnClick = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render the button with correct text', () => { + render() + + expect(screen.getByText('operation.automatic')).toBeInTheDocument() + }) + + it('should render the sparkling icon', () => { + const { container } = render() + + // The icon should be an SVG element inside the button + const svg = container.querySelector('svg') + expect(svg).toBeTruthy() + }) + + it('should render as a button element', () => { + render() + + expect(screen.getByRole('button')).toBeInTheDocument() + }) + }) + + describe('User Interactions', () => { + it('should call onClick when button is clicked', () => { + render() + + const button = screen.getByRole('button') + fireEvent.click(button) + + expect(mockOnClick).toHaveBeenCalledTimes(1) + }) + + it('should call onClick multiple times on multiple clicks', () => { + render() + + const button = screen.getByRole('button') + + fireEvent.click(button) + fireEvent.click(button) + fireEvent.click(button) + + expect(mockOnClick).toHaveBeenCalledTimes(3) + }) + }) + + describe('Styling', () => { + it('should have secondary-accent variant', () => { + render() + + const button = screen.getByRole('button') + expect(button.className).toContain('secondary-accent') + }) + + it('should have small size', () => { + render() + + const button = screen.getByRole('button') + expect(button.className).toContain('small') + }) + }) +}) diff --git a/web/app/components/app/log/empty-element.spec.tsx b/web/app/components/app/log/empty-element.spec.tsx new file mode 100644 index 0000000000..71d2bd0dd2 --- /dev/null +++ b/web/app/components/app/log/empty-element.spec.tsx @@ -0,0 +1,134 @@ +import type { App } from '@/types/app' +import { render, screen } from '@testing-library/react' +import { AppModeEnum } from '@/types/app' +import EmptyElement from './empty-element' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), + Trans: ({ i18nKey, components }: { i18nKey: string, components: Record }) => ( + + {i18nKey} + {components.shareLink} + {components.testLink} + + ), +})) + +vi.mock('@/utils/app-redirection', () => ({ + getRedirectionPath: (isTest: boolean, _app: App) => isTest ? '/test-path' : '/prod-path', +})) + +vi.mock('@/utils/var', () => ({ + basePath: '/base', +})) + +describe('EmptyElement', () => { + const createMockAppDetail = (mode: AppModeEnum) => ({ + id: 'test-app-id', + name: 'Test App', + description: 'Test description', + mode, + icon_type: 'emoji', + icon: 'test-icon', + icon_background: '#ffffff', + enable_site: true, + enable_api: true, + created_at: Date.now(), + site: { + access_token: 'test-token', + app_base_url: 'https://app.example.com', + }, + }) as unknown as App + + describe('Rendering', () => { + it('should render empty element with title', () => { + const appDetail = createMockAppDetail(AppModeEnum.CHAT) + render() + + expect(screen.getByText('table.empty.element.title')).toBeInTheDocument() + }) + + it('should render Trans component with i18n key', () => { + const appDetail = createMockAppDetail(AppModeEnum.CHAT) + render() + + const transComponent = screen.getByTestId('trans-component') + expect(transComponent).toHaveAttribute('data-i18n-key', 'table.empty.element.content') + }) + + it('should render ThreeDotsIcon SVG', () => { + const appDetail = createMockAppDetail(AppModeEnum.CHAT) + const { container } = render() + + const svg = container.querySelector('svg') + expect(svg).toBeInTheDocument() + }) + }) + + describe('App Mode Handling', () => { + it('should use CHAT mode for chat apps', () => { + const appDetail = createMockAppDetail(AppModeEnum.CHAT) + render() + + const link = screen.getAllByRole('link')[0] + expect(link).toHaveAttribute('href', 'https://app.example.com/base/chat/test-token') + }) + + it('should use COMPLETION mode for completion apps', () => { + const appDetail = createMockAppDetail(AppModeEnum.COMPLETION) + render() + + const link = screen.getAllByRole('link')[0] + expect(link).toHaveAttribute('href', 'https://app.example.com/base/completion/test-token') + }) + + it('should use WORKFLOW mode for workflow apps', () => { + const appDetail = createMockAppDetail(AppModeEnum.WORKFLOW) + render() + + const link = screen.getAllByRole('link')[0] + expect(link).toHaveAttribute('href', 'https://app.example.com/base/workflow/test-token') + }) + + it('should use CHAT mode for advanced-chat apps', () => { + const appDetail = createMockAppDetail(AppModeEnum.ADVANCED_CHAT) + render() + + const link = screen.getAllByRole('link')[0] + expect(link).toHaveAttribute('href', 'https://app.example.com/base/chat/test-token') + }) + + it('should use CHAT mode for agent-chat apps', () => { + const appDetail = createMockAppDetail(AppModeEnum.AGENT_CHAT) + render() + + const link = screen.getAllByRole('link')[0] + expect(link).toHaveAttribute('href', 'https://app.example.com/base/chat/test-token') + }) + }) + + describe('Links', () => { + it('should render share link with correct attributes', () => { + const appDetail = createMockAppDetail(AppModeEnum.CHAT) + render() + + const links = screen.getAllByRole('link') + const shareLink = links[0] + + expect(shareLink).toHaveAttribute('target', '_blank') + expect(shareLink).toHaveAttribute('rel', 'noopener noreferrer') + }) + + it('should render test link with redirection path', () => { + const appDetail = createMockAppDetail(AppModeEnum.CHAT) + render() + + const links = screen.getAllByRole('link') + const testLink = links[1] + + expect(testLink).toHaveAttribute('href', '/test-path') + }) + }) +}) diff --git a/web/app/components/app/log/filter.spec.tsx b/web/app/components/app/log/filter.spec.tsx new file mode 100644 index 0000000000..8e978cdf9e --- /dev/null +++ b/web/app/components/app/log/filter.spec.tsx @@ -0,0 +1,210 @@ +import type { QueryParam } from './index' +import { fireEvent, render, screen } from '@testing-library/react' +import Filter, { TIME_PERIOD_MAPPING } from './filter' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string, options?: { count?: number }) => { + if (options?.count !== undefined) + return `${key} (${options.count})` + return key + }, + }), +})) + +vi.mock('@/service/use-log', () => ({ + useAnnotationsCount: () => ({ + data: { count: 10 }, + isLoading: false, + }), +})) + +describe('Filter', () => { + const defaultQueryParams: QueryParam = { + period: '9', + annotation_status: 'all', + keyword: '', + } + + const mockSetQueryParams = vi.fn() + const defaultProps = { + appId: 'test-app-id', + queryParams: defaultQueryParams, + setQueryParams: mockSetQueryParams, + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render filter components', () => { + render() + + expect(screen.getByPlaceholderText('operation.search')).toBeInTheDocument() + }) + + it('should return null when loading', () => { + // This test verifies the component renders correctly with the mocked data + const { container } = render() + expect(container.firstChild).not.toBeNull() + }) + + it('should render sort component in chat mode', () => { + render() + + expect(screen.getByPlaceholderText('operation.search')).toBeInTheDocument() + }) + + it('should not render sort component when not in chat mode', () => { + render() + + expect(screen.getByPlaceholderText('operation.search')).toBeInTheDocument() + }) + }) + + describe('TIME_PERIOD_MAPPING', () => { + it('should have correct period keys', () => { + expect(Object.keys(TIME_PERIOD_MAPPING)).toEqual(['1', '2', '3', '4', '5', '6', '7', '8', '9']) + }) + + it('should have today period with value 0', () => { + expect(TIME_PERIOD_MAPPING['1'].value).toBe(0) + expect(TIME_PERIOD_MAPPING['1'].name).toBe('today') + }) + + it('should have last7days period with value 7', () => { + expect(TIME_PERIOD_MAPPING['2'].value).toBe(7) + expect(TIME_PERIOD_MAPPING['2'].name).toBe('last7days') + }) + + it('should have last4weeks period with value 28', () => { + expect(TIME_PERIOD_MAPPING['3'].value).toBe(28) + expect(TIME_PERIOD_MAPPING['3'].name).toBe('last4weeks') + }) + + it('should have allTime period with value -1', () => { + expect(TIME_PERIOD_MAPPING['9'].value).toBe(-1) + expect(TIME_PERIOD_MAPPING['9'].name).toBe('allTime') + }) + }) + + describe('User Interactions', () => { + it('should update keyword when typing in search input', () => { + render() + + const searchInput = screen.getByPlaceholderText('operation.search') + fireEvent.change(searchInput, { target: { value: 'test search' } }) + + expect(mockSetQueryParams).toHaveBeenCalledWith({ + ...defaultQueryParams, + keyword: 'test search', + }) + }) + + it('should clear keyword when clear button is clicked', () => { + const propsWithKeyword = { + ...defaultProps, + queryParams: { ...defaultQueryParams, keyword: 'existing search' }, + } + + render() + + const clearButton = screen.getByTestId('input-clear') + fireEvent.click(clearButton) + + expect(mockSetQueryParams).toHaveBeenCalledWith({ + ...defaultQueryParams, + keyword: '', + }) + }) + }) + + describe('Query Params', () => { + it('should display "today" when period is set to 1', () => { + const propsWithPeriod = { + ...defaultProps, + queryParams: { ...defaultQueryParams, period: '1' }, + } + + render() + + // Period '1' maps to 'today' in TIME_PERIOD_MAPPING + expect(screen.getByText('filter.period.today')).toBeInTheDocument() + }) + + it('should display "last7days" when period is set to 2', () => { + const propsWithPeriod = { + ...defaultProps, + queryParams: { ...defaultQueryParams, period: '2' }, + } + + render() + + expect(screen.getByText('filter.period.last7days')).toBeInTheDocument() + }) + + it('should display "allTime" when period is set to 9', () => { + render() + + // Default period is '9' which maps to 'allTime' + expect(screen.getByText('filter.period.allTime')).toBeInTheDocument() + }) + + it('should display annotated status with count when annotation_status is annotated', () => { + const propsWithAnnotation = { + ...defaultProps, + queryParams: { ...defaultQueryParams, annotation_status: 'annotated' }, + } + + render() + + // The mock returns count: 10, so the text should include the count + expect(screen.getByText('filter.annotation.annotated (10)')).toBeInTheDocument() + }) + + it('should display not_annotated status when annotation_status is not_annotated', () => { + const propsWithNotAnnotated = { + ...defaultProps, + queryParams: { ...defaultQueryParams, annotation_status: 'not_annotated' }, + } + + render() + + expect(screen.getByText('filter.annotation.not_annotated')).toBeInTheDocument() + }) + + it('should display all annotation status when annotation_status is all', () => { + render() + + // Default annotation_status is 'all' + expect(screen.getByText('filter.annotation.all')).toBeInTheDocument() + }) + }) + + describe('Chat Mode', () => { + it('should display sort component with sort_by parameter', () => { + const propsWithSort = { + ...defaultProps, + isChatMode: true, + queryParams: { ...defaultQueryParams, sort_by: 'created_at' }, + } + + render() + + expect(screen.getByPlaceholderText('operation.search')).toBeInTheDocument() + }) + + it('should handle descending sort order', () => { + const propsWithDescSort = { + ...defaultProps, + isChatMode: true, + queryParams: { ...defaultQueryParams, sort_by: '-created_at' }, + } + + render() + + expect(screen.getByPlaceholderText('operation.search')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/log/model-info.spec.tsx b/web/app/components/app/log/model-info.spec.tsx new file mode 100644 index 0000000000..c8263c2360 --- /dev/null +++ b/web/app/components/app/log/model-info.spec.tsx @@ -0,0 +1,221 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import ModelInfo from './model-info' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + useTextGenerationCurrentProviderAndModelAndModelList: () => ({ + currentModel: { + model: 'gpt-4', + model_display_name: 'GPT-4', + }, + currentProvider: { + provider: 'openai', + label: 'OpenAI', + }, + }), +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/model-icon', () => ({ + default: ({ modelName }: { provider: unknown, modelName: string }) => ( +
ModelIcon
+ ), +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/model-name', () => ({ + default: ({ modelItem, showMode }: { modelItem: { model: string }, showMode: boolean }) => ( +
+ {modelItem?.model} +
+ ), +})) + +vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ + PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => ( +
{children}
+ ), + PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick: () => void }) => ( +
{children}
+ ), + PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), +})) + +describe('ModelInfo', () => { + const defaultModel = { + name: 'gpt-4', + provider: 'openai', + completion_params: { + temperature: 0.7, + top_p: 0.9, + presence_penalty: 0.1, + max_tokens: 2048, + stop: ['END'], + }, + } + + describe('Rendering', () => { + it('should render model icon', () => { + render() + + expect(screen.getByTestId('model-icon')).toBeInTheDocument() + }) + + it('should render model name', () => { + render() + + expect(screen.getByTestId('model-name')).toBeInTheDocument() + expect(screen.getByTestId('model-name')).toHaveTextContent('gpt-4') + }) + + it('should render info icon button', () => { + const { container } = render() + + // The info button should contain an SVG icon + const svgs = container.querySelectorAll('svg') + expect(svgs.length).toBeGreaterThan(0) + }) + + it('should show model name with showMode prop', () => { + render() + + expect(screen.getByTestId('model-name')).toHaveAttribute('data-show-mode', 'true') + }) + }) + + describe('Info Panel Toggle', () => { + it('should be closed by default', () => { + render() + + expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'false') + }) + + it('should open when info button is clicked', () => { + render() + + const trigger = screen.getByTestId('portal-trigger') + fireEvent.click(trigger) + + expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'true') + }) + + it('should close when info button is clicked again', () => { + render() + + const trigger = screen.getByTestId('portal-trigger') + + // Open + fireEvent.click(trigger) + expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'true') + + // Close + fireEvent.click(trigger) + expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'false') + }) + }) + + describe('Model Parameters Display', () => { + it('should render model params header', () => { + render() + + expect(screen.getByText('detail.modelParams')).toBeInTheDocument() + }) + + it('should render temperature parameter', () => { + render() + + expect(screen.getByText('Temperature')).toBeInTheDocument() + expect(screen.getByText('0.7')).toBeInTheDocument() + }) + + it('should render top_p parameter', () => { + render() + + expect(screen.getByText('Top P')).toBeInTheDocument() + expect(screen.getByText('0.9')).toBeInTheDocument() + }) + + it('should render presence_penalty parameter', () => { + render() + + expect(screen.getByText('Presence Penalty')).toBeInTheDocument() + expect(screen.getByText('0.1')).toBeInTheDocument() + }) + + it('should render max_tokens parameter', () => { + render() + + expect(screen.getByText('Max Token')).toBeInTheDocument() + expect(screen.getByText('2048')).toBeInTheDocument() + }) + + it('should render stop parameter as comma-separated values', () => { + render() + + expect(screen.getByText('Stop')).toBeInTheDocument() + expect(screen.getByText('END')).toBeInTheDocument() + }) + }) + + describe('Missing Parameters', () => { + it('should show dash for missing parameters', () => { + const modelWithNoParams = { + name: 'gpt-4', + provider: 'openai', + completion_params: {}, + } + + render() + + const dashes = screen.getAllByText('-') + expect(dashes.length).toBeGreaterThan(0) + }) + + it('should show dash for non-array stop values', () => { + const modelWithInvalidStop = { + name: 'gpt-4', + provider: 'openai', + completion_params: { + stop: 'not-an-array', + }, + } + + render() + + const stopValues = screen.getAllByText('-') + expect(stopValues.length).toBeGreaterThan(0) + }) + + it('should join array stop values with comma', () => { + const modelWithMultipleStops = { + name: 'gpt-4', + provider: 'openai', + completion_params: { + stop: ['END', 'STOP', 'DONE'], + }, + } + + render() + + expect(screen.getByText('END,STOP,DONE')).toBeInTheDocument() + }) + }) + + describe('Model without completion_params', () => { + it('should handle undefined completion_params', () => { + const modelWithNoCompletionParams = { + name: 'gpt-4', + provider: 'openai', + } + + render() + + expect(screen.getByTestId('model-icon')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app/log/var-panel.spec.tsx b/web/app/components/app/log/var-panel.spec.tsx new file mode 100644 index 0000000000..eff186e5b9 --- /dev/null +++ b/web/app/components/app/log/var-panel.spec.tsx @@ -0,0 +1,217 @@ +import { act, fireEvent, render, screen } from '@testing-library/react' +import VarPanel from './var-panel' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +vi.mock('@/app/components/base/image-uploader/image-preview', () => ({ + default: ({ url, title, onCancel }: { url: string, title: string, onCancel: () => void }) => ( +
+ +
+ ), +})) + +describe('VarPanel', () => { + const defaultProps = { + varList: [ + { label: 'name', value: 'John Doe' }, + { label: 'age', value: '25' }, + ], + message_files: [], + } + + describe('Rendering', () => { + it('should render variables section header', () => { + render() + + expect(screen.getByText('detail.variables')).toBeInTheDocument() + }) + + it('should render variable labels with braces', () => { + render() + + expect(screen.getByText('name')).toBeInTheDocument() + expect(screen.getByText('age')).toBeInTheDocument() + }) + + it('should render variable values', () => { + render() + + expect(screen.getByText('John Doe')).toBeInTheDocument() + expect(screen.getByText('25')).toBeInTheDocument() + }) + + it('should render opening and closing braces', () => { + render() + + const openingBraces = screen.getAllByText('{{') + const closingBraces = screen.getAllByText('}}') + + expect(openingBraces.length).toBe(2) + expect(closingBraces.length).toBe(2) + }) + + it('should render Variable02 icon', () => { + const { container } = render() + + const svg = container.querySelector('svg') + expect(svg).toBeInTheDocument() + }) + }) + + describe('Collapse/Expand', () => { + it('should show expanded state by default', () => { + render() + + expect(screen.getByText('John Doe')).toBeInTheDocument() + expect(screen.getByText('25')).toBeInTheDocument() + }) + + it('should collapse when header is clicked', () => { + render() + + const header = screen.getByText('detail.variables').closest('div') + fireEvent.click(header!) + + expect(screen.queryByText('John Doe')).not.toBeInTheDocument() + expect(screen.queryByText('25')).not.toBeInTheDocument() + }) + + it('should expand when clicked again', () => { + render() + + const header = screen.getByText('detail.variables').closest('div') + + // Collapse + fireEvent.click(header!) + expect(screen.queryByText('John Doe')).not.toBeInTheDocument() + + // Expand + fireEvent.click(header!) + expect(screen.getByText('John Doe')).toBeInTheDocument() + }) + + it('should show arrow icon when collapsed', () => { + const { container } = render() + + const header = screen.getByText('detail.variables').closest('div') + fireEvent.click(header!) + + // When collapsed, there should be SVG icons in the component + const svgs = container.querySelectorAll('svg') + expect(svgs.length).toBeGreaterThan(0) + }) + + it('should show arrow icon when expanded', () => { + const { container } = render() + + // When expanded, there should be SVG icons in the component + const svgs = container.querySelectorAll('svg') + expect(svgs.length).toBeGreaterThan(0) + }) + }) + + describe('Message Files', () => { + it('should not render images section when message_files is empty', () => { + render() + + expect(screen.queryByText('detail.uploadImages')).not.toBeInTheDocument() + }) + + it('should render images section when message_files has items', () => { + const propsWithFiles = { + ...defaultProps, + message_files: ['https://example.com/image1.jpg', 'https://example.com/image2.jpg'], + } + + render() + + expect(screen.getByText('detail.uploadImages')).toBeInTheDocument() + }) + + it('should render image thumbnails with correct background', () => { + const propsWithFiles = { + ...defaultProps, + message_files: ['https://example.com/image1.jpg'], + } + + const { container } = render() + + const thumbnail = container.querySelector('[style*="background-image"]') + expect(thumbnail).toBeInTheDocument() + expect(thumbnail).toHaveStyle({ backgroundImage: 'url(https://example.com/image1.jpg)' }) + }) + + it('should open image preview when thumbnail is clicked', () => { + const propsWithFiles = { + ...defaultProps, + message_files: ['https://example.com/image1.jpg'], + } + + const { container } = render() + + const thumbnail = container.querySelector('[style*="background-image"]') + fireEvent.click(thumbnail!) + + expect(screen.getByTestId('image-preview')).toBeInTheDocument() + expect(screen.getByTestId('image-preview')).toHaveAttribute('data-url', 'https://example.com/image1.jpg') + }) + + it('should close image preview when close button is clicked', () => { + const propsWithFiles = { + ...defaultProps, + message_files: ['https://example.com/image1.jpg'], + } + + const { container } = render() + + // Open preview + const thumbnail = container.querySelector('[style*="background-image"]') + fireEvent.click(thumbnail!) + + expect(screen.getByTestId('image-preview')).toBeInTheDocument() + + // Close preview + act(() => { + fireEvent.click(screen.getByTestId('close-preview')) + }) + + expect(screen.queryByTestId('image-preview')).not.toBeInTheDocument() + }) + }) + + describe('Empty State', () => { + it('should render with empty varList', () => { + const emptyProps = { + varList: [], + message_files: [], + } + + render() + + expect(screen.getByText('detail.variables')).toBeInTheDocument() + }) + }) + + describe('Multiple Images', () => { + it('should render multiple image thumbnails', () => { + const propsWithMultipleFiles = { + ...defaultProps, + message_files: [ + 'https://example.com/image1.jpg', + 'https://example.com/image2.jpg', + 'https://example.com/image3.jpg', + ], + } + + const { container } = render() + + const thumbnails = container.querySelectorAll('[style*="background-image"]') + expect(thumbnails.length).toBe(3) + }) + }) +}) diff --git a/web/app/components/app/overview/trigger-card.spec.tsx b/web/app/components/app/overview/trigger-card.spec.tsx new file mode 100644 index 0000000000..0ee9da582d --- /dev/null +++ b/web/app/components/app/overview/trigger-card.spec.tsx @@ -0,0 +1,390 @@ +import type { AppDetailResponse } from '@/models/app' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { AppModeEnum } from '@/types/app' +import TriggerCard from './trigger-card' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string, options?: { count?: number }) => { + if (options?.count !== undefined) + return `${key} (${options.count})` + return key + }, + }), +})) + +vi.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + isCurrentWorkspaceEditor: true, + }), +})) + +vi.mock('@/context/i18n', () => ({ + useDocLink: () => (path: string) => `https://docs.example.com${path}`, +})) + +const mockSetTriggerStatus = vi.fn() +const mockSetTriggerStatuses = vi.fn() +vi.mock('@/app/components/workflow/store/trigger-status', () => ({ + useTriggerStatusStore: () => ({ + setTriggerStatus: mockSetTriggerStatus, + setTriggerStatuses: mockSetTriggerStatuses, + }), +})) + +const mockUpdateTriggerStatus = vi.fn() +const mockInvalidateAppTriggers = vi.fn() +let mockTriggers: Array<{ + id: string + node_id: string + title: string + trigger_type: string + status: string + provider_name?: string +}> = [] +let mockIsLoading = false + +vi.mock('@/service/use-tools', () => ({ + useAppTriggers: () => ({ + data: { data: mockTriggers }, + isLoading: mockIsLoading, + }), + useUpdateTriggerStatus: () => ({ + mutateAsync: mockUpdateTriggerStatus, + }), + useInvalidateAppTriggers: () => mockInvalidateAppTriggers, +})) + +vi.mock('@/service/use-triggers', () => ({ + useAllTriggerPlugins: () => ({ + data: [ + { id: 'plugin-1', name: 'Test Plugin', icon: 'test-icon' }, + ], + }), +})) + +vi.mock('@/utils', () => ({ + canFindTool: () => false, +})) + +vi.mock('@/app/components/workflow/block-icon', () => ({ + default: ({ type }: { type: string }) => ( +
BlockIcon
+ ), +})) + +vi.mock('@/app/components/base/switch', () => ({ + default: ({ defaultValue, onChange, disabled }: { defaultValue: boolean, onChange: (v: boolean) => void, disabled: boolean }) => ( + + ), +})) + +describe('TriggerCard', () => { + const mockAppInfo = { + id: 'test-app-id', + name: 'Test App', + description: 'Test description', + mode: AppModeEnum.WORKFLOW, + icon_type: 'emoji', + icon: 'test-icon', + icon_background: '#ffffff', + created_at: Date.now(), + updated_at: Date.now(), + enable_site: true, + enable_api: true, + } as AppDetailResponse + + const mockOnToggleResult = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + mockTriggers = [] + mockIsLoading = false + mockUpdateTriggerStatus.mockResolvedValue({}) + }) + + describe('Loading State', () => { + it('should render loading skeleton when isLoading is true', () => { + mockIsLoading = true + + const { container } = render( + , + ) + + expect(container.querySelector('.animate-pulse')).toBeInTheDocument() + }) + }) + + describe('Empty State', () => { + it('should show no triggers added message when triggers is empty', () => { + mockTriggers = [] + + render() + + expect(screen.getByText('overview.triggerInfo.noTriggerAdded')).toBeInTheDocument() + }) + + it('should show trigger status description when no triggers', () => { + mockTriggers = [] + + render() + + expect(screen.getByText('overview.triggerInfo.triggerStatusDescription')).toBeInTheDocument() + }) + + it('should show learn more link when no triggers', () => { + mockTriggers = [] + + render() + + const learnMoreLink = screen.getByText('overview.triggerInfo.learnAboutTriggers') + expect(learnMoreLink).toBeInTheDocument() + expect(learnMoreLink).toHaveAttribute('href', 'https://docs.example.com/use-dify/nodes/trigger/overview') + }) + }) + + describe('With Triggers', () => { + beforeEach(() => { + mockTriggers = [ + { + id: 'trigger-1', + node_id: 'node-1', + title: 'Webhook Trigger', + trigger_type: 'trigger-webhook', + status: 'enabled', + }, + { + id: 'trigger-2', + node_id: 'node-2', + title: 'Schedule Trigger', + trigger_type: 'trigger-schedule', + status: 'disabled', + }, + ] + }) + + it('should show triggers count message', () => { + render() + + expect(screen.getByText('overview.triggerInfo.triggersAdded (2)')).toBeInTheDocument() + }) + + it('should render trigger titles', () => { + render() + + expect(screen.getByText('Webhook Trigger')).toBeInTheDocument() + expect(screen.getByText('Schedule Trigger')).toBeInTheDocument() + }) + + it('should show running status for enabled triggers', () => { + render() + + expect(screen.getByText('overview.status.running')).toBeInTheDocument() + }) + + it('should show disable status for disabled triggers', () => { + render() + + expect(screen.getByText('overview.status.disable')).toBeInTheDocument() + }) + + it('should render block icons for each trigger', () => { + render() + + const blockIcons = screen.getAllByTestId('block-icon') + expect(blockIcons.length).toBe(2) + }) + + it('should render switches for each trigger', () => { + render() + + const switches = screen.getAllByTestId('switch') + expect(switches.length).toBe(2) + }) + }) + + describe('Toggle Trigger', () => { + beforeEach(() => { + mockTriggers = [ + { + id: 'trigger-1', + node_id: 'node-1', + title: 'Test Trigger', + trigger_type: 'trigger-webhook', + status: 'disabled', + }, + ] + }) + + it('should call updateTriggerStatus when toggle is clicked', async () => { + render() + + const switchBtn = screen.getByTestId('switch') + fireEvent.click(switchBtn) + + await waitFor(() => { + expect(mockUpdateTriggerStatus).toHaveBeenCalledWith({ + appId: 'test-app-id', + triggerId: 'trigger-1', + enableTrigger: true, + }) + }) + }) + + it('should update trigger status in store optimistically', async () => { + render() + + const switchBtn = screen.getByTestId('switch') + fireEvent.click(switchBtn) + + await waitFor(() => { + expect(mockSetTriggerStatus).toHaveBeenCalledWith('node-1', 'enabled') + }) + }) + + it('should invalidate app triggers after successful update', async () => { + render() + + const switchBtn = screen.getByTestId('switch') + fireEvent.click(switchBtn) + + await waitFor(() => { + expect(mockInvalidateAppTriggers).toHaveBeenCalledWith('test-app-id') + }) + }) + + it('should call onToggleResult with null on success', async () => { + render() + + const switchBtn = screen.getByTestId('switch') + fireEvent.click(switchBtn) + + await waitFor(() => { + expect(mockOnToggleResult).toHaveBeenCalledWith(null) + }) + }) + + it('should rollback status and call onToggleResult with error on failure', async () => { + const error = new Error('Update failed') + mockUpdateTriggerStatus.mockRejectedValueOnce(error) + + render() + + const switchBtn = screen.getByTestId('switch') + fireEvent.click(switchBtn) + + await waitFor(() => { + expect(mockSetTriggerStatus).toHaveBeenCalledWith('node-1', 'disabled') + expect(mockOnToggleResult).toHaveBeenCalledWith(error) + }) + }) + }) + + describe('Trigger Types', () => { + it('should render webhook trigger type correctly', () => { + mockTriggers = [ + { + id: 'trigger-1', + node_id: 'node-1', + title: 'Webhook', + trigger_type: 'trigger-webhook', + status: 'enabled', + }, + ] + + render() + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-type', 'trigger-webhook') + }) + + it('should render schedule trigger type correctly', () => { + mockTriggers = [ + { + id: 'trigger-1', + node_id: 'node-1', + title: 'Schedule', + trigger_type: 'trigger-schedule', + status: 'enabled', + }, + ] + + render() + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-type', 'trigger-schedule') + }) + + it('should render plugin trigger type correctly', () => { + mockTriggers = [ + { + id: 'trigger-1', + node_id: 'node-1', + title: 'Plugin', + trigger_type: 'trigger-plugin', + status: 'enabled', + provider_name: 'plugin-1', + }, + ] + + render() + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-type', 'trigger-plugin') + }) + }) + + describe('Editor Permissions', () => { + it('should render switches for triggers', () => { + mockTriggers = [ + { + id: 'trigger-1', + node_id: 'node-1', + title: 'Test Trigger', + trigger_type: 'trigger-webhook', + status: 'enabled', + }, + ] + + render() + + const switchBtn = screen.getByTestId('switch') + expect(switchBtn).toBeInTheDocument() + }) + }) + + describe('Status Sync', () => { + it('should sync trigger statuses to store when data loads', () => { + mockTriggers = [ + { + id: 'trigger-1', + node_id: 'node-1', + title: 'Test', + trigger_type: 'trigger-webhook', + status: 'enabled', + }, + { + id: 'trigger-2', + node_id: 'node-2', + title: 'Test 2', + trigger_type: 'trigger-schedule', + status: 'disabled', + }, + ] + + render() + + expect(mockSetTriggerStatuses).toHaveBeenCalledWith({ + 'node-1': 'enabled', + 'node-2': 'disabled', + }) + }) + }) +}) diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index f1eadb9d05..730a39b68d 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -33,6 +33,7 @@ import { fetchWorkflowDraft } from '@/service/workflow' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' +import { downloadBlob } from '@/utils/download' import { formatTime } from '@/utils/time' import { basePath } from '@/utils/var' @@ -161,13 +162,8 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { appID: app.id, include, }) - const a = document.createElement('a') const file = new Blob([data], { type: 'application/yaml' }) - const url = URL.createObjectURL(file) - a.href = url - a.download = `${app.name}.yml` - a.click() - URL.revokeObjectURL(url) + downloadBlob({ data: file, fileName: `${app.name}.yml` }) } catch { notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) @@ -346,7 +342,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { dateFormat: `${t('segment.dateTimeFormat', { ns: 'datasetDocuments' })}`, }) return `${t('segment.editedAt', { ns: 'datasetDocuments' })} ${timeText}` - }, [app.updated_at, app.created_at]) + }, [app.updated_at, app.created_at, t]) return ( <> diff --git a/web/app/components/base/file-uploader/file-uploader-in-attachment/file-item.tsx b/web/app/components/base/file-uploader/file-uploader-in-attachment/file-item.tsx index 6ef5bcb308..f8015aa7c7 100644 --- a/web/app/components/base/file-uploader/file-uploader-in-attachment/file-item.tsx +++ b/web/app/components/base/file-uploader/file-uploader-in-attachment/file-item.tsx @@ -15,11 +15,11 @@ import ImagePreview from '@/app/components/base/image-uploader/image-preview' import ProgressCircle from '@/app/components/base/progress-bar/progress-circle' import { SupportUploadFileTypes } from '@/app/components/workflow/types' import { cn } from '@/utils/classnames' +import { downloadUrl } from '@/utils/download' import { formatFileSize } from '@/utils/format' import FileImageRender from '../file-image-render' import FileTypeIcon from '../file-type-icon' import { - downloadFile, fileIsUploaded, getFileAppearanceType, getFileExtension, @@ -140,7 +140,7 @@ const FileInAttachmentItem = ({ showDownloadAction && ( { e.stopPropagation() - downloadFile(url || base64Url || '', name) + downloadUrl({ url: url || base64Url || '', fileName: name, target: '_blank' }) }} > diff --git a/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-image-item.tsx b/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-image-item.tsx index 77dc3e35b8..d9118aac4f 100644 --- a/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-image-item.tsx +++ b/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-image-item.tsx @@ -8,9 +8,9 @@ import Button from '@/app/components/base/button' import { ReplayLine } from '@/app/components/base/icons/src/vender/other' import ImagePreview from '@/app/components/base/image-uploader/image-preview' import ProgressCircle from '@/app/components/base/progress-bar/progress-circle' +import { downloadUrl } from '@/utils/download' import FileImageRender from '../file-image-render' import { - downloadFile, fileIsUploaded, } from '../utils' @@ -85,7 +85,7 @@ const FileImageItem = ({ className="absolute bottom-0.5 right-0.5 flex h-6 w-6 items-center justify-center rounded-lg bg-components-actionbar-bg shadow-md" onClick={(e) => { e.stopPropagation() - downloadFile(download_url || '', name) + downloadUrl({ url: download_url || '', fileName: name, target: '_blank' }) }} > diff --git a/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx b/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx index 828864239a..af32f917b9 100644 --- a/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx +++ b/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx @@ -12,10 +12,10 @@ import VideoPreview from '@/app/components/base/file-uploader/video-preview' import { ReplayLine } from '@/app/components/base/icons/src/vender/other' import ProgressCircle from '@/app/components/base/progress-bar/progress-circle' import { cn } from '@/utils/classnames' +import { downloadUrl } from '@/utils/download' import { formatFileSize } from '@/utils/format' import FileTypeIcon from '../file-type-icon' import { - downloadFile, fileIsUploaded, getFileAppearanceType, getFileExtension, @@ -100,7 +100,7 @@ const FileItem = ({ className="absolute -right-1 -top-1 hidden group-hover/file-item:flex" onClick={(e) => { e.stopPropagation() - downloadFile(download_url || '', name) + downloadUrl({ url: download_url || '', fileName: name, target: '_blank' }) }} > diff --git a/web/app/components/base/file-uploader/utils.spec.ts b/web/app/components/base/file-uploader/utils.spec.ts index de167a8c25..f69b3c27f5 100644 --- a/web/app/components/base/file-uploader/utils.spec.ts +++ b/web/app/components/base/file-uploader/utils.spec.ts @@ -1,4 +1,3 @@ -import type { MockInstance } from 'vitest' import mime from 'mime' import { SupportUploadFileTypes } from '@/app/components/workflow/types' import { upload } from '@/service/base' @@ -6,7 +5,6 @@ import { TransferMethod } from '@/types/app' import { FILE_EXTS } from '../prompt-editor/constants' import { FileAppearanceTypeEnum } from './types' import { - downloadFile, fileIsUploaded, fileUpload, getFileAppearanceType, @@ -782,74 +780,4 @@ describe('file-uploader utils', () => { } as any)).toBe(true) }) }) - - describe('downloadFile', () => { - let mockAnchor: HTMLAnchorElement - let createElementMock: MockInstance - let appendChildMock: MockInstance - let removeChildMock: MockInstance - - beforeEach(() => { - // Mock createElement and appendChild - mockAnchor = { - href: '', - download: '', - style: { display: '' }, - target: '', - title: '', - click: vi.fn(), - } as unknown as HTMLAnchorElement - - createElementMock = vi.spyOn(document, 'createElement').mockReturnValue(mockAnchor as any) - appendChildMock = vi.spyOn(document.body, 'appendChild').mockImplementation((node: Node) => { - return node - }) - removeChildMock = vi.spyOn(document.body, 'removeChild').mockImplementation((node: Node) => { - return node - }) - }) - - afterEach(() => { - vi.resetAllMocks() - }) - - it('should create and trigger download with correct attributes', () => { - const url = 'https://example.com/test.pdf' - const filename = 'test.pdf' - - downloadFile(url, filename) - - // Verify anchor element was created with correct properties - expect(createElementMock).toHaveBeenCalledWith('a') - expect(mockAnchor.href).toBe(url) - expect(mockAnchor.download).toBe(filename) - expect(mockAnchor.style.display).toBe('none') - expect(mockAnchor.target).toBe('_blank') - expect(mockAnchor.title).toBe(filename) - - // Verify DOM operations - expect(appendChildMock).toHaveBeenCalledWith(mockAnchor) - expect(mockAnchor.click).toHaveBeenCalled() - expect(removeChildMock).toHaveBeenCalledWith(mockAnchor) - }) - - it('should handle empty filename', () => { - const url = 'https://example.com/test.pdf' - const filename = '' - - downloadFile(url, filename) - - expect(mockAnchor.download).toBe('') - expect(mockAnchor.title).toBe('') - }) - - it('should handle empty url', () => { - const url = '' - const filename = 'test.pdf' - - downloadFile(url, filename) - - expect(mockAnchor.href).toBe('') - }) - }) }) diff --git a/web/app/components/base/file-uploader/utils.ts b/web/app/components/base/file-uploader/utils.ts index 5d5754b8fe..23e460db51 100644 --- a/web/app/components/base/file-uploader/utils.ts +++ b/web/app/components/base/file-uploader/utils.ts @@ -249,15 +249,3 @@ export const fileIsUploaded = (file: FileEntity) => { if (file.transferMethod === TransferMethod.remote_url && file.progress === 100) return true } - -export const downloadFile = (url: string, filename: string) => { - const anchor = document.createElement('a') - anchor.href = url - anchor.download = filename - anchor.style.display = 'none' - anchor.target = '_blank' - anchor.title = filename - document.body.appendChild(anchor) - anchor.click() - document.body.removeChild(anchor) -} diff --git a/web/app/components/base/image-uploader/image-preview.tsx b/web/app/components/base/image-uploader/image-preview.tsx index b6a07c60aa..0641af3d79 100644 --- a/web/app/components/base/image-uploader/image-preview.tsx +++ b/web/app/components/base/image-uploader/image-preview.tsx @@ -8,6 +8,7 @@ import { createPortal } from 'react-dom' import { useHotkeys } from 'react-hotkeys-hook' import Toast from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' +import { downloadUrl } from '@/utils/download' type ImagePreviewProps = { url: string @@ -60,27 +61,14 @@ const ImagePreview: FC = ({ const downloadImage = () => { // Open in a new window, considering the case when the page is inside an iframe - if (url.startsWith('http') || url.startsWith('https')) { - const a = document.createElement('a') - a.href = url - a.target = '_blank' - a.download = title - a.click() - } - else if (url.startsWith('data:image')) { - // Base64 image - const a = document.createElement('a') - a.href = url - a.target = '_blank' - a.download = title - a.click() - } - else { - Toast.notify({ - type: 'error', - message: `Unable to open image: ${url}`, - }) + if (url.startsWith('http') || url.startsWith('https') || url.startsWith('data:image')) { + downloadUrl({ url, fileName: title, target: '_blank' }) + return } + Toast.notify({ + type: 'error', + message: `Unable to open image: ${url}`, + }) } const zoomIn = () => { @@ -135,12 +123,7 @@ const ImagePreview: FC = ({ catch (err) { console.error('Failed to copy image:', err) - const link = document.createElement('a') - link.href = url - link.download = `${title}.png` - document.body.appendChild(link) - link.click() - document.body.removeChild(link) + downloadUrl({ url, fileName: `${title}.png` }) Toast.notify({ type: 'info', @@ -215,6 +198,7 @@ const ImagePreview: FC = ({ tabIndex={-1} > { } + {/* eslint-disable-next-line next/no-img-element */} {title} { }, [isShow]) const downloadQR = () => { - const canvas = document.getElementsByTagName('canvas')[0] - const link = document.createElement('a') - link.download = 'qrcode.png' - link.href = canvas.toDataURL() - link.click() + const canvas = qrCodeRef.current?.querySelector('canvas') + if (!(canvas instanceof HTMLCanvasElement)) + return + downloadUrl({ url: canvas.toDataURL(), fileName: 'qrcode.png' }) } const handlePanelClick = (event: React.MouseEvent) => { diff --git a/web/app/components/billing/annotation-full/usage.spec.tsx b/web/app/components/billing/annotation-full/usage.spec.tsx new file mode 100644 index 0000000000..c5fd1a2b10 --- /dev/null +++ b/web/app/components/billing/annotation-full/usage.spec.tsx @@ -0,0 +1,57 @@ +import { render, screen } from '@testing-library/react' +import Usage from './usage' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +const mockPlan = { + usage: { + annotatedResponse: 50, + }, + total: { + annotatedResponse: 100, + }, +} + +vi.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + plan: mockPlan, + }), +})) + +describe('Usage', () => { + // Rendering: renders UsageInfo with correct props from context + describe('Rendering', () => { + it('should render usage info with data from provider context', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByText('annotatedResponse.quotaTitle')).toBeInTheDocument() + }) + + it('should pass className to UsageInfo component', () => { + // Arrange + const testClassName = 'mt-4' + + // Act + const { container } = render() + + // Assert + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass(testClassName) + }) + + it('should display usage and total values from context', () => { + // Arrange & Act + render() + + // Assert + expect(screen.getByText('50')).toBeInTheDocument() + expect(screen.getByText('100')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/billing/billing-page/index.spec.tsx b/web/app/components/billing/billing-page/index.spec.tsx index 8b68f74012..f80c688d47 100644 --- a/web/app/components/billing/billing-page/index.spec.tsx +++ b/web/app/components/billing/billing-page/index.spec.tsx @@ -73,6 +73,56 @@ describe('Billing', () => { }) }) + it('returns the refetched url from the async callback', async () => { + const newUrl = 'https://new-billing-url' + refetchMock.mockResolvedValue({ data: newUrl }) + render() + + const actionButton = screen.getByRole('button', { name: /billing\.viewBillingTitle/ }) + fireEvent.click(actionButton) + + await waitFor(() => expect(openAsyncWindowMock).toHaveBeenCalled()) + const [asyncCallback] = openAsyncWindowMock.mock.calls[0] + + // Execute the async callback passed to openAsyncWindow + const result = await asyncCallback() + expect(result).toBe(newUrl) + expect(refetchMock).toHaveBeenCalled() + }) + + it('returns null when refetch returns no url', async () => { + refetchMock.mockResolvedValue({ data: null }) + render() + + const actionButton = screen.getByRole('button', { name: /billing\.viewBillingTitle/ }) + fireEvent.click(actionButton) + + await waitFor(() => expect(openAsyncWindowMock).toHaveBeenCalled()) + const [asyncCallback] = openAsyncWindowMock.mock.calls[0] + + // Execute the async callback when url is null + const result = await asyncCallback() + expect(result).toBeNull() + }) + + it('handles errors in onError callback', async () => { + const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {}) + render() + + const actionButton = screen.getByRole('button', { name: /billing\.viewBillingTitle/ }) + fireEvent.click(actionButton) + + await waitFor(() => expect(openAsyncWindowMock).toHaveBeenCalled()) + const [, options] = openAsyncWindowMock.mock.calls[0] + + // Execute the onError callback + const testError = new Error('Test error') + options.onError(testError) + expect(consoleError).toHaveBeenCalledWith('Failed to fetch billing url', testError) + + consoleError.mockRestore() + }) + it('disables the button while billing url is fetching', () => { fetching = true render() diff --git a/web/app/components/billing/plan/index.spec.tsx b/web/app/components/billing/plan/index.spec.tsx index 473f81f9f4..fb1800653e 100644 --- a/web/app/components/billing/plan/index.spec.tsx +++ b/web/app/components/billing/plan/index.spec.tsx @@ -125,4 +125,70 @@ describe('PlanComp', () => { expect(setShowAccountSettingModalMock).toHaveBeenCalledWith(null) }) + + it('does not trigger verify when isPending is true', async () => { + isPending = true + render() + + const verifyBtn = screen.getByText('education.toVerified') + fireEvent.click(verifyBtn) + + await waitFor(() => expect(mutateAsyncMock).not.toHaveBeenCalled()) + }) + + it('renders sandbox plan', () => { + providerContextMock.mockReturnValue({ + plan: { ...planMock, type: Plan.sandbox }, + enableEducationPlan: false, + allowRefreshEducationVerify: false, + isEducationAccount: false, + }) + render() + + expect(screen.getByText('billing.plans.sandbox.name')).toBeInTheDocument() + }) + + it('renders team plan', () => { + providerContextMock.mockReturnValue({ + plan: { ...planMock, type: Plan.team }, + enableEducationPlan: false, + allowRefreshEducationVerify: false, + isEducationAccount: false, + }) + render() + + expect(screen.getByText('billing.plans.team.name')).toBeInTheDocument() + }) + + it('shows verify button when education account is about to expire', () => { + providerContextMock.mockReturnValue({ + plan: planMock, + enableEducationPlan: true, + allowRefreshEducationVerify: true, + isEducationAccount: true, + }) + render() + + expect(screen.getByText('education.toVerified')).toBeInTheDocument() + }) + + it('handles modal onConfirm and onCancel callbacks', async () => { + mutateAsyncMock.mockRejectedValueOnce(new Error('boom')) + render() + + // Trigger verify to show modal + const verifyBtn = screen.getByText('education.toVerified') + fireEvent.click(verifyBtn) + + await waitFor(() => expect(screen.getByTestId('verify-modal').getAttribute('data-is-show')).toBe('true')) + + // Get the props passed to the modal and call onConfirm/onCancel + const lastCall = verifyStateModalMock.mock.calls[verifyStateModalMock.mock.calls.length - 1][0] + expect(lastCall.onConfirm).toBeDefined() + expect(lastCall.onCancel).toBeDefined() + + // Call onConfirm to close modal + lastCall.onConfirm() + lastCall.onCancel() + }) }) diff --git a/web/app/components/billing/pricing/assets/index.spec.tsx b/web/app/components/billing/pricing/assets/index.spec.tsx index 7980f9a182..cc56c57593 100644 --- a/web/app/components/billing/pricing/assets/index.spec.tsx +++ b/web/app/components/billing/pricing/assets/index.spec.tsx @@ -52,6 +52,24 @@ describe('Pricing Assets', () => { expect(rects.some(rect => rect.getAttribute('fill') === 'var(--color-saas-dify-blue-accessible)')).toBe(true) }) + it('should render inactive state for Cloud', () => { + // Arrange + const { container } = render() + + // Assert + const rects = Array.from(container.querySelectorAll('rect')) + expect(rects.some(rect => rect.getAttribute('fill') === 'var(--color-text-primary)')).toBe(true) + }) + + it('should render active state for SelfHosted', () => { + // Arrange + const { container } = render() + + // Assert + const rects = Array.from(container.querySelectorAll('rect')) + expect(rects.some(rect => rect.getAttribute('fill') === 'var(--color-saas-dify-blue-accessible)')).toBe(true) + }) + it('should render inactive state for SelfHosted', () => { // Arrange const { container } = render() diff --git a/web/app/components/billing/utils/index.spec.ts b/web/app/components/billing/utils/index.spec.ts new file mode 100644 index 0000000000..03a159c18a --- /dev/null +++ b/web/app/components/billing/utils/index.spec.ts @@ -0,0 +1,301 @@ +import type { CurrentPlanInfoBackend } from '../type' +import { DocumentProcessingPriority, Plan } from '../type' +import { getPlanVectorSpaceLimitMB, parseCurrentPlan, parseVectorSpaceToMB } from './index' + +describe('billing utils', () => { + // parseVectorSpaceToMB tests + describe('parseVectorSpaceToMB', () => { + it('should parse MB values correctly', () => { + expect(parseVectorSpaceToMB('50MB')).toBe(50) + expect(parseVectorSpaceToMB('100MB')).toBe(100) + }) + + it('should parse GB values and convert to MB', () => { + expect(parseVectorSpaceToMB('5GB')).toBe(5 * 1024) + expect(parseVectorSpaceToMB('20GB')).toBe(20 * 1024) + }) + + it('should be case insensitive', () => { + expect(parseVectorSpaceToMB('50mb')).toBe(50) + expect(parseVectorSpaceToMB('5gb')).toBe(5 * 1024) + }) + + it('should return 0 for invalid format', () => { + expect(parseVectorSpaceToMB('50')).toBe(0) + expect(parseVectorSpaceToMB('invalid')).toBe(0) + expect(parseVectorSpaceToMB('')).toBe(0) + expect(parseVectorSpaceToMB('50TB')).toBe(0) + }) + }) + + // getPlanVectorSpaceLimitMB tests + describe('getPlanVectorSpaceLimitMB', () => { + it('should return correct vector space for sandbox plan', () => { + expect(getPlanVectorSpaceLimitMB(Plan.sandbox)).toBe(50) + }) + + it('should return correct vector space for professional plan', () => { + expect(getPlanVectorSpaceLimitMB(Plan.professional)).toBe(5 * 1024) + }) + + it('should return correct vector space for team plan', () => { + expect(getPlanVectorSpaceLimitMB(Plan.team)).toBe(20 * 1024) + }) + + it('should return 0 for invalid plan', () => { + // @ts-expect-error - Testing invalid plan input + expect(getPlanVectorSpaceLimitMB('invalid')).toBe(0) + }) + }) + + // parseCurrentPlan tests + describe('parseCurrentPlan', () => { + const createMockPlanData = (overrides: Partial = {}): CurrentPlanInfoBackend => ({ + billing: { + enabled: true, + subscription: { + plan: Plan.sandbox, + }, + }, + members: { + size: 1, + limit: 1, + }, + apps: { + size: 2, + limit: 5, + }, + vector_space: { + size: 10, + limit: 50, + }, + annotation_quota_limit: { + size: 5, + limit: 10, + }, + documents_upload_quota: { + size: 20, + limit: 0, + }, + docs_processing: DocumentProcessingPriority.standard, + can_replace_logo: false, + model_load_balancing_enabled: false, + dataset_operator_enabled: false, + education: { + enabled: false, + activated: false, + }, + webapp_copyright_enabled: false, + workspace_members: { + size: 1, + limit: 1, + }, + is_allow_transfer_workspace: false, + knowledge_pipeline: { + publish_enabled: false, + }, + ...overrides, + }) + + it('should parse plan type correctly', () => { + const data = createMockPlanData() + const result = parseCurrentPlan(data) + expect(result.type).toBe(Plan.sandbox) + }) + + it('should parse usage values correctly', () => { + const data = createMockPlanData() + const result = parseCurrentPlan(data) + + expect(result.usage.vectorSpace).toBe(10) + expect(result.usage.buildApps).toBe(2) + expect(result.usage.teamMembers).toBe(1) + expect(result.usage.annotatedResponse).toBe(5) + expect(result.usage.documentsUploadQuota).toBe(20) + }) + + it('should parse total limits correctly', () => { + const data = createMockPlanData() + const result = parseCurrentPlan(data) + + expect(result.total.vectorSpace).toBe(50) + expect(result.total.buildApps).toBe(5) + expect(result.total.teamMembers).toBe(1) + expect(result.total.annotatedResponse).toBe(10) + }) + + it('should convert 0 limits to NUM_INFINITE (-1)', () => { + const data = createMockPlanData({ + documents_upload_quota: { + size: 20, + limit: 0, + }, + }) + const result = parseCurrentPlan(data) + expect(result.total.documentsUploadQuota).toBe(-1) + }) + + it('should handle api_rate_limit quota', () => { + const data = createMockPlanData({ + api_rate_limit: { + usage: 100, + limit: 5000, + reset_date: null, + }, + }) + const result = parseCurrentPlan(data) + + expect(result.usage.apiRateLimit).toBe(100) + expect(result.total.apiRateLimit).toBe(5000) + }) + + it('should handle trigger_event quota', () => { + const data = createMockPlanData({ + trigger_event: { + usage: 50, + limit: 3000, + reset_date: null, + }, + }) + const result = parseCurrentPlan(data) + + expect(result.usage.triggerEvents).toBe(50) + expect(result.total.triggerEvents).toBe(3000) + }) + + it('should use fallback for api_rate_limit when not provided', () => { + const data = createMockPlanData() + const result = parseCurrentPlan(data) + + // Fallback to plan preset value for sandbox: 5000 + expect(result.total.apiRateLimit).toBe(5000) + }) + + it('should convert 0 or -1 rate limits to NUM_INFINITE', () => { + const data = createMockPlanData({ + api_rate_limit: { + usage: 0, + limit: 0, + reset_date: null, + }, + }) + const result = parseCurrentPlan(data) + expect(result.total.apiRateLimit).toBe(-1) + + const data2 = createMockPlanData({ + api_rate_limit: { + usage: 0, + limit: -1, + reset_date: null, + }, + }) + const result2 = parseCurrentPlan(data2) + expect(result2.total.apiRateLimit).toBe(-1) + }) + + it('should handle reset dates with milliseconds timestamp', () => { + const futureDate = Date.now() + 86400000 // Tomorrow in ms + const data = createMockPlanData({ + api_rate_limit: { + usage: 100, + limit: 5000, + reset_date: futureDate, + }, + }) + const result = parseCurrentPlan(data) + + expect(result.reset.apiRateLimit).toBe(1) + }) + + it('should handle reset dates with seconds timestamp', () => { + const futureDate = Math.floor(Date.now() / 1000) + 86400 // Tomorrow in seconds + const data = createMockPlanData({ + api_rate_limit: { + usage: 100, + limit: 5000, + reset_date: futureDate, + }, + }) + const result = parseCurrentPlan(data) + + expect(result.reset.apiRateLimit).toBe(1) + }) + + it('should handle reset dates in YYYYMMDD format', () => { + const tomorrow = new Date() + tomorrow.setDate(tomorrow.getDate() + 1) + const year = tomorrow.getFullYear() + const month = String(tomorrow.getMonth() + 1).padStart(2, '0') + const day = String(tomorrow.getDate()).padStart(2, '0') + const dateNumber = Number.parseInt(`${year}${month}${day}`, 10) + + const data = createMockPlanData({ + api_rate_limit: { + usage: 100, + limit: 5000, + reset_date: dateNumber, + }, + }) + const result = parseCurrentPlan(data) + + expect(result.reset.apiRateLimit).toBe(1) + }) + + it('should return null for invalid reset dates', () => { + const data = createMockPlanData({ + api_rate_limit: { + usage: 100, + limit: 5000, + reset_date: 0, + }, + }) + const result = parseCurrentPlan(data) + expect(result.reset.apiRateLimit).toBeNull() + }) + + it('should return null for negative reset dates', () => { + const data = createMockPlanData({ + api_rate_limit: { + usage: 100, + limit: 5000, + reset_date: -1, + }, + }) + const result = parseCurrentPlan(data) + expect(result.reset.apiRateLimit).toBeNull() + }) + + it('should return null when reset date is in the past', () => { + const pastDate = Date.now() - 86400000 // Yesterday + const data = createMockPlanData({ + api_rate_limit: { + usage: 100, + limit: 5000, + reset_date: pastDate, + }, + }) + const result = parseCurrentPlan(data) + expect(result.reset.apiRateLimit).toBeNull() + }) + + it('should handle missing apps field', () => { + const data = createMockPlanData() + // @ts-expect-error - Testing edge case + delete data.apps + const result = parseCurrentPlan(data) + expect(result.usage.buildApps).toBe(0) + }) + + it('should return null for unrecognized date format', () => { + const data = createMockPlanData({ + api_rate_limit: { + usage: 100, + limit: 5000, + reset_date: 12345, // Unrecognized format + }, + }) + const result = parseCurrentPlan(data) + expect(result.reset.apiRateLimit).toBeNull() + }) + }) +}) diff --git a/web/app/components/datasets/api/index.spec.tsx b/web/app/components/datasets/api/index.spec.tsx new file mode 100644 index 0000000000..33ee656a23 --- /dev/null +++ b/web/app/components/datasets/api/index.spec.tsx @@ -0,0 +1,24 @@ +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it } from 'vitest' +import ApiIndex from './index' + +afterEach(() => { + cleanup() +}) + +describe('ApiIndex', () => { + it('should render without crashing', () => { + render() + expect(screen.getByText('index')).toBeInTheDocument() + }) + + it('should render a div with text "index"', () => { + const { container } = render() + expect(container.firstChild).toBeInstanceOf(HTMLDivElement) + expect(container.textContent).toBe('index') + }) + + it('should be a valid function component', () => { + expect(typeof ApiIndex).toBe('function') + }) +}) diff --git a/web/app/components/datasets/chunk.spec.tsx b/web/app/components/datasets/chunk.spec.tsx new file mode 100644 index 0000000000..d3dc011aef --- /dev/null +++ b/web/app/components/datasets/chunk.spec.tsx @@ -0,0 +1,111 @@ +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it } from 'vitest' +import { ChunkContainer, ChunkLabel, QAPreview } from './chunk' + +afterEach(() => { + cleanup() +}) + +describe('ChunkLabel', () => { + it('should render label text', () => { + render() + expect(screen.getByText('Chunk 1')).toBeInTheDocument() + }) + + it('should render character count', () => { + render() + expect(screen.getByText('150 characters')).toBeInTheDocument() + }) + + it('should render separator dot', () => { + render() + expect(screen.getByText('ยท')).toBeInTheDocument() + }) + + it('should render with zero character count', () => { + render() + expect(screen.getByText('0 characters')).toBeInTheDocument() + }) + + it('should render with large character count', () => { + render() + expect(screen.getByText('999999 characters')).toBeInTheDocument() + }) +}) + +describe('ChunkContainer', () => { + it('should render label and character count', () => { + render(Content) + expect(screen.getByText('Container 1')).toBeInTheDocument() + expect(screen.getByText('200 characters')).toBeInTheDocument() + }) + + it('should render children content', () => { + render(Test Content) + expect(screen.getByText('Test Content')).toBeInTheDocument() + }) + + it('should render with complex children', () => { + render( + +
+ Nested content +
+
, + ) + expect(screen.getByTestId('child-div')).toBeInTheDocument() + expect(screen.getByText('Nested content')).toBeInTheDocument() + }) + + it('should render empty children', () => { + render({null}) + expect(screen.getByText('Empty')).toBeInTheDocument() + }) +}) + +describe('QAPreview', () => { + const mockQA = { + question: 'What is the meaning of life?', + answer: 'The meaning of life is 42.', + } + + it('should render question text', () => { + render() + expect(screen.getByText('What is the meaning of life?')).toBeInTheDocument() + }) + + it('should render answer text', () => { + render() + expect(screen.getByText('The meaning of life is 42.')).toBeInTheDocument() + }) + + it('should render Q label', () => { + render() + expect(screen.getByText('Q')).toBeInTheDocument() + }) + + it('should render A label', () => { + render() + expect(screen.getByText('A')).toBeInTheDocument() + }) + + it('should render with empty strings', () => { + render() + expect(screen.getByText('Q')).toBeInTheDocument() + expect(screen.getByText('A')).toBeInTheDocument() + }) + + it('should render with long text', () => { + const longQuestion = 'Q'.repeat(500) + const longAnswer = 'A'.repeat(500) + render() + expect(screen.getByText(longQuestion)).toBeInTheDocument() + expect(screen.getByText(longAnswer)).toBeInTheDocument() + }) + + it('should render with special characters', () => { + render(?', answer: '& special chars!' }} />) + expect(screen.getByText('What about & < > " \'' - renderWithProviders( - , + it('should show error notification when operation fails', async () => { + vi.useFakeTimers() + mockEnable.mockRejectedValue(new Error('API Error')) + render( + , ) - - // Act - hover to show tooltip - const tooltipTrigger = screen.getByTestId('error-tooltip-trigger') - fireEvent.mouseEnter(tooltipTrigger) - - // Assert - await waitFor(() => { - expect(screen.getByText(specialChars)).toBeInTheDocument() + const switchElement = document.querySelector('[role="switch"]') + await act(async () => { + fireEvent.click(switchElement!) }) - }) - - it('should handle all status types in sequence', () => { - // Arrange - const statuses: DocumentDisplayStatus[] = [ - 'queuing', - 'indexing', - 'paused', - 'error', - 'available', - 'enabled', - 'disabled', - 'archived', - ] - - // Act & Assert - statuses.forEach((status) => { - const { unmount } = renderWithProviders() - const indicator = screen.getByTestId('status-indicator') - expect(indicator).toBeInTheDocument() - unmount() + await act(async () => { + vi.advanceTimersByTime(600) + // Flush promises + await Promise.resolve() }) + expect(mockNotify).toHaveBeenCalledWith({ + type: 'error', + message: 'actionMsg.modifiedUnsuccessfully', + }) + vi.useRealTimers() }) }) - // ==================== Component Memoization ==================== - // Test React.memo behavior - describe('Component Memoization', () => { + describe('status color mapping', () => { + it('should have correct color class for green status', () => { + const { container } = render() + const text = container.querySelector('.text-util-colors-green-green-600') + expect(text).toBeInTheDocument() + }) + + it('should have correct color class for orange status', () => { + const { container } = render() + const text = container.querySelector('.text-util-colors-warning-warning-600') + expect(text).toBeInTheDocument() + }) + + it('should have correct color class for red status', () => { + const { container } = render() + const text = container.querySelector('.text-util-colors-red-red-600') + expect(text).toBeInTheDocument() + }) + + it('should have correct color class for blue status', () => { + const { container } = render() + const text = container.querySelector('.text-util-colors-blue-light-blue-light-600') + expect(text).toBeInTheDocument() + }) + + it('should have correct color class for gray status', () => { + const { container } = render() + const text = container.querySelector('.text-text-tertiary') + expect(text).toBeInTheDocument() + }) + }) + + describe('memoization', () => { it('should be wrapped with React.memo', () => { - // Assert - expect(StatusItem).toHaveProperty('$$typeof', Symbol.for('react.memo')) - }) - - it('should render correctly with same props', () => { - // Arrange - const props = { - status: 'available' as const, - scene: 'detail' as const, - detail: createDetailProps(), - } - - // Act - const { rerender } = renderWithProviders() - rerender( - - - , - ) - - // Assert - const indicator = screen.getByTestId('status-indicator') - expect(indicator).toBeInTheDocument() - }) - - it('should update when status prop changes', () => { - // Arrange - const { rerender } = renderWithProviders() - - // Assert initial - green/success background - let indicator = screen.getByTestId('status-indicator') - expect(indicator).toHaveClass('bg-components-badge-status-light-success-bg') - - // Act - rerender( - - - , - ) - - // Assert updated - red/error background - indicator = screen.getByTestId('status-indicator') - expect(indicator).toHaveClass('bg-components-badge-status-light-error-bg') + expect((StatusItem as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) }) }) - // ==================== Styling Tests ==================== - // Test CSS classes and styling - describe('Styling', () => { - it('should apply correct status text color for green status', () => { - // Arrange & Act - renderWithProviders() - - // Assert - const statusText = screen.getByText('datasetDocuments.list.status.available') - expect(statusText).toHaveClass('text-util-colors-green-green-600') - }) - - it('should apply correct status text color for red status', () => { - // Arrange & Act - renderWithProviders() - - // Assert - const statusText = screen.getByText('datasetDocuments.list.status.error') - expect(statusText).toHaveClass('text-util-colors-red-red-600') - }) - - it('should apply correct status text color for orange status', () => { - // Arrange & Act - renderWithProviders() - - // Assert - const statusText = screen.getByText('datasetDocuments.list.status.queuing') - expect(statusText).toHaveClass('text-util-colors-warning-warning-600') - }) - - it('should apply correct status text color for blue status', () => { - // Arrange & Act - renderWithProviders() - - // Assert - const statusText = screen.getByText('datasetDocuments.list.status.indexing') - expect(statusText).toHaveClass('text-util-colors-blue-light-blue-light-600') - }) - - it('should apply correct status text color for gray status', () => { - // Arrange & Act - renderWithProviders() - - // Assert - const statusText = screen.getByText('datasetDocuments.list.status.disabled') - expect(statusText).toHaveClass('text-text-tertiary') - }) - - it('should render switch with md size in detail scene', () => { - // Arrange & Act - renderWithProviders( + describe('default props', () => { + it('should work with default datasetId', () => { + render( , ) + const switchElement = document.querySelector('[role="switch"]') + expect(switchElement).toBeInTheDocument() + }) - // Assert - check switch has the md size class (h-4 w-7) - const switchEl = screen.getByRole('switch') - expect(switchEl).toHaveClass('h-4', 'w-7') + it('should work without detail prop', () => { + render() + expect(screen.getByText('Available')).toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/extra-info/api-access/index.spec.tsx b/web/app/components/datasets/extra-info/api-access/index.spec.tsx index fb4930cbdb..19e6b1ebca 100644 --- a/web/app/components/datasets/extra-info/api-access/index.spec.tsx +++ b/web/app/components/datasets/extra-info/api-access/index.spec.tsx @@ -1,792 +1,137 @@ -import type { DataSet } from '@/models/datasets' -import { render, screen, waitFor } from '@testing-library/react' -import userEvent from '@testing-library/user-event' -import { beforeEach, describe, expect, it, vi } from 'vitest' - -// ============================================================================ -// Component Imports (after mocks) -// ============================================================================ - -import Card from './card' +import { act, cleanup, fireEvent, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' import ApiAccess from './index' -// ============================================================================ -// Mock Setup -// ============================================================================ - -// Mock next/navigation -vi.mock('next/navigation', () => ({ - useRouter: () => ({ - push: vi.fn(), - replace: vi.fn(), +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, }), - usePathname: () => '/test', - useSearchParams: () => new URLSearchParams(), })) -// Mock next/link -vi.mock('next/link', () => ({ - default: ({ children, href, ...props }: { children: React.ReactNode, href: string, [key: string]: unknown }) => ( - {children} - ), -})) - -// Dataset context mock data -const mockDataset: Partial = { - id: 'dataset-123', - name: 'Test Dataset', - enable_api: true, -} - -// Mock use-context-selector -vi.mock('use-context-selector', () => ({ - useContext: vi.fn(() => ({ dataset: mockDataset })), - useContextSelector: vi.fn((_, selector) => selector({ dataset: mockDataset })), - createContext: vi.fn(() => ({})), -})) - -// Mock dataset detail context -const mockMutateDatasetRes = vi.fn() +// Mock context and hooks for Card component vi.mock('@/context/dataset-detail', () => ({ - default: {}, - useDatasetDetailContext: vi.fn(() => ({ - dataset: mockDataset, - mutateDatasetRes: mockMutateDatasetRes, - })), - useDatasetDetailContextWithSelector: vi.fn((selector: (v: { dataset?: typeof mockDataset, mutateDatasetRes?: () => void }) => unknown) => - selector({ dataset: mockDataset as DataSet, mutateDatasetRes: mockMutateDatasetRes }), - ), + useDatasetDetailContextWithSelector: vi.fn(() => 'test-dataset-id'), })) -// Mock app context for workspace permissions -let mockIsCurrentWorkspaceManager = true vi.mock('@/context/app-context', () => ({ - useSelector: vi.fn((selector: (state: { isCurrentWorkspaceManager: boolean }) => unknown) => - selector({ isCurrentWorkspaceManager: mockIsCurrentWorkspaceManager }), - ), + useSelector: vi.fn(() => true), })) -// Mock service hooks -const mockEnableDatasetServiceApi = vi.fn(() => Promise.resolve({ result: 'success' })) -const mockDisableDatasetServiceApi = vi.fn(() => Promise.resolve({ result: 'success' })) +vi.mock('@/hooks/use-api-access-url', () => ({ + useDatasetApiAccessUrl: vi.fn(() => 'https://api.example.com/docs'), +})) vi.mock('@/service/knowledge/use-dataset', () => ({ - useDatasetApiBaseUrl: vi.fn(() => ({ - data: { api_base_url: 'https://api.example.com' }, - isLoading: false, - })), - useEnableDatasetServiceApi: vi.fn(() => ({ - mutateAsync: mockEnableDatasetServiceApi, - isPending: false, - })), - useDisableDatasetServiceApi: vi.fn(() => ({ - mutateAsync: mockDisableDatasetServiceApi, - isPending: false, - })), + useEnableDatasetServiceApi: vi.fn(() => ({ mutateAsync: vi.fn() })), + useDisableDatasetServiceApi: vi.fn(() => ({ mutateAsync: vi.fn() })), })) -// Mock API access URL hook -vi.mock('@/hooks/use-api-access-url', () => ({ - useDatasetApiAccessUrl: vi.fn(() => 'https://docs.dify.ai/api-reference/datasets'), -})) - -// ============================================================================ -// ApiAccess Component Tests -// ============================================================================ +afterEach(() => { + cleanup() +}) describe('ApiAccess', () => { - beforeEach(() => { - vi.clearAllMocks() + it('should render without crashing', () => { + render() + expect(screen.getByText('appMenus.apiAccess')).toBeInTheDocument() }) - // -------------------------------------------------------------------------- - // Rendering Tests - // -------------------------------------------------------------------------- - describe('Rendering', () => { - it('should render without crashing', () => { - render() - expect(screen.getByText(/appMenus\.apiAccess/i)).toBeInTheDocument() - }) - - it('should render API title when expanded', () => { - render() - expect(screen.getByText(/appMenus\.apiAccess/i)).toBeInTheDocument() - }) - - it('should not render API title when collapsed', () => { - render() - expect(screen.queryByText(/appMenus\.apiAccess/i)).not.toBeInTheDocument() - }) - - it('should render ApiAggregate icon', () => { - const { container } = render() - const icon = container.querySelector('svg') - expect(icon).toBeInTheDocument() - }) - - it('should render Indicator component', () => { - const { container } = render() - const indicatorElement = container.querySelector('.relative.flex.h-8') - expect(indicatorElement).toBeInTheDocument() - }) - - it('should render with proper container padding', () => { - const { container } = render() - const wrapper = container.firstChild as HTMLElement - expect(wrapper).toHaveClass('p-3', 'pt-2') - }) + it('should render API access text when expanded', () => { + render() + expect(screen.getByText('appMenus.apiAccess')).toBeInTheDocument() }) - // -------------------------------------------------------------------------- - // Props Variations Tests - // -------------------------------------------------------------------------- - describe('Props Variations', () => { - it('should apply compressed layout when expand is false', () => { - const { container } = render() - const triggerContainer = container.querySelector('[class*="w-8"]') - expect(triggerContainer).toBeInTheDocument() - }) - - it('should apply full width when expand is true', () => { - const { container } = render() - const trigger = container.querySelector('.w-full') - expect(trigger).toBeInTheDocument() - }) - - it('should pass apiEnabled=true to Indicator with green color', () => { - const { container } = render() - // Indicator uses color prop - test the visual presence - const indicatorContainer = container.querySelector('.relative.flex.h-8') - expect(indicatorContainer).toBeInTheDocument() - }) - - it('should pass apiEnabled=false to Indicator with yellow color', () => { - const { container } = render() - const indicatorContainer = container.querySelector('.relative.flex.h-8') - expect(indicatorContainer).toBeInTheDocument() - }) - - it('should position Indicator absolutely when collapsed', () => { - const { container } = render() - // When collapsed, Indicator has 'absolute -right-px -top-px' classes - const triggerDiv = container.querySelector('[class*="w-8"][class*="justify-center"]') - expect(triggerDiv).toBeInTheDocument() - }) + it('should not render API access text when collapsed', () => { + render() + expect(screen.queryByText('appMenus.apiAccess')).not.toBeInTheDocument() }) - // -------------------------------------------------------------------------- - // User Interactions Tests - // -------------------------------------------------------------------------- - describe('User Interactions', () => { - it('should toggle popup open state on click', async () => { - const user = userEvent.setup() - - render() - - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('[class*="cursor-pointer"]') - expect(trigger).toBeInTheDocument() - - if (trigger) - await user.click(trigger) - - // After click, the popup should toggle (Card should be rendered via portal) - }) - - it('should apply hover styles on trigger', () => { - render() - - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('div[class*="cursor-pointer"]') - expect(trigger).toHaveClass('cursor-pointer') - }) - - it('should toggle open state from false to true on first click', async () => { - const user = userEvent.setup() - - render() - - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('[class*="cursor-pointer"]') - if (trigger) - await user.click(trigger) - - // The handleToggle function should flip open from false to true - }) - - it('should toggle open state back to false on second click', async () => { - const user = userEvent.setup() - - render() - - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('[class*="cursor-pointer"]') - if (trigger) { - await user.click(trigger) // open - await user.click(trigger) // close - } - - // The handleToggle function should flip open from true to false - }) - - it('should apply open state styling when popup is open', async () => { - const user = userEvent.setup() - - render() - - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('[class*="cursor-pointer"]') - if (trigger) - await user.click(trigger) - - // When open, the trigger should have bg-state-base-hover class - }) + it('should render with apiEnabled=true', () => { + render() + expect(screen.getByText('appMenus.apiAccess')).toBeInTheDocument() }) - // -------------------------------------------------------------------------- - // Portal and Card Integration Tests - // -------------------------------------------------------------------------- - describe('Portal and Card Integration', () => { - it('should render Card component inside portal when open', async () => { - const user = userEvent.setup() - - render() - - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('[class*="cursor-pointer"]') - if (trigger) - await user.click(trigger) - - // Wait for portal content to appear - await waitFor(() => { - expect(screen.getByText(/serviceApi\.enabled/i)).toBeInTheDocument() - }) - }) - - it('should pass apiEnabled prop to Card component', async () => { - const user = userEvent.setup() - - render() - - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('[class*="cursor-pointer"]') - if (trigger) - await user.click(trigger) - - await waitFor(() => { - expect(screen.getByText(/serviceApi\.disabled/i)).toBeInTheDocument() - }) - }) - - it('should use correct portal placement configuration', () => { - render() - // PortalToFollowElem is configured with placement="top-start" - // The component should render without errors - expect(screen.getByText(/appMenus\.apiAccess/i)).toBeInTheDocument() - }) - - it('should use correct portal offset configuration', () => { - render() - // PortalToFollowElem is configured with offset={{ mainAxis: 4, crossAxis: -4 }} - // The component should render without errors - expect(screen.getByText(/appMenus\.apiAccess/i)).toBeInTheDocument() - }) - }) - - // -------------------------------------------------------------------------- - // Edge Cases Tests - // -------------------------------------------------------------------------- - describe('Edge Cases', () => { - it('should handle rapid toggle clicks gracefully', async () => { - const user = userEvent.setup() - - const { container } = render() - - // Use a more specific selector to find the trigger in the main component - const trigger = container.querySelector('.p-3 [class*="cursor-pointer"]') - if (trigger) { - // Rapid clicks - await user.click(trigger) - await user.click(trigger) - await user.click(trigger) - } - - // Component should handle state changes without errors - use getAllByText since Card may be open - const elements = screen.getAllByText(/appMenus\.apiAccess/i) - expect(elements.length).toBeGreaterThan(0) - }) - - it('should render correctly when both expand and apiEnabled are false', () => { - render() - // Should render without title but with indicator - expect(screen.queryByText(/appMenus\.apiAccess/i)).not.toBeInTheDocument() - }) - - it('should maintain state across prop changes', () => { - const { rerender } = render() - - expect(screen.getByText(/appMenus\.apiAccess/i)).toBeInTheDocument() - - rerender() - - // Component should still render after prop change - expect(screen.getByText(/appMenus\.apiAccess/i)).toBeInTheDocument() - }) - }) - - // -------------------------------------------------------------------------- - // Memoization Tests - // -------------------------------------------------------------------------- - describe('Memoization', () => { - it('should be memoized with React.memo', () => { - const { rerender } = render() - - rerender() - - expect(screen.getByText(/appMenus\.apiAccess/i)).toBeInTheDocument() - }) - - it('should not re-render unnecessarily with same props', () => { - const { rerender } = render() - - rerender() - rerender() - - expect(screen.getByText(/appMenus\.apiAccess/i)).toBeInTheDocument() - }) - }) -}) - -// ============================================================================ -// Card Component Tests -// ============================================================================ - -describe('Card (api-access)', () => { - beforeEach(() => { - vi.clearAllMocks() - mockIsCurrentWorkspaceManager = true - mockEnableDatasetServiceApi.mockResolvedValue({ result: 'success' }) - mockDisableDatasetServiceApi.mockResolvedValue({ result: 'success' }) - }) - - // -------------------------------------------------------------------------- - // Rendering Tests - // -------------------------------------------------------------------------- - describe('Rendering', () => { - it('should render without crashing', () => { - render() - expect(screen.getByText(/serviceApi\.enabled/i)).toBeInTheDocument() - }) - - it('should display enabled status when API is enabled', () => { - render() - expect(screen.getByText(/serviceApi\.enabled/i)).toBeInTheDocument() - }) - - it('should display disabled status when API is disabled', () => { - render() - expect(screen.getByText(/serviceApi\.disabled/i)).toBeInTheDocument() - }) - - it('should render switch component', () => { - render() - expect(screen.getByRole('switch')).toBeInTheDocument() - }) - - it('should render API Reference link', () => { - render() - expect(screen.getByText(/overview\.apiInfo\.doc/i)).toBeInTheDocument() - }) - - it('should render Indicator component', () => { - const { container } = render() - // Indicator is rendered - verify card structure - const cardContainer = container.querySelector('.w-\\[208px\\]') - expect(cardContainer).toBeInTheDocument() - }) - - it('should render description tip text', () => { - render() - expect(screen.getByText(/appMenus\.apiAccessTip/i)).toBeInTheDocument() - }) - - it('should apply success text color when enabled', () => { - render() - const statusText = screen.getByText(/serviceApi\.enabled/i) - expect(statusText).toHaveClass('text-text-success') - }) - - it('should apply warning text color when disabled', () => { - render() - const statusText = screen.getByText(/serviceApi\.disabled/i) - expect(statusText).toHaveClass('text-text-warning') - }) - }) - - // -------------------------------------------------------------------------- - // User Interactions Tests - // -------------------------------------------------------------------------- - describe('User Interactions', () => { - it('should call enableDatasetServiceApi when switch is toggled on', async () => { - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - expect(mockEnableDatasetServiceApi).toHaveBeenCalledWith('dataset-123') - }) - }) - - it('should call disableDatasetServiceApi when switch is toggled off', async () => { - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - expect(mockDisableDatasetServiceApi).toHaveBeenCalledWith('dataset-123') - }) - }) - - it('should call mutateDatasetRes after successful API enable', async () => { - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - expect(mockMutateDatasetRes).toHaveBeenCalled() - }) - }) - - it('should call mutateDatasetRes after successful API disable', async () => { - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - expect(mockMutateDatasetRes).toHaveBeenCalled() - }) - }) - - it('should not call mutateDatasetRes on API enable failure', async () => { - mockEnableDatasetServiceApi.mockResolvedValueOnce({ result: 'fail' }) - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - expect(mockEnableDatasetServiceApi).toHaveBeenCalled() - }) - - expect(mockMutateDatasetRes).not.toHaveBeenCalled() - }) - - it('should not call mutateDatasetRes on API disable failure', async () => { - mockDisableDatasetServiceApi.mockResolvedValueOnce({ result: 'fail' }) - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - expect(mockDisableDatasetServiceApi).toHaveBeenCalled() - }) - - expect(mockMutateDatasetRes).not.toHaveBeenCalled() - }) - - it('should have correct href for API Reference link', () => { - render() - - const apiRefLink = screen.getByText(/overview\.apiInfo\.doc/i).closest('a') - expect(apiRefLink).toHaveAttribute('href', 'https://docs.dify.ai/api-reference/datasets') - }) - - it('should open API Reference in new tab', () => { - render() - - const apiRefLink = screen.getByText(/overview\.apiInfo\.doc/i).closest('a') - expect(apiRefLink).toHaveAttribute('target', '_blank') - expect(apiRefLink).toHaveAttribute('rel', 'noopener noreferrer') - }) - }) - - // -------------------------------------------------------------------------- - // Permission Handling Tests - // -------------------------------------------------------------------------- - describe('Permission Handling', () => { - it('should disable switch when user is not workspace manager', () => { - mockIsCurrentWorkspaceManager = false - - render() - - const switchButton = screen.getByRole('switch') - expect(switchButton).toHaveClass('!cursor-not-allowed') - expect(switchButton).toHaveClass('!opacity-50') - }) - - it('should enable switch when user is workspace manager', () => { - mockIsCurrentWorkspaceManager = true - - render() - - const switchButton = screen.getByRole('switch') - expect(switchButton).not.toHaveClass('!cursor-not-allowed') - expect(switchButton).not.toHaveClass('!opacity-50') - }) - - it('should not trigger API call when switch is disabled and clicked', async () => { - mockIsCurrentWorkspaceManager = false - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - // API should not be called when disabled - expect(mockEnableDatasetServiceApi).not.toHaveBeenCalled() - }) - }) - - // -------------------------------------------------------------------------- - // Edge Cases Tests - // -------------------------------------------------------------------------- - describe('Edge Cases', () => { - it('should handle empty datasetId gracefully', async () => { - const { useDatasetDetailContextWithSelector } = await import('@/context/dataset-detail') - vi.mocked(useDatasetDetailContextWithSelector).mockImplementation((selector) => { - return selector({ - dataset: { ...mockDataset, id: '' } as DataSet, - mutateDatasetRes: mockMutateDatasetRes, - }) - }) - - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - expect(mockEnableDatasetServiceApi).toHaveBeenCalledWith('') - }) - - // Reset mock - vi.mocked(useDatasetDetailContextWithSelector).mockImplementation(selector => - selector({ dataset: mockDataset as DataSet, mutateDatasetRes: mockMutateDatasetRes }), - ) - }) - - it('should handle undefined datasetId gracefully when enabling API', async () => { - const { useDatasetDetailContextWithSelector } = await import('@/context/dataset-detail') - vi.mocked(useDatasetDetailContextWithSelector).mockImplementation((selector) => { - const partialDataset = { ...mockDataset } as Partial - delete partialDataset.id - return selector({ - dataset: partialDataset as DataSet, - mutateDatasetRes: mockMutateDatasetRes, - }) - }) - - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - // Should use fallback empty string - expect(mockEnableDatasetServiceApi).toHaveBeenCalledWith('') - }) - - // Reset mock - vi.mocked(useDatasetDetailContextWithSelector).mockImplementation(selector => - selector({ dataset: mockDataset as DataSet, mutateDatasetRes: mockMutateDatasetRes }), - ) - }) - - it('should handle undefined datasetId gracefully when disabling API', async () => { - const { useDatasetDetailContextWithSelector } = await import('@/context/dataset-detail') - vi.mocked(useDatasetDetailContextWithSelector).mockImplementation((selector) => { - const partialDataset = { ...mockDataset } as Partial - delete partialDataset.id - return selector({ - dataset: partialDataset as DataSet, - mutateDatasetRes: mockMutateDatasetRes, - }) - }) - - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - // Should use fallback empty string for disableDatasetServiceApi - expect(mockDisableDatasetServiceApi).toHaveBeenCalledWith('') - }) - - // Reset mock - vi.mocked(useDatasetDetailContextWithSelector).mockImplementation(selector => - selector({ dataset: mockDataset as DataSet, mutateDatasetRes: mockMutateDatasetRes }), - ) - }) - - it('should handle undefined mutateDatasetRes gracefully', async () => { - const { useDatasetDetailContextWithSelector } = await import('@/context/dataset-detail') - vi.mocked(useDatasetDetailContextWithSelector).mockImplementation((selector) => { - return selector({ - dataset: mockDataset as DataSet, - mutateDatasetRes: undefined, - }) - }) - - const user = userEvent.setup() - - render() - - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - await waitFor(() => { - expect(mockEnableDatasetServiceApi).toHaveBeenCalled() - }) - - // Should not throw error when mutateDatasetRes is undefined - - // Reset mock - vi.mocked(useDatasetDetailContextWithSelector).mockImplementation(selector => - selector({ dataset: mockDataset as DataSet, mutateDatasetRes: mockMutateDatasetRes }), - ) - }) - }) - - // -------------------------------------------------------------------------- - // Memoization Tests - // -------------------------------------------------------------------------- - describe('Memoization', () => { - it('should be memoized with React.memo', () => { - const { rerender } = render() - - rerender() - - expect(screen.getByText(/serviceApi\.enabled/i)).toBeInTheDocument() - }) - - it('should use useCallback for onToggle handler', () => { - const { rerender } = render() - - rerender() - - // Component should render without issues with memoized callbacks - expect(screen.getByRole('switch')).toBeInTheDocument() - }) - - it('should update when apiEnabled prop changes', () => { - const { rerender } = render() - - expect(screen.getByText(/serviceApi\.enabled/i)).toBeInTheDocument() - - rerender() - - expect(screen.getByText(/serviceApi\.disabled/i)).toBeInTheDocument() - }) - }) -}) - -// ============================================================================ -// Integration Tests -// ============================================================================ - -describe('ApiAccess Integration', () => { - beforeEach(() => { - vi.clearAllMocks() - mockIsCurrentWorkspaceManager = true - mockEnableDatasetServiceApi.mockResolvedValue({ result: 'success' }) - mockDisableDatasetServiceApi.mockResolvedValue({ result: 'success' }) - }) - - it('should open Card popup and toggle API status', async () => { - const user = userEvent.setup() - + it('should render with apiEnabled=false', () => { render() + expect(screen.getByText('appMenus.apiAccess')).toBeInTheDocument() + }) - // Open popup - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('[class*="cursor-pointer"]') - if (trigger) - await user.click(trigger) + it('should be wrapped with React.memo', () => { + expect((ApiAccess as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) - // Wait for Card to appear - await waitFor(() => { - expect(screen.getByText(/serviceApi\.disabled/i)).toBeInTheDocument() + describe('toggle functionality', () => { + it('should toggle open state when trigger is clicked', async () => { + const { container } = render() + const trigger = container.querySelector('.cursor-pointer') + expect(trigger).toBeInTheDocument() + + // Click to open + await act(async () => { + fireEvent.click(trigger!) + }) + + // The component should update its state - check for state change via class + expect(trigger).toBeInTheDocument() }) - // Toggle API on - const switchButton = screen.getByRole('switch') - await user.click(switchButton) + it('should toggle open state multiple times', async () => { + const { container } = render() + const trigger = container.querySelector('.cursor-pointer') - await waitFor(() => { - expect(mockEnableDatasetServiceApi).toHaveBeenCalledWith('dataset-123') + // First click - open + await act(async () => { + fireEvent.click(trigger!) + }) + + // Second click - close + await act(async () => { + fireEvent.click(trigger!) + }) + + expect(trigger).toBeInTheDocument() + }) + + it('should work when collapsed', async () => { + const { container } = render() + const trigger = container.querySelector('.cursor-pointer') + + await act(async () => { + fireEvent.click(trigger!) + }) + + expect(trigger).toBeInTheDocument() }) }) - it('should complete full workflow: open -> view status -> toggle -> verify callback', async () => { - const user = userEvent.setup() - - render() - - // Open popup - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('[class*="cursor-pointer"]') - if (trigger) - await user.click(trigger) - - // Verify enabled status is shown - await waitFor(() => { - expect(screen.getByText(/serviceApi\.enabled/i)).toBeInTheDocument() + describe('indicator color', () => { + it('should render with green indicator when apiEnabled is true', () => { + const { container } = render() + // Indicator component should be present + const indicator = container.querySelector('.shrink-0') + expect(indicator).toBeInTheDocument() }) - // Toggle API off - const switchButton = screen.getByRole('switch') - await user.click(switchButton) - - // Verify API call and callback - await waitFor(() => { - expect(mockDisableDatasetServiceApi).toHaveBeenCalledWith('dataset-123') - expect(mockMutateDatasetRes).toHaveBeenCalled() + it('should render with yellow indicator when apiEnabled is false', () => { + const { container } = render() + const indicator = container.querySelector('.shrink-0') + expect(indicator).toBeInTheDocument() }) }) - it('should navigate to API Reference from Card', async () => { - const user = userEvent.setup() - - render() - - // Open popup - const trigger = screen.getByText(/appMenus\.apiAccess/i).closest('[class*="cursor-pointer"]') - if (trigger) - await user.click(trigger) - - // Wait for Card to appear - await waitFor(() => { - expect(screen.getByText(/overview\.apiInfo\.doc/i)).toBeInTheDocument() + describe('layout', () => { + it('should have justify-center when collapsed', () => { + const { container } = render() + const trigger = container.querySelector('.justify-center') + expect(trigger).toBeInTheDocument() }) - // Verify link - const apiRefLink = screen.getByText(/overview\.apiInfo\.doc/i).closest('a') - expect(apiRefLink).toHaveAttribute('href', 'https://docs.dify.ai/api-reference/datasets') + it('should not have justify-center when expanded', () => { + const { container } = render() + const innerDiv = container.querySelector('.cursor-pointer') + // When expanded, should have gap-2 and text, not justify-center + expect(innerDiv).not.toHaveClass('justify-center') + }) }) }) diff --git a/web/app/components/datasets/extra-info/statistics.spec.tsx b/web/app/components/datasets/extra-info/statistics.spec.tsx new file mode 100644 index 0000000000..d7f79a1ab2 --- /dev/null +++ b/web/app/components/datasets/extra-info/statistics.spec.tsx @@ -0,0 +1,87 @@ +import type { RelatedApp, RelatedAppResponse } from '@/models/datasets' +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import { AppModeEnum } from '@/types/app' +import Statistics from './statistics' + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock useDocLink +vi.mock('@/context/i18n', () => ({ + useDocLink: () => (path: string) => `https://docs.example.com${path}`, +})) + +afterEach(() => { + cleanup() +}) + +describe('Statistics', () => { + const mockRelatedApp: RelatedApp = { + id: 'app-1', + name: 'Test App', + mode: AppModeEnum.CHAT, + icon_type: 'emoji', + icon: '๐Ÿค–', + icon_background: '#ffffff', + icon_url: '', + } + + const mockRelatedApps: RelatedAppResponse = { + data: [mockRelatedApp], + total: 1, + } + + it('should render document count', () => { + render() + expect(screen.getByText('5')).toBeInTheDocument() + }) + + it('should render document label', () => { + render() + expect(screen.getByText('datasetMenus.documents')).toBeInTheDocument() + }) + + it('should render related apps total', () => { + render() + expect(screen.getByText('1')).toBeInTheDocument() + }) + + it('should render related app label', () => { + render() + expect(screen.getByText('datasetMenus.relatedApp')).toBeInTheDocument() + }) + + it('should render -- for undefined document count', () => { + render() + expect(screen.getByText('--')).toBeInTheDocument() + }) + + it('should render -- for undefined related apps total', () => { + render() + const dashes = screen.getAllByText('--') + expect(dashes.length).toBeGreaterThan(0) + }) + + it('should render with zero document count', () => { + render() + expect(screen.getByText('0')).toBeInTheDocument() + }) + + it('should render with empty related apps', () => { + const emptyRelatedApps: RelatedAppResponse = { + data: [], + total: 0, + } + render() + expect(screen.getByText('0')).toBeInTheDocument() + }) + + it('should be wrapped with React.memo', () => { + expect((Statistics as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) +}) diff --git a/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts b/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts index ad68a1df1c..4bd8357f1c 100644 --- a/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts +++ b/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts @@ -5,6 +5,7 @@ import { useTranslation } from 'react-i18next' import Toast from '@/app/components/base/toast' import { useCheckDatasetUsage, useDeleteDataset } from '@/service/use-dataset-card' import { useExportPipelineDSL } from '@/service/use-pipeline' +import { downloadBlob } from '@/utils/download' type ModalState = { showRenameModal: boolean @@ -65,13 +66,8 @@ export const useDatasetCardState = ({ dataset, onSuccess }: UseDatasetCardStateO pipelineId: pipeline_id, include, }) - const a = document.createElement('a') const file = new Blob([data], { type: 'application/yaml' }) - const url = URL.createObjectURL(file) - a.href = url - a.download = `${name}.pipeline` - a.click() - URL.revokeObjectURL(url) + downloadBlob({ data: file, fileName: `${name}.pipeline` }) } catch { Toast.notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) diff --git a/web/app/components/datasets/loading.spec.tsx b/web/app/components/datasets/loading.spec.tsx new file mode 100644 index 0000000000..0b291d727f --- /dev/null +++ b/web/app/components/datasets/loading.spec.tsx @@ -0,0 +1,21 @@ +import { cleanup, render } from '@testing-library/react' +import { afterEach, describe, expect, it } from 'vitest' +import DatasetsLoading from './loading' + +afterEach(() => { + cleanup() +}) + +describe('DatasetsLoading', () => { + it('should render null', () => { + const { container } = render() + expect(container.firstChild).toBeNull() + }) + + it('should not throw on multiple renders', () => { + expect(() => { + render() + render() + }).not.toThrow() + }) +}) diff --git a/web/app/components/datasets/no-linked-apps-panel.spec.tsx b/web/app/components/datasets/no-linked-apps-panel.spec.tsx new file mode 100644 index 0000000000..aa66e43fbd --- /dev/null +++ b/web/app/components/datasets/no-linked-apps-panel.spec.tsx @@ -0,0 +1,58 @@ +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import NoLinkedAppsPanel from './no-linked-apps-panel' + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock useDocLink +vi.mock('@/context/i18n', () => ({ + useDocLink: () => (path: string) => `https://docs.example.com${path}`, +})) + +afterEach(() => { + cleanup() +}) + +describe('NoLinkedAppsPanel', () => { + it('should render without crashing', () => { + render() + expect(screen.getByText('datasetMenus.emptyTip')).toBeInTheDocument() + }) + + it('should render the empty tip text', () => { + render() + expect(screen.getByText('datasetMenus.emptyTip')).toBeInTheDocument() + }) + + it('should render the view doc link', () => { + render() + expect(screen.getByText('datasetMenus.viewDoc')).toBeInTheDocument() + }) + + it('should render link with correct href', () => { + render() + const link = screen.getByRole('link') + expect(link).toHaveAttribute('href', 'https://docs.example.com/use-dify/knowledge/integrate-knowledge-within-application') + }) + + it('should render link with target="_blank"', () => { + render() + const link = screen.getByRole('link') + expect(link).toHaveAttribute('target', '_blank') + }) + + it('should render link with rel="noopener noreferrer"', () => { + render() + const link = screen.getByRole('link') + expect(link).toHaveAttribute('rel', 'noopener noreferrer') + }) + + it('should be wrapped with React.memo', () => { + expect((NoLinkedAppsPanel as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) +}) diff --git a/web/app/components/datasets/preview/index.spec.tsx b/web/app/components/datasets/preview/index.spec.tsx new file mode 100644 index 0000000000..56638fb612 --- /dev/null +++ b/web/app/components/datasets/preview/index.spec.tsx @@ -0,0 +1,25 @@ +import { cleanup, render } from '@testing-library/react' +import { afterEach, describe, expect, it } from 'vitest' +import DatasetPreview from './index' + +afterEach(() => { + cleanup() +}) + +describe('DatasetPreview', () => { + it('should render null', () => { + const { container } = render() + expect(container.firstChild).toBeNull() + }) + + it('should be a valid function component', () => { + expect(typeof DatasetPreview).toBe('function') + }) + + it('should not throw on multiple renders', () => { + expect(() => { + render() + render() + }).not.toThrow() + }) +}) diff --git a/web/app/components/develop/ApiServer.spec.tsx b/web/app/components/develop/ApiServer.spec.tsx new file mode 100644 index 0000000000..097eac578a --- /dev/null +++ b/web/app/components/develop/ApiServer.spec.tsx @@ -0,0 +1,220 @@ +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { act } from 'react' +import ApiServer from './ApiServer' + +// Mock the secret-key-modal since it involves complex API interactions +vi.mock('@/app/components/develop/secret-key/secret-key-modal', () => ({ + default: ({ isShow, onClose }: { isShow: boolean, onClose: () => void }) => ( + isShow ?
: null + ), +})) + +describe('ApiServer', () => { + const defaultProps = { + apiBaseUrl: 'https://api.example.com', + } + + describe('rendering', () => { + it('should render the API server label', () => { + render() + expect(screen.getByText('appApi.apiServer')).toBeInTheDocument() + }) + + it('should render the API base URL', () => { + render() + expect(screen.getByText('https://api.example.com')).toBeInTheDocument() + }) + + it('should render the OK status badge', () => { + render() + expect(screen.getByText('appApi.ok')).toBeInTheDocument() + }) + + it('should render the API key button', () => { + render() + expect(screen.getByText('appApi.apiKey')).toBeInTheDocument() + }) + + it('should render CopyFeedback component', () => { + render() + // CopyFeedback renders a button for copying + const copyButtons = screen.getAllByRole('button') + expect(copyButtons.length).toBeGreaterThan(0) + }) + }) + + describe('with different API URLs', () => { + it('should render localhost URL', () => { + render() + expect(screen.getByText('http://localhost:3000/api')).toBeInTheDocument() + }) + + it('should render production URL', () => { + render() + expect(screen.getByText('https://api.dify.ai/v1')).toBeInTheDocument() + }) + + it('should render URL with path', () => { + render() + expect(screen.getByText('https://api.example.com/v1/chat')).toBeInTheDocument() + }) + }) + + describe('with appId prop', () => { + it('should render without appId', () => { + render() + expect(screen.getByText('https://api.example.com')).toBeInTheDocument() + }) + + it('should render with appId', () => { + render() + expect(screen.getByText('https://api.example.com')).toBeInTheDocument() + }) + }) + + describe('SecretKeyButton interaction', () => { + it('should open modal when API key button is clicked', async () => { + const user = userEvent.setup() + render() + + const apiKeyButton = screen.getByText('appApi.apiKey') + await act(async () => { + await user.click(apiKeyButton) + }) + + expect(screen.getByTestId('secret-key-modal')).toBeInTheDocument() + }) + + it('should close modal when close button is clicked', async () => { + const user = userEvent.setup() + render() + + // Open modal + const apiKeyButton = screen.getByText('appApi.apiKey') + await act(async () => { + await user.click(apiKeyButton) + }) + + expect(screen.getByTestId('secret-key-modal')).toBeInTheDocument() + + // Close modal + const closeButton = screen.getByText('Close Modal') + await act(async () => { + await user.click(closeButton) + }) + + expect(screen.queryByTestId('secret-key-modal')).not.toBeInTheDocument() + }) + }) + + describe('styling', () => { + it('should have flex layout with wrap', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('flex') + expect(wrapper.className).toContain('flex-wrap') + }) + + it('should have items-center alignment', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('items-center') + }) + + it('should have gap-y-2 for vertical spacing', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('gap-y-2') + }) + + it('should apply green styling to OK badge', () => { + render() + const okBadge = screen.getByText('appApi.ok') + expect(okBadge.className).toContain('bg-[#ECFDF3]') + expect(okBadge.className).toContain('text-[#039855]') + }) + + it('should have border styling on URL container', () => { + render() + const urlText = screen.getByText('https://api.example.com') + const urlContainer = urlText.closest('div[class*="rounded-lg"]') + expect(urlContainer).toBeInTheDocument() + }) + }) + + describe('API server label', () => { + it('should have correct styling for label', () => { + render() + const label = screen.getByText('appApi.apiServer') + expect(label.className).toContain('rounded-md') + expect(label.className).toContain('border') + }) + + it('should have tertiary text color on label', () => { + render() + const label = screen.getByText('appApi.apiServer') + expect(label.className).toContain('text-text-tertiary') + }) + }) + + describe('URL display', () => { + it('should have truncate class for long URLs', () => { + render() + const urlText = screen.getByText('https://api.example.com') + expect(urlText.className).toContain('truncate') + }) + + it('should have font-medium class on URL', () => { + render() + const urlText = screen.getByText('https://api.example.com') + expect(urlText.className).toContain('font-medium') + }) + + it('should have secondary text color on URL', () => { + render() + const urlText = screen.getByText('https://api.example.com') + expect(urlText.className).toContain('text-text-secondary') + }) + }) + + describe('divider', () => { + it('should render vertical divider between URL and copy button', () => { + const { container } = render() + const divider = container.querySelector('.bg-divider-regular') + expect(divider).toBeInTheDocument() + }) + + it('should have correct divider dimensions', () => { + const { container } = render() + const divider = container.querySelector('.bg-divider-regular') + expect(divider?.className).toContain('h-[14px]') + expect(divider?.className).toContain('w-[1px]') + }) + }) + + describe('SecretKeyButton styling', () => { + it('should have shrink-0 class to prevent shrinking', () => { + render() + // The SecretKeyButton wraps a Button component + const button = screen.getByRole('button', { name: /apiKey/i }) + // Check parent container has shrink-0 + const buttonContainer = button.closest('.shrink-0') + expect(buttonContainer).toBeInTheDocument() + }) + }) + + describe('accessibility', () => { + it('should have accessible button for API key', () => { + render() + const button = screen.getByRole('button', { name: /apiKey/i }) + expect(button).toBeInTheDocument() + }) + + it('should have multiple buttons (copy + API key)', () => { + render() + const buttons = screen.getAllByRole('button') + expect(buttons.length).toBeGreaterThanOrEqual(2) + }) + }) +}) diff --git a/web/app/components/develop/code.spec.tsx b/web/app/components/develop/code.spec.tsx new file mode 100644 index 0000000000..b279c41a66 --- /dev/null +++ b/web/app/components/develop/code.spec.tsx @@ -0,0 +1,590 @@ +import { act, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { Code, CodeGroup, Embed, Pre } from './code' + +// Mock the clipboard utility +vi.mock('@/utils/clipboard', () => ({ + writeTextToClipboard: vi.fn().mockResolvedValue(undefined), +})) + +describe('code.tsx components', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.useFakeTimers({ shouldAdvanceTime: true }) + }) + + afterEach(() => { + vi.runOnlyPendingTimers() + vi.useRealTimers() + }) + + describe('Code', () => { + it('should render children', () => { + render(const x = 1) + expect(screen.getByText('const x = 1')).toBeInTheDocument() + }) + + it('should render as code element', () => { + render(code snippet) + const codeElement = screen.getByText('code snippet') + expect(codeElement.tagName).toBe('CODE') + }) + + it('should pass through additional props', () => { + render(snippet) + const codeElement = screen.getByTestId('custom-code') + expect(codeElement).toHaveClass('custom-class') + }) + + it('should render with complex children', () => { + render( + + part1 + part2 + , + ) + expect(screen.getByText('part1')).toBeInTheDocument() + expect(screen.getByText('part2')).toBeInTheDocument() + }) + }) + + describe('Embed', () => { + it('should render value prop', () => { + render(ignored children) + expect(screen.getByText('embedded content')).toBeInTheDocument() + }) + + it('should render as span element', () => { + render(children) + const span = screen.getByText('test value') + expect(span.tagName).toBe('SPAN') + }) + + it('should pass through additional props', () => { + render(children) + const embed = screen.getByTestId('embed-test') + expect(embed).toHaveClass('embed-class') + }) + + it('should not render children, only value', () => { + render(hidden children) + expect(screen.getByText('shown')).toBeInTheDocument() + expect(screen.queryByText('hidden children')).not.toBeInTheDocument() + }) + }) + + describe('CodeGroup', () => { + describe('with string targetCode', () => { + it('should render code from targetCode string', () => { + render( + +
fallback
+
, + ) + expect(screen.getByText('const hello = \'world\'')).toBeInTheDocument() + }) + + it('should have shadow and rounded styles', () => { + const { container } = render( + +
fallback
+
, + ) + const codeGroup = container.querySelector('.shadow-md') + expect(codeGroup).toBeInTheDocument() + expect(codeGroup).toHaveClass('rounded-2xl') + }) + + it('should have bg-zinc-900 background', () => { + const { container } = render( + +
fallback
+
, + ) + const codeGroup = container.querySelector('.bg-zinc-900') + expect(codeGroup).toBeInTheDocument() + }) + }) + + describe('with array targetCode', () => { + it('should render single code example without tabs', () => { + const examples = [{ code: 'single example' }] + render( + +
fallback
+
, + ) + expect(screen.getByText('single example')).toBeInTheDocument() + }) + + it('should render multiple code examples with tabs', () => { + const examples = [ + { title: 'JavaScript', code: 'console.log("js")' }, + { title: 'Python', code: 'print("py")' }, + ] + render( + +
fallback
+
, + ) + expect(screen.getByRole('tab', { name: 'JavaScript' })).toBeInTheDocument() + expect(screen.getByRole('tab', { name: 'Python' })).toBeInTheDocument() + }) + + it('should show first tab content by default', () => { + const examples = [ + { title: 'Tab1', code: 'first content' }, + { title: 'Tab2', code: 'second content' }, + ] + render( + +
fallback
+
, + ) + expect(screen.getByText('first content')).toBeInTheDocument() + }) + + it('should switch tabs on click', async () => { + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }) + const examples = [ + { title: 'Tab1', code: 'first content' }, + { title: 'Tab2', code: 'second content' }, + ] + render( + +
fallback
+
, + ) + + const tab2 = screen.getByRole('tab', { name: 'Tab2' }) + await act(async () => { + await user.click(tab2) + }) + + await waitFor(() => { + expect(screen.getByText('second content')).toBeInTheDocument() + }) + }) + + it('should use "Code" as default title when title not provided', () => { + const examples = [ + { code: 'example 1' }, + { code: 'example 2' }, + ] + render( + +
fallback
+
, + ) + const codeTabs = screen.getAllByRole('tab', { name: 'Code' }) + expect(codeTabs).toHaveLength(2) + }) + }) + + describe('with title prop', () => { + it('should render title in header', () => { + render( + +
fallback
+
, + ) + expect(screen.getByText('API Example')).toBeInTheDocument() + }) + + it('should render title in h3 element', () => { + render( + +
fallback
+
, + ) + const h3 = screen.getByRole('heading', { level: 3 }) + expect(h3).toHaveTextContent('Example Title') + }) + }) + + describe('with tag and label props', () => { + it('should render tag in code panel header', () => { + render( + +
fallback
+
, + ) + expect(screen.getByText('GET')).toBeInTheDocument() + }) + + it('should render label in code panel header', () => { + render( + +
fallback
+
, + ) + expect(screen.getByText('/api/users')).toBeInTheDocument() + }) + + it('should render both tag and label with separator', () => { + const { container } = render( + +
fallback
+
, + ) + expect(screen.getByText('POST')).toBeInTheDocument() + expect(screen.getByText('/api/create')).toBeInTheDocument() + // Separator should be present + const separator = container.querySelector('.rounded-full.bg-zinc-500') + expect(separator).toBeInTheDocument() + }) + }) + + describe('CopyButton functionality', () => { + it('should render copy button', () => { + render( + +
fallback
+
, + ) + const copyButton = screen.getByRole('button') + expect(copyButton).toBeInTheDocument() + }) + + it('should show "Copy" text initially', () => { + render( + +
fallback
+
, + ) + expect(screen.getByText('Copy')).toBeInTheDocument() + }) + + it('should show "Copied!" after clicking copy button', async () => { + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }) + const { writeTextToClipboard } = await import('@/utils/clipboard') + + render( + +
fallback
+
, + ) + + const copyButton = screen.getByRole('button') + await act(async () => { + await user.click(copyButton) + }) + + await waitFor(() => { + expect(writeTextToClipboard).toHaveBeenCalledWith('code to copy') + }) + + expect(screen.getByText('Copied!')).toBeInTheDocument() + }) + + it('should reset copy state after timeout', async () => { + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }) + + render( + +
fallback
+
, + ) + + const copyButton = screen.getByRole('button') + await act(async () => { + await user.click(copyButton) + }) + + await waitFor(() => { + expect(screen.getByText('Copied!')).toBeInTheDocument() + }) + + // Advance time past the timeout + await act(async () => { + vi.advanceTimersByTime(1500) + }) + + await waitFor(() => { + expect(screen.getByText('Copy')).toBeInTheDocument() + }) + }) + }) + + describe('without targetCode (using children)', () => { + it('should render children when no targetCode provided', () => { + render( + +
child code content
+
, + ) + expect(screen.getByText('child code content')).toBeInTheDocument() + }) + }) + + describe('styling', () => { + it('should have not-prose class to prevent prose styling', () => { + const { container } = render( + +
fallback
+
, + ) + const codeGroup = container.querySelector('.not-prose') + expect(codeGroup).toBeInTheDocument() + }) + + it('should have my-6 margin', () => { + const { container } = render( + +
fallback
+
, + ) + const codeGroup = container.querySelector('.my-6') + expect(codeGroup).toBeInTheDocument() + }) + + it('should have overflow-hidden', () => { + const { container } = render( + +
fallback
+
, + ) + const codeGroup = container.querySelector('.overflow-hidden') + expect(codeGroup).toBeInTheDocument() + }) + }) + }) + + describe('Pre', () => { + describe('when outside CodeGroup context', () => { + it('should wrap children in CodeGroup', () => { + const { container } = render( +
+            
code content
+
, + ) + // Should render within a CodeGroup structure + const codeGroup = container.querySelector('.bg-zinc-900') + expect(codeGroup).toBeInTheDocument() + }) + + it('should pass props to CodeGroup', () => { + render( +
+            
code
+
, + ) + expect(screen.getByText('Pre Title')).toBeInTheDocument() + }) + }) + + describe('when inside CodeGroup context (isGrouped)', () => { + it('should return children directly without wrapping', () => { + render( + +
+              inner code
+            
+
, + ) + // The outer code should be rendered (from targetCode) + expect(screen.getByText('outer code')).toBeInTheDocument() + }) + }) + }) + + describe('CodePanelHeader (via CodeGroup)', () => { + it('should not render when neither tag nor label provided', () => { + const { container } = render( + +
fallback
+
, + ) + const headerDivider = container.querySelector('.border-b-white\\/7\\.5') + expect(headerDivider).not.toBeInTheDocument() + }) + + it('should render when only tag is provided', () => { + render( + +
fallback
+
, + ) + expect(screen.getByText('GET')).toBeInTheDocument() + }) + + it('should render when only label is provided', () => { + render( + +
fallback
+
, + ) + expect(screen.getByText('/api/endpoint')).toBeInTheDocument() + }) + + it('should render label with font-mono styling', () => { + render( + +
fallback
+
, + ) + const label = screen.getByText('/api/test') + expect(label.className).toContain('font-mono') + expect(label.className).toContain('text-xs') + }) + }) + + describe('CodeGroupHeader (via CodeGroup with multiple tabs)', () => { + it('should render tab list for multiple examples', () => { + const examples = [ + { title: 'cURL', code: 'curl example' }, + { title: 'Node.js', code: 'node example' }, + ] + render( + +
fallback
+
, + ) + expect(screen.getByRole('tablist')).toBeInTheDocument() + }) + + it('should style active tab differently', () => { + const examples = [ + { title: 'Active', code: 'active code' }, + { title: 'Inactive', code: 'inactive code' }, + ] + render( + +
fallback
+
, + ) + const activeTab = screen.getByRole('tab', { name: 'Active' }) + expect(activeTab.className).toContain('border-emerald-500') + expect(activeTab.className).toContain('text-emerald-400') + }) + + it('should have header background styling', () => { + const examples = [ + { title: 'Tab1', code: 'code1' }, + { title: 'Tab2', code: 'code2' }, + ] + const { container } = render( + +
fallback
+
, + ) + const header = container.querySelector('.bg-zinc-800') + expect(header).toBeInTheDocument() + }) + }) + + describe('CodePanel (via CodeGroup)', () => { + it('should render code in pre element', () => { + render( + +
fallback
+
, + ) + const preElement = screen.getByText('pre content').closest('pre') + expect(preElement).toBeInTheDocument() + }) + + it('should have text-white class on pre', () => { + render( + +
fallback
+
, + ) + const preElement = screen.getByText('white text').closest('pre') + expect(preElement?.className).toContain('text-white') + }) + + it('should have text-xs class on pre', () => { + render( + +
fallback
+
, + ) + const preElement = screen.getByText('small text').closest('pre') + expect(preElement?.className).toContain('text-xs') + }) + + it('should have overflow-x-auto on pre', () => { + render( + +
fallback
+
, + ) + const preElement = screen.getByText('scrollable').closest('pre') + expect(preElement?.className).toContain('overflow-x-auto') + }) + + it('should have p-4 padding on pre', () => { + render( + +
fallback
+
, + ) + const preElement = screen.getByText('padded').closest('pre') + expect(preElement?.className).toContain('p-4') + }) + }) + + describe('ClipboardIcon (via CopyButton in CodeGroup)', () => { + it('should render clipboard icon in copy button', () => { + render( + +
fallback
+
, + ) + const copyButton = screen.getByRole('button') + const svg = copyButton.querySelector('svg') + expect(svg).toBeInTheDocument() + expect(svg).toHaveAttribute('viewBox', '0 0 20 20') + }) + }) + + describe('edge cases', () => { + it('should handle empty string targetCode', () => { + render( + +
fallback
+
, + ) + // Should render copy button even with empty code + expect(screen.getByRole('button')).toBeInTheDocument() + }) + + it('should handle targetCode with special characters', () => { + const specialCode = '
&
' + render( + +
fallback
+
, + ) + expect(screen.getByText(specialCode)).toBeInTheDocument() + }) + + it('should handle multiline targetCode', () => { + const multilineCode = `line1 +line2 +line3` + render( + +
fallback
+
, + ) + // Multiline code should be rendered - use a partial match + expect(screen.getByText(/line1/)).toBeInTheDocument() + expect(screen.getByText(/line2/)).toBeInTheDocument() + expect(screen.getByText(/line3/)).toBeInTheDocument() + }) + + it('should handle examples with tag property', () => { + const examples = [ + { title: 'Example', tag: 'v1', code: 'versioned code' }, + ] + render( + +
fallback
+
, + ) + expect(screen.getByText('versioned code')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/develop/index.spec.tsx b/web/app/components/develop/index.spec.tsx new file mode 100644 index 0000000000..f90e33e691 --- /dev/null +++ b/web/app/components/develop/index.spec.tsx @@ -0,0 +1,339 @@ +import { render, screen } from '@testing-library/react' +import DevelopMain from './index' + +// Mock the app store with a factory function to control state +const mockAppDetailValue: { current: unknown } = { current: undefined } +vi.mock('@/app/components/app/store', () => ({ + useStore: (selector: (state: unknown) => unknown) => { + const state = { appDetail: mockAppDetailValue.current } + return selector(state) + }, +})) + +// Mock the Doc component since it has complex dependencies +vi.mock('@/app/components/develop/doc', () => ({ + default: ({ appDetail }: { appDetail: { name?: string } | null }) => ( +
+ Doc Component - + {appDetail?.name} +
+ ), +})) + +// Mock the ApiServer component +vi.mock('@/app/components/develop/ApiServer', () => ({ + default: ({ apiBaseUrl, appId }: { apiBaseUrl: string, appId: string }) => ( +
+ API Server - + {apiBaseUrl} + {' '} + - + {appId} +
+ ), +})) + +describe('DevelopMain', () => { + beforeEach(() => { + vi.clearAllMocks() + mockAppDetailValue.current = undefined + }) + + describe('loading state', () => { + it('should show loading when appDetail is undefined', () => { + mockAppDetailValue.current = undefined + render() + + // Loading component renders with role="status" + expect(screen.getByRole('status')).toBeInTheDocument() + }) + + it('should show loading when appDetail is null', () => { + mockAppDetailValue.current = null + render() + + expect(screen.getByRole('status')).toBeInTheDocument() + }) + + it('should have centered loading container', () => { + mockAppDetailValue.current = undefined + const { container } = render() + + const loadingContainer = container.querySelector('.flex.h-full.items-center.justify-center') + expect(loadingContainer).toBeInTheDocument() + }) + + it('should have correct background on loading state', () => { + mockAppDetailValue.current = undefined + const { container } = render() + + const loadingContainer = container.querySelector('.bg-background-default') + expect(loadingContainer).toBeInTheDocument() + }) + }) + + describe('with appDetail loaded', () => { + const mockAppDetail = { + id: 'app-123', + name: 'Test Application', + api_base_url: 'https://api.example.com/v1', + mode: 'chat', + } + + beforeEach(() => { + mockAppDetailValue.current = mockAppDetail + }) + + it('should render ApiServer component', () => { + render() + expect(screen.getByTestId('api-server')).toBeInTheDocument() + }) + + it('should pass api_base_url to ApiServer', () => { + render() + expect(screen.getByTestId('api-server')).toHaveTextContent('https://api.example.com/v1') + }) + + it('should pass appId to ApiServer', () => { + render() + expect(screen.getByTestId('api-server')).toHaveTextContent('app-123') + }) + + it('should render Doc component', () => { + render() + expect(screen.getByTestId('doc-component')).toBeInTheDocument() + }) + + it('should pass appDetail to Doc component', () => { + render() + expect(screen.getByTestId('doc-component')).toHaveTextContent('Test Application') + }) + + it('should not show loading when appDetail exists', () => { + render() + expect(screen.queryByRole('status')).not.toBeInTheDocument() + }) + }) + + describe('layout structure', () => { + const mockAppDetail = { + id: 'app-123', + name: 'Test Application', + api_base_url: 'https://api.example.com', + mode: 'chat', + } + + beforeEach(() => { + mockAppDetailValue.current = mockAppDetail + }) + + it('should have flex column layout', () => { + const { container } = render() + const mainContainer = container.firstChild as HTMLElement + expect(mainContainer.className).toContain('flex') + expect(mainContainer.className).toContain('flex-col') + }) + + it('should have relative positioning', () => { + const { container } = render() + const mainContainer = container.firstChild as HTMLElement + expect(mainContainer.className).toContain('relative') + }) + + it('should have full height', () => { + const { container } = render() + const mainContainer = container.firstChild as HTMLElement + expect(mainContainer.className).toContain('h-full') + }) + + it('should have overflow-hidden', () => { + const { container } = render() + const mainContainer = container.firstChild as HTMLElement + expect(mainContainer.className).toContain('overflow-hidden') + }) + }) + + describe('header section', () => { + const mockAppDetail = { + id: 'app-123', + name: 'Test Application', + api_base_url: 'https://api.example.com', + mode: 'chat', + } + + beforeEach(() => { + mockAppDetailValue.current = mockAppDetail + }) + + it('should have header with border', () => { + const { container } = render() + const header = container.querySelector('.border-b') + expect(header).toBeInTheDocument() + }) + + it('should have shrink-0 on header to prevent shrinking', () => { + const { container } = render() + const header = container.querySelector('.shrink-0') + expect(header).toBeInTheDocument() + }) + + it('should have horizontal padding on header', () => { + const { container } = render() + const header = container.querySelector('.px-6') + expect(header).toBeInTheDocument() + }) + + it('should have vertical padding on header', () => { + const { container } = render() + const header = container.querySelector('.py-2') + expect(header).toBeInTheDocument() + }) + + it('should have items centered in header', () => { + const { container } = render() + const header = container.querySelector('.items-center') + expect(header).toBeInTheDocument() + }) + + it('should have justify-between in header', () => { + const { container } = render() + const header = container.querySelector('.justify-between') + expect(header).toBeInTheDocument() + }) + }) + + describe('content section', () => { + const mockAppDetail = { + id: 'app-123', + name: 'Test Application', + api_base_url: 'https://api.example.com', + mode: 'chat', + } + + beforeEach(() => { + mockAppDetailValue.current = mockAppDetail + }) + + it('should have grow class for content area', () => { + const { container } = render() + const content = container.querySelector('.grow') + expect(content).toBeInTheDocument() + }) + + it('should have overflow-auto for content scrolling', () => { + const { container } = render() + const content = container.querySelector('.overflow-auto') + expect(content).toBeInTheDocument() + }) + + it('should have horizontal padding on content', () => { + const { container } = render() + const content = container.querySelector('.px-4') + expect(content).toBeInTheDocument() + }) + + it('should have vertical padding on content', () => { + const { container } = render() + const content = container.querySelector('.py-4') + expect(content).toBeInTheDocument() + }) + + it('should have responsive padding', () => { + const { container } = render() + const content = container.querySelector('[class*="sm:px-10"]') + expect(content).toBeInTheDocument() + }) + }) + + describe('with different appIds', () => { + const mockAppDetail = { + id: 'app-456', + name: 'Another App', + api_base_url: 'https://another-api.com', + mode: 'completion', + } + + beforeEach(() => { + mockAppDetailValue.current = mockAppDetail + }) + + it('should pass different appId to ApiServer', () => { + render() + expect(screen.getByTestId('api-server')).toHaveTextContent('app-456') + }) + + it('should handle app with different api_base_url', () => { + render() + expect(screen.getByTestId('api-server')).toHaveTextContent('https://another-api.com') + }) + }) + + describe('empty state handling', () => { + it('should handle appDetail with minimal properties', () => { + mockAppDetailValue.current = { + api_base_url: 'https://api.test.com', + } + render() + expect(screen.getByTestId('api-server')).toBeInTheDocument() + }) + + it('should handle appDetail with empty api_base_url', () => { + mockAppDetailValue.current = { + api_base_url: '', + name: 'Empty URL App', + } + render() + expect(screen.getByTestId('api-server')).toBeInTheDocument() + }) + }) + + describe('title element', () => { + const mockAppDetail = { + id: 'app-123', + name: 'Test Application', + api_base_url: 'https://api.example.com', + mode: 'chat', + } + + beforeEach(() => { + mockAppDetailValue.current = mockAppDetail + }) + + it('should have title div with correct styling', () => { + const { container } = render() + const title = container.querySelector('.text-lg.font-medium.text-text-primary') + expect(title).toBeInTheDocument() + }) + + it('should render empty title div', () => { + const { container } = render() + const title = container.querySelector('.text-lg.font-medium.text-text-primary') + expect(title?.textContent).toBe('') + }) + }) + + describe('border styling', () => { + const mockAppDetail = { + id: 'app-123', + name: 'Test Application', + api_base_url: 'https://api.example.com', + mode: 'chat', + } + + beforeEach(() => { + mockAppDetailValue.current = mockAppDetail + }) + + it('should have solid border style', () => { + const { container } = render() + const header = container.querySelector('.border-solid') + expect(header).toBeInTheDocument() + }) + + it('should have divider regular color on border', () => { + const { container } = render() + const header = container.querySelector('.border-b-divider-regular') + expect(header).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/develop/md.spec.tsx b/web/app/components/develop/md.spec.tsx new file mode 100644 index 0000000000..8eab1c0ac8 --- /dev/null +++ b/web/app/components/develop/md.spec.tsx @@ -0,0 +1,655 @@ +import { render, screen } from '@testing-library/react' +import { Col, Heading, Properties, Property, PropertyInstruction, Row, SubProperty } from './md' + +describe('md.tsx components', () => { + describe('Heading', () => { + const defaultProps = { + url: '/api/messages', + method: 'GET' as const, + title: 'Get Messages', + name: '#get-messages', + } + + describe('rendering', () => { + it('should render the method badge', () => { + render() + expect(screen.getByText('GET')).toBeInTheDocument() + }) + + it('should render the url', () => { + render() + expect(screen.getByText('/api/messages')).toBeInTheDocument() + }) + + it('should render the title as a link', () => { + render() + const link = screen.getByRole('link', { name: 'Get Messages' }) + expect(link).toBeInTheDocument() + expect(link).toHaveAttribute('href', '#get-messages') + }) + + it('should render an anchor span with correct id', () => { + const { container } = render() + const anchor = container.querySelector('#get-messages') + expect(anchor).toBeInTheDocument() + }) + + it('should strip # prefix from name for id', () => { + const { container } = render() + const anchor = container.querySelector('#with-hash') + expect(anchor).toBeInTheDocument() + }) + }) + + describe('method styling', () => { + it('should apply emerald styles for GET method', () => { + render() + const badge = screen.getByText('GET') + expect(badge.className).toContain('text-emerald') + expect(badge.className).toContain('bg-emerald-400/10') + expect(badge.className).toContain('ring-emerald-300') + }) + + it('should apply sky styles for POST method', () => { + render() + const badge = screen.getByText('POST') + expect(badge.className).toContain('text-sky') + expect(badge.className).toContain('bg-sky-400/10') + expect(badge.className).toContain('ring-sky-300') + }) + + it('should apply amber styles for PUT method', () => { + render() + const badge = screen.getByText('PUT') + expect(badge.className).toContain('text-amber') + expect(badge.className).toContain('bg-amber-400/10') + expect(badge.className).toContain('ring-amber-300') + }) + + it('should apply rose styles for DELETE method', () => { + render() + const badge = screen.getByText('DELETE') + expect(badge.className).toContain('text-red') + expect(badge.className).toContain('bg-rose') + expect(badge.className).toContain('ring-rose') + }) + + it('should apply violet styles for PATCH method', () => { + render() + const badge = screen.getByText('PATCH') + expect(badge.className).toContain('text-violet') + expect(badge.className).toContain('bg-violet-400/10') + expect(badge.className).toContain('ring-violet-300') + }) + }) + + describe('badge base styles', () => { + it('should have rounded-lg class', () => { + render() + const badge = screen.getByText('GET') + expect(badge.className).toContain('rounded-lg') + }) + + it('should have font-mono class', () => { + render() + const badge = screen.getByText('GET') + expect(badge.className).toContain('font-mono') + }) + + it('should have font-semibold class', () => { + render() + const badge = screen.getByText('GET') + expect(badge.className).toContain('font-semibold') + }) + + it('should have ring-1 and ring-inset classes', () => { + render() + const badge = screen.getByText('GET') + expect(badge.className).toContain('ring-1') + expect(badge.className).toContain('ring-inset') + }) + }) + + describe('url styles', () => { + it('should have font-mono class on url', () => { + render() + const url = screen.getByText('/api/messages') + expect(url.className).toContain('font-mono') + }) + + it('should have text-xs class on url', () => { + render() + const url = screen.getByText('/api/messages') + expect(url.className).toContain('text-xs') + }) + + it('should have zinc text color on url', () => { + render() + const url = screen.getByText('/api/messages') + expect(url.className).toContain('text-zinc-400') + }) + }) + + describe('h2 element', () => { + it('should render title inside h2', () => { + render() + const h2 = screen.getByRole('heading', { level: 2 }) + expect(h2).toBeInTheDocument() + expect(h2).toHaveTextContent('Get Messages') + }) + + it('should have scroll-mt-32 class on h2', () => { + render() + const h2 = screen.getByRole('heading', { level: 2 }) + expect(h2.className).toContain('scroll-mt-32') + }) + }) + }) + + describe('Row', () => { + it('should render children', () => { + render( + +
Child 1
+
Child 2
+
, + ) + expect(screen.getByText('Child 1')).toBeInTheDocument() + expect(screen.getByText('Child 2')).toBeInTheDocument() + }) + + it('should have grid layout', () => { + const { container } = render( + +
Content
+
, + ) + const row = container.firstChild as HTMLElement + expect(row.className).toContain('grid') + expect(row.className).toContain('grid-cols-1') + }) + + it('should have gap classes', () => { + const { container } = render( + +
Content
+
, + ) + const row = container.firstChild as HTMLElement + expect(row.className).toContain('gap-x-16') + expect(row.className).toContain('gap-y-10') + }) + + it('should have xl responsive classes', () => { + const { container } = render( + +
Content
+
, + ) + const row = container.firstChild as HTMLElement + expect(row.className).toContain('xl:grid-cols-2') + expect(row.className).toContain('xl:!max-w-none') + }) + + it('should have items-start class', () => { + const { container } = render( + +
Content
+
, + ) + const row = container.firstChild as HTMLElement + expect(row.className).toContain('items-start') + }) + }) + + describe('Col', () => { + it('should render children', () => { + render( + +
Column Content
+ , + ) + expect(screen.getByText('Column Content')).toBeInTheDocument() + }) + + it('should have first/last child margin classes', () => { + const { container } = render( + +
Content
+ , + ) + const col = container.firstChild as HTMLElement + expect(col.className).toContain('[&>:first-child]:mt-0') + expect(col.className).toContain('[&>:last-child]:mb-0') + }) + + it('should apply sticky classes when sticky is true', () => { + const { container } = render( + +
Sticky Content
+ , + ) + const col = container.firstChild as HTMLElement + expect(col.className).toContain('xl:sticky') + expect(col.className).toContain('xl:top-24') + }) + + it('should not apply sticky classes when sticky is false', () => { + const { container } = render( + +
Non-sticky Content
+ , + ) + const col = container.firstChild as HTMLElement + expect(col.className).not.toContain('xl:sticky') + expect(col.className).not.toContain('xl:top-24') + }) + }) + + describe('Properties', () => { + it('should render children', () => { + render( + +
  • Property 1
  • +
  • Property 2
  • +
    , + ) + expect(screen.getByText('Property 1')).toBeInTheDocument() + expect(screen.getByText('Property 2')).toBeInTheDocument() + }) + + it('should render as ul with role list', () => { + render( + +
  • Property
  • +
    , + ) + const list = screen.getByRole('list') + expect(list).toBeInTheDocument() + expect(list.tagName).toBe('UL') + }) + + it('should have my-6 margin class', () => { + const { container } = render( + +
  • Property
  • +
    , + ) + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('my-6') + }) + + it('should have list-none class on ul', () => { + render( + +
  • Property
  • +
    , + ) + const list = screen.getByRole('list') + expect(list.className).toContain('list-none') + }) + + it('should have m-0 and p-0 classes on ul', () => { + render( + +
  • Property
  • +
    , + ) + const list = screen.getByRole('list') + expect(list.className).toContain('m-0') + expect(list.className).toContain('p-0') + }) + + it('should have divide-y class on ul', () => { + render( + +
  • Property
  • +
    , + ) + const list = screen.getByRole('list') + expect(list.className).toContain('divide-y') + }) + + it('should have max-w constraint class', () => { + render( + +
  • Property
  • +
    , + ) + const list = screen.getByRole('list') + expect(list.className).toContain('max-w-[calc(theme(maxWidth.lg)-theme(spacing.8))]') + }) + }) + + describe('Property', () => { + const defaultProps = { + name: 'user_id', + type: 'string', + anchor: false, + } + + it('should render name in code element', () => { + render( + + User identifier + , + ) + const code = screen.getByText('user_id') + expect(code.tagName).toBe('CODE') + }) + + it('should render type', () => { + render( + + User identifier + , + ) + expect(screen.getByText('string')).toBeInTheDocument() + }) + + it('should render children as description', () => { + render( + + User identifier + , + ) + expect(screen.getByText('User identifier')).toBeInTheDocument() + }) + + it('should render as li element', () => { + const { container } = render( + + Description + , + ) + expect(container.querySelector('li')).toBeInTheDocument() + }) + + it('should have m-0 class on li', () => { + const { container } = render( + + Description + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('m-0') + }) + + it('should have padding classes on li', () => { + const { container } = render( + + Description + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('px-0') + expect(li.className).toContain('py-4') + }) + + it('should have first:pt-0 and last:pb-0 classes', () => { + const { container } = render( + + Description + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('first:pt-0') + expect(li.className).toContain('last:pb-0') + }) + + it('should render dl element with proper structure', () => { + const { container } = render( + + Description + , + ) + expect(container.querySelector('dl')).toBeInTheDocument() + }) + + it('should have sr-only dt elements for accessibility', () => { + const { container } = render( + + User identifier + , + ) + const dtElements = container.querySelectorAll('dt') + expect(dtElements.length).toBe(3) + dtElements.forEach((dt) => { + expect(dt.className).toContain('sr-only') + }) + }) + + it('should have font-mono class on type', () => { + render( + + Description + , + ) + const typeElement = screen.getByText('string') + expect(typeElement.className).toContain('font-mono') + expect(typeElement.className).toContain('text-xs') + }) + }) + + describe('SubProperty', () => { + const defaultProps = { + name: 'sub_field', + type: 'number', + anchor: false, + } + + it('should render name in code element', () => { + render( + + Sub field description + , + ) + const code = screen.getByText('sub_field') + expect(code.tagName).toBe('CODE') + }) + + it('should render type', () => { + render( + + Sub field description + , + ) + expect(screen.getByText('number')).toBeInTheDocument() + }) + + it('should render children as description', () => { + render( + + Sub field description + , + ) + expect(screen.getByText('Sub field description')).toBeInTheDocument() + }) + + it('should render as li element', () => { + const { container } = render( + + Description + , + ) + expect(container.querySelector('li')).toBeInTheDocument() + }) + + it('should have m-0 class on li', () => { + const { container } = render( + + Description + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('m-0') + }) + + it('should have different padding than Property (py-1 vs py-4)', () => { + const { container } = render( + + Description + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('px-0') + expect(li.className).toContain('py-1') + }) + + it('should have last:pb-0 class', () => { + const { container } = render( + + Description + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('last:pb-0') + }) + + it('should render dl element with proper structure', () => { + const { container } = render( + + Description + , + ) + expect(container.querySelector('dl')).toBeInTheDocument() + }) + + it('should have sr-only dt elements for accessibility', () => { + const { container } = render( + + Sub field description + , + ) + const dtElements = container.querySelectorAll('dt') + expect(dtElements.length).toBe(3) + dtElements.forEach((dt) => { + expect(dt.className).toContain('sr-only') + }) + }) + + it('should have font-mono and text-xs on type', () => { + render( + + Description + , + ) + const typeElement = screen.getByText('number') + expect(typeElement.className).toContain('font-mono') + expect(typeElement.className).toContain('text-xs') + }) + }) + + describe('PropertyInstruction', () => { + it('should render children', () => { + render( + + This is an instruction + , + ) + expect(screen.getByText('This is an instruction')).toBeInTheDocument() + }) + + it('should render as li element', () => { + const { container } = render( + + Instruction text + , + ) + expect(container.querySelector('li')).toBeInTheDocument() + }) + + it('should have m-0 class', () => { + const { container } = render( + + Instruction + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('m-0') + }) + + it('should have padding classes', () => { + const { container } = render( + + Instruction + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('px-0') + expect(li.className).toContain('py-4') + }) + + it('should have italic class', () => { + const { container } = render( + + Instruction + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('italic') + }) + + it('should have first:pt-0 class', () => { + const { container } = render( + + Instruction + , + ) + const li = container.querySelector('li')! + expect(li.className).toContain('first:pt-0') + }) + }) + + describe('integration tests', () => { + it('should render Property inside Properties', () => { + render( + + + Unique identifier + + + Display name + + , + ) + + expect(screen.getByText('id')).toBeInTheDocument() + expect(screen.getByText('name')).toBeInTheDocument() + expect(screen.getByText('Unique identifier')).toBeInTheDocument() + expect(screen.getByText('Display name')).toBeInTheDocument() + }) + + it('should render Col inside Row', () => { + render( + + +
    Left column
    + + +
    Right column
    + +
    , + ) + + expect(screen.getByText('Left column')).toBeInTheDocument() + expect(screen.getByText('Right column')).toBeInTheDocument() + }) + + it('should render PropertyInstruction inside Properties', () => { + render( + + + Note: All fields are required + + + A required field + + , + ) + + expect(screen.getByText('Note: All fields are required')).toBeInTheDocument() + expect(screen.getByText('required_field')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/develop/secret-key/input-copy.spec.tsx b/web/app/components/develop/secret-key/input-copy.spec.tsx new file mode 100644 index 0000000000..0216f2bfad --- /dev/null +++ b/web/app/components/develop/secret-key/input-copy.spec.tsx @@ -0,0 +1,314 @@ +import { act, render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import copy from 'copy-to-clipboard' +import InputCopy from './input-copy' + +// Mock copy-to-clipboard +vi.mock('copy-to-clipboard', () => ({ + default: vi.fn().mockReturnValue(true), +})) + +describe('InputCopy', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.useFakeTimers({ shouldAdvanceTime: true }) + }) + + afterEach(() => { + vi.runOnlyPendingTimers() + vi.useRealTimers() + }) + + describe('rendering', () => { + it('should render the value', () => { + render() + expect(screen.getByText('test-api-key-12345')).toBeInTheDocument() + }) + + it('should render with empty value by default', () => { + render() + // Empty string should be rendered + expect(screen.getByRole('button')).toBeInTheDocument() + }) + + it('should render children when provided', () => { + render( + + Custom Content + , + ) + expect(screen.getByTestId('custom-child')).toBeInTheDocument() + }) + + it('should render CopyFeedback component', () => { + render() + // CopyFeedback should render a button + const buttons = screen.getAllByRole('button') + expect(buttons.length).toBeGreaterThan(0) + }) + }) + + describe('styling', () => { + it('should apply custom className', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('custom-class') + }) + + it('should have flex layout', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('flex') + }) + + it('should have items-center alignment', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('items-center') + }) + + it('should have rounded-lg class', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('rounded-lg') + }) + + it('should have background class', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('bg-components-input-bg-normal') + }) + + it('should have hover state', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('hover:bg-state-base-hover') + }) + + it('should have py-2 padding', () => { + const { container } = render() + const wrapper = container.firstChild as HTMLElement + expect(wrapper.className).toContain('py-2') + }) + }) + + describe('copy functionality', () => { + it('should copy value when clicked', async () => { + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }) + render() + + const copyableArea = screen.getByText('copy-this-value') + await act(async () => { + await user.click(copyableArea) + }) + + expect(copy).toHaveBeenCalledWith('copy-this-value') + }) + + it('should update copied state after clicking', async () => { + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }) + render() + + const copyableArea = screen.getByText('test-value') + await act(async () => { + await user.click(copyableArea) + }) + + // Copy function should have been called + expect(copy).toHaveBeenCalledWith('test-value') + }) + + it('should reset copied state after timeout', async () => { + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }) + render() + + const copyableArea = screen.getByText('test-value') + await act(async () => { + await user.click(copyableArea) + }) + + expect(copy).toHaveBeenCalledWith('test-value') + + // Advance time to reset the copied state + await act(async () => { + vi.advanceTimersByTime(1500) + }) + + // Component should still be functional + expect(screen.getByText('test-value')).toBeInTheDocument() + }) + + it('should render tooltip on value', () => { + render() + // Value should be wrapped in tooltip (tooltip shows on hover, not as visible text) + const valueText = screen.getByText('test-value') + expect(valueText).toBeInTheDocument() + }) + }) + + describe('tooltip', () => { + it('should render tooltip wrapper', () => { + render() + const valueText = screen.getByText('test') + expect(valueText).toBeInTheDocument() + }) + + it('should have cursor-pointer on clickable area', () => { + render() + const valueText = screen.getByText('test') + const clickableArea = valueText.closest('div[class*="cursor-pointer"]') + expect(clickableArea).toBeInTheDocument() + }) + }) + + describe('divider', () => { + it('should render vertical divider', () => { + const { container } = render() + const divider = container.querySelector('.bg-divider-regular') + expect(divider).toBeInTheDocument() + }) + + it('should have correct divider dimensions', () => { + const { container } = render() + const divider = container.querySelector('.bg-divider-regular') + expect(divider?.className).toContain('h-4') + expect(divider?.className).toContain('w-px') + }) + + it('should have shrink-0 on divider', () => { + const { container } = render() + const divider = container.querySelector('.bg-divider-regular') + expect(divider?.className).toContain('shrink-0') + }) + }) + + describe('value display', () => { + it('should have truncate class for long values', () => { + render() + const valueText = screen.getByText('very-long-api-key-value-that-might-overflow') + const container = valueText.closest('div[class*="truncate"]') + expect(container).toBeInTheDocument() + }) + + it('should have text-secondary color on value', () => { + render() + const valueText = screen.getByText('test-value') + expect(valueText.className).toContain('text-text-secondary') + }) + + it('should have absolute positioning for overlay', () => { + render() + const valueText = screen.getByText('test') + const container = valueText.closest('div[class*="absolute"]') + expect(container).toBeInTheDocument() + }) + }) + + describe('inner container', () => { + it('should have grow class on inner container', () => { + const { container } = render() + const innerContainer = container.querySelector('.grow') + expect(innerContainer).toBeInTheDocument() + }) + + it('should have h-5 height on inner container', () => { + const { container } = render() + const innerContainer = container.querySelector('.h-5') + expect(innerContainer).toBeInTheDocument() + }) + }) + + describe('with children', () => { + it('should render children before value', () => { + const { container } = render( + + Prefix: + , + ) + const children = container.querySelector('[data-testid="prefix"]') + expect(children).toBeInTheDocument() + }) + + it('should render both children and value', () => { + render( + + Label: + , + ) + expect(screen.getByText('Label:')).toBeInTheDocument() + expect(screen.getByText('api-key')).toBeInTheDocument() + }) + }) + + describe('CopyFeedback section', () => { + it('should have margin on CopyFeedback container', () => { + const { container } = render() + const copyFeedbackContainer = container.querySelector('.mx-1') + expect(copyFeedbackContainer).toBeInTheDocument() + }) + }) + + describe('relative container', () => { + it('should have relative positioning on value container', () => { + const { container } = render() + const relativeContainer = container.querySelector('.relative') + expect(relativeContainer).toBeInTheDocument() + }) + + it('should have grow on value container', () => { + const { container } = render() + // Find the relative container that also has grow + const valueContainer = container.querySelector('.relative.grow') + expect(valueContainer).toBeInTheDocument() + }) + + it('should have full height on value container', () => { + const { container } = render() + const valueContainer = container.querySelector('.relative.h-full') + expect(valueContainer).toBeInTheDocument() + }) + }) + + describe('edge cases', () => { + it('should handle undefined value', () => { + render() + // Should not crash + expect(screen.getByRole('button')).toBeInTheDocument() + }) + + it('should handle empty string value', () => { + render() + expect(screen.getByRole('button')).toBeInTheDocument() + }) + + it('should handle very long values', () => { + const longValue = 'a'.repeat(500) + render() + expect(screen.getByText(longValue)).toBeInTheDocument() + }) + + it('should handle special characters in value', () => { + const specialValue = 'key-with-special-chars!@#$%^&*()' + render() + expect(screen.getByText(specialValue)).toBeInTheDocument() + }) + }) + + describe('multiple clicks', () => { + it('should handle multiple rapid clicks', async () => { + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }) + render() + + const copyableArea = screen.getByText('test') + + // Click multiple times rapidly + await act(async () => { + await user.click(copyableArea) + await user.click(copyableArea) + await user.click(copyableArea) + }) + + expect(copy).toHaveBeenCalledTimes(3) + }) + }) +}) diff --git a/web/app/components/develop/secret-key/secret-key-button.spec.tsx b/web/app/components/develop/secret-key/secret-key-button.spec.tsx new file mode 100644 index 0000000000..4b4fbaab29 --- /dev/null +++ b/web/app/components/develop/secret-key/secret-key-button.spec.tsx @@ -0,0 +1,297 @@ +import { act, render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import SecretKeyButton from './secret-key-button' + +// Mock the SecretKeyModal since it has complex dependencies +vi.mock('@/app/components/develop/secret-key/secret-key-modal', () => ({ + default: ({ isShow, onClose, appId }: { isShow: boolean, onClose: () => void, appId?: string }) => ( + isShow + ? ( +
    + {`Modal for ${appId || 'no-app'}`} + +
    + ) + : null + ), +})) + +describe('SecretKeyButton', () => { + describe('rendering', () => { + it('should render the button', () => { + render() + expect(screen.getByRole('button')).toBeInTheDocument() + }) + + it('should render the API key text', () => { + render() + expect(screen.getByText('appApi.apiKey')).toBeInTheDocument() + }) + + it('should render the key icon', () => { + const { container } = render() + // RiKey2Line icon should be rendered as an svg + const svg = container.querySelector('svg') + expect(svg).toBeInTheDocument() + }) + + it('should not show modal initially', () => { + render() + expect(screen.queryByTestId('secret-key-modal')).not.toBeInTheDocument() + }) + }) + + describe('button interaction', () => { + it('should open modal when button is clicked', async () => { + const user = userEvent.setup() + render() + + const button = screen.getByRole('button') + await act(async () => { + await user.click(button) + }) + + expect(screen.getByTestId('secret-key-modal')).toBeInTheDocument() + }) + + it('should close modal when onClose is called', async () => { + const user = userEvent.setup() + render() + + // Open modal + const button = screen.getByRole('button') + await act(async () => { + await user.click(button) + }) + + expect(screen.getByTestId('secret-key-modal')).toBeInTheDocument() + + // Close modal + const closeButton = screen.getByTestId('close-modal') + await act(async () => { + await user.click(closeButton) + }) + + expect(screen.queryByTestId('secret-key-modal')).not.toBeInTheDocument() + }) + + it('should toggle modal visibility', async () => { + const user = userEvent.setup() + render() + + const button = screen.getByRole('button') + + // Open + await act(async () => { + await user.click(button) + }) + expect(screen.getByTestId('secret-key-modal')).toBeInTheDocument() + + // Close + const closeButton = screen.getByTestId('close-modal') + await act(async () => { + await user.click(closeButton) + }) + expect(screen.queryByTestId('secret-key-modal')).not.toBeInTheDocument() + + // Open again + await act(async () => { + await user.click(button) + }) + expect(screen.getByTestId('secret-key-modal')).toBeInTheDocument() + }) + }) + + describe('props', () => { + it('should apply custom className', () => { + render() + const button = screen.getByRole('button') + expect(button.className).toContain('custom-class') + }) + + it('should pass appId to modal', async () => { + const user = userEvent.setup() + render() + + const button = screen.getByRole('button') + await act(async () => { + await user.click(button) + }) + + expect(screen.getByText('Modal for app-123')).toBeInTheDocument() + }) + + it('should handle undefined appId', async () => { + const user = userEvent.setup() + render() + + const button = screen.getByRole('button') + await act(async () => { + await user.click(button) + }) + + expect(screen.getByText('Modal for no-app')).toBeInTheDocument() + }) + + it('should apply custom textCls', () => { + render() + const text = screen.getByText('appApi.apiKey') + expect(text.className).toContain('custom-text-class') + }) + }) + + describe('button styling', () => { + it('should have px-3 padding', () => { + render() + const button = screen.getByRole('button') + expect(button.className).toContain('px-3') + }) + + it('should have small size', () => { + render() + const button = screen.getByRole('button') + expect(button.className).toContain('btn-small') + }) + + it('should have ghost variant', () => { + render() + const button = screen.getByRole('button') + expect(button.className).toContain('btn-ghost') + }) + }) + + describe('icon styling', () => { + it('should have icon container with flex layout', () => { + const { container } = render() + const iconContainer = container.querySelector('.flex.items-center.justify-center') + expect(iconContainer).toBeInTheDocument() + }) + + it('should have correct icon dimensions', () => { + const { container } = render() + const iconContainer = container.querySelector('.h-3\\.5.w-3\\.5') + expect(iconContainer).toBeInTheDocument() + }) + + it('should have tertiary text color on icon', () => { + const { container } = render() + const icon = container.querySelector('.text-text-tertiary') + expect(icon).toBeInTheDocument() + }) + }) + + describe('text styling', () => { + it('should have system-xs-medium class', () => { + render() + const text = screen.getByText('appApi.apiKey') + expect(text.className).toContain('system-xs-medium') + }) + + it('should have horizontal padding', () => { + render() + const text = screen.getByText('appApi.apiKey') + expect(text.className).toContain('px-[3px]') + }) + + it('should have tertiary text color', () => { + render() + const text = screen.getByText('appApi.apiKey') + expect(text.className).toContain('text-text-tertiary') + }) + }) + + describe('modal props', () => { + it('should pass isShow prop to modal', async () => { + const user = userEvent.setup() + render() + + // Initially modal should not be visible + expect(screen.queryByTestId('secret-key-modal')).not.toBeInTheDocument() + + const button = screen.getByRole('button') + await act(async () => { + await user.click(button) + }) + + // Now modal should be visible + expect(screen.getByTestId('secret-key-modal')).toBeInTheDocument() + }) + + it('should pass onClose callback to modal', async () => { + const user = userEvent.setup() + render() + + const button = screen.getByRole('button') + await act(async () => { + await user.click(button) + }) + + const closeButton = screen.getByTestId('close-modal') + await act(async () => { + await user.click(closeButton) + }) + + // Modal should be closed after clicking close + expect(screen.queryByTestId('secret-key-modal')).not.toBeInTheDocument() + }) + }) + + describe('accessibility', () => { + it('should have accessible button', () => { + render() + const button = screen.getByRole('button') + expect(button).toBeInTheDocument() + }) + + it('should be keyboard accessible', async () => { + const user = userEvent.setup() + render() + + const button = screen.getByRole('button') + button.focus() + expect(document.activeElement).toBe(button) + + // Press Enter to activate + await act(async () => { + await user.keyboard('{Enter}') + }) + + expect(screen.getByTestId('secret-key-modal')).toBeInTheDocument() + }) + }) + + describe('multiple instances', () => { + it('should work independently when multiple instances exist', async () => { + const user = userEvent.setup() + render( + <> + + + , + ) + + const buttons = screen.getAllByRole('button') + expect(buttons).toHaveLength(2) + + // Click first button + await act(async () => { + await user.click(buttons[0]) + }) + + expect(screen.getByText('Modal for app-1')).toBeInTheDocument() + + // Close first modal + const closeButton = screen.getByTestId('close-modal') + await act(async () => { + await user.click(closeButton) + }) + + // Click second button + await act(async () => { + await user.click(buttons[1]) + }) + + expect(screen.getByText('Modal for app-2')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/develop/secret-key/secret-key-generate.spec.tsx b/web/app/components/develop/secret-key/secret-key-generate.spec.tsx new file mode 100644 index 0000000000..5988d6b7f3 --- /dev/null +++ b/web/app/components/develop/secret-key/secret-key-generate.spec.tsx @@ -0,0 +1,302 @@ +import type { CreateApiKeyResponse } from '@/models/app' +import { act, render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import SecretKeyGenerateModal from './secret-key-generate' + +// Helper to create a valid CreateApiKeyResponse +const createMockApiKey = (token: string): CreateApiKeyResponse => ({ + id: 'mock-id', + token, + created_at: '2024-01-01T00:00:00Z', +}) + +describe('SecretKeyGenerateModal', () => { + const defaultProps = { + isShow: true, + onClose: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering when shown', () => { + it('should render the modal when isShow is true', () => { + render() + expect(screen.getByText('appApi.apiKeyModal.apiSecretKey')).toBeInTheDocument() + }) + + it('should render the generate tips text', () => { + render() + expect(screen.getByText('appApi.apiKeyModal.generateTips')).toBeInTheDocument() + }) + + it('should render the OK button', () => { + render() + expect(screen.getByText('appApi.actionMsg.ok')).toBeInTheDocument() + }) + + it('should render the close icon', () => { + render() + // Modal renders via portal, so query from document.body + const closeIcon = document.body.querySelector('svg.cursor-pointer') + expect(closeIcon).toBeInTheDocument() + }) + + it('should render InputCopy component', () => { + render() + expect(screen.getByText('test-token-123')).toBeInTheDocument() + }) + }) + + describe('rendering when hidden', () => { + it('should not render content when isShow is false', () => { + render() + expect(screen.queryByText('appApi.apiKeyModal.apiSecretKey')).not.toBeInTheDocument() + }) + }) + + describe('newKey prop', () => { + it('should display the token when newKey is provided', () => { + render() + expect(screen.getByText('sk-abc123xyz')).toBeInTheDocument() + }) + + it('should handle undefined newKey', () => { + render() + // Should not crash and modal should still render + expect(screen.getByText('appApi.apiKeyModal.apiSecretKey')).toBeInTheDocument() + }) + + it('should handle newKey with empty token', () => { + render() + expect(screen.getByText('appApi.apiKeyModal.apiSecretKey')).toBeInTheDocument() + }) + + it('should display long tokens correctly', () => { + const longToken = `sk-${'a'.repeat(100)}` + render() + expect(screen.getByText(longToken)).toBeInTheDocument() + }) + }) + + describe('close functionality', () => { + it('should call onClose when X icon is clicked', async () => { + const user = userEvent.setup() + const onClose = vi.fn() + render() + + // Modal renders via portal + const closeIcon = document.body.querySelector('svg.cursor-pointer') + expect(closeIcon).toBeInTheDocument() + + await act(async () => { + await user.click(closeIcon!) + }) + + // HeadlessUI Dialog may trigger onClose multiple times (icon click handler + dialog close) + expect(onClose).toHaveBeenCalled() + }) + + it('should call onClose when OK button is clicked', async () => { + const user = userEvent.setup() + const onClose = vi.fn() + render() + + const okButton = screen.getByRole('button', { name: /ok/i }) + await act(async () => { + await user.click(okButton) + }) + + // HeadlessUI Dialog calls onClose both from button click and modal close + expect(onClose).toHaveBeenCalled() + }) + }) + + describe('className prop', () => { + it('should apply custom className', () => { + render( + , + ) + // Modal renders via portal + const modal = document.body.querySelector('.custom-modal-class') + expect(modal).toBeInTheDocument() + }) + + it('should apply shrink-0 class', () => { + render( + , + ) + // Modal renders via portal + const modal = document.body.querySelector('.shrink-0') + expect(modal).toBeInTheDocument() + }) + }) + + describe('modal styling', () => { + it('should have px-8 padding', () => { + render() + // Modal renders via portal + const modal = document.body.querySelector('.px-8') + expect(modal).toBeInTheDocument() + }) + }) + + describe('close icon styling', () => { + it('should have cursor-pointer class on close icon', () => { + render() + // Modal renders via portal + const closeIcon = document.body.querySelector('svg.cursor-pointer') + expect(closeIcon).toBeInTheDocument() + }) + + it('should have correct dimensions on close icon', () => { + render() + // Modal renders via portal + const closeIcon = document.body.querySelector('svg[class*="h-6"][class*="w-6"]') + expect(closeIcon).toBeInTheDocument() + }) + + it('should have tertiary text color on close icon', () => { + render() + // Modal renders via portal + const closeIcon = document.body.querySelector('svg[class*="text-text-tertiary"]') + expect(closeIcon).toBeInTheDocument() + }) + }) + + describe('header section', () => { + it('should have flex justify-end on close container', () => { + render() + // Modal renders via portal + const closeIcon = document.body.querySelector('svg.cursor-pointer') + const closeContainer = closeIcon?.parentElement + expect(closeContainer).toBeInTheDocument() + expect(closeContainer?.className).toContain('flex') + expect(closeContainer?.className).toContain('justify-end') + }) + + it('should have negative margin on close container', () => { + render() + // Modal renders via portal + const closeIcon = document.body.querySelector('svg.cursor-pointer') + const closeContainer = closeIcon?.parentElement + expect(closeContainer).toBeInTheDocument() + expect(closeContainer?.className).toContain('-mr-2') + expect(closeContainer?.className).toContain('-mt-6') + }) + + it('should have bottom margin on close container', () => { + render() + // Modal renders via portal + const closeIcon = document.body.querySelector('svg.cursor-pointer') + const closeContainer = closeIcon?.parentElement + expect(closeContainer).toBeInTheDocument() + expect(closeContainer?.className).toContain('mb-4') + }) + }) + + describe('tips text styling', () => { + it('should have mt-1 margin on tips', () => { + render() + const tips = screen.getByText('appApi.apiKeyModal.generateTips') + expect(tips.className).toContain('mt-1') + }) + + it('should have correct font size', () => { + render() + const tips = screen.getByText('appApi.apiKeyModal.generateTips') + expect(tips.className).toContain('text-[13px]') + }) + + it('should have normal font weight', () => { + render() + const tips = screen.getByText('appApi.apiKeyModal.generateTips') + expect(tips.className).toContain('font-normal') + }) + + it('should have leading-5 line height', () => { + render() + const tips = screen.getByText('appApi.apiKeyModal.generateTips') + expect(tips.className).toContain('leading-5') + }) + + it('should have tertiary text color', () => { + render() + const tips = screen.getByText('appApi.apiKeyModal.generateTips') + expect(tips.className).toContain('text-text-tertiary') + }) + }) + + describe('InputCopy section', () => { + it('should render InputCopy with token value', () => { + render() + expect(screen.getByText('test-token')).toBeInTheDocument() + }) + + it('should have w-full class on InputCopy', () => { + render() + // The InputCopy component should have w-full + const inputText = screen.getByText('test') + const inputContainer = inputText.closest('.w-full') + expect(inputContainer).toBeInTheDocument() + }) + }) + + describe('OK button section', () => { + it('should render OK button', () => { + render() + const button = screen.getByRole('button', { name: /ok/i }) + expect(button).toBeInTheDocument() + }) + + it('should have button container with flex layout', () => { + render() + const button = screen.getByRole('button', { name: /ok/i }) + const container = button.parentElement + expect(container).toBeInTheDocument() + expect(container?.className).toContain('flex') + }) + + it('should have shrink-0 on button', () => { + render() + const button = screen.getByRole('button', { name: /ok/i }) + expect(button.className).toContain('shrink-0') + }) + }) + + describe('button text styling', () => { + it('should have text-xs font size on button text', () => { + render() + const buttonText = screen.getByText('appApi.actionMsg.ok') + expect(buttonText.className).toContain('text-xs') + }) + + it('should have font-medium on button text', () => { + render() + const buttonText = screen.getByText('appApi.actionMsg.ok') + expect(buttonText.className).toContain('font-medium') + }) + + it('should have secondary text color on button text', () => { + render() + const buttonText = screen.getByText('appApi.actionMsg.ok') + expect(buttonText.className).toContain('text-text-secondary') + }) + }) + + describe('default prop values', () => { + it('should default isShow to false', () => { + // When isShow is explicitly set to false + render() + expect(screen.queryByText('appApi.apiKeyModal.apiSecretKey')).not.toBeInTheDocument() + }) + }) + + describe('modal title', () => { + it('should display the correct title', () => { + render() + expect(screen.getByText('appApi.apiKeyModal.apiSecretKey')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/develop/secret-key/secret-key-modal.spec.tsx b/web/app/components/develop/secret-key/secret-key-modal.spec.tsx new file mode 100644 index 0000000000..79c51759ea --- /dev/null +++ b/web/app/components/develop/secret-key/secret-key-modal.spec.tsx @@ -0,0 +1,614 @@ +import { act, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import SecretKeyModal from './secret-key-modal' + +// Mock the app context +const mockCurrentWorkspace = vi.fn().mockReturnValue({ + id: 'workspace-1', + name: 'Test Workspace', +}) +const mockIsCurrentWorkspaceManager = vi.fn().mockReturnValue(true) +const mockIsCurrentWorkspaceEditor = vi.fn().mockReturnValue(true) + +vi.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + currentWorkspace: mockCurrentWorkspace(), + isCurrentWorkspaceManager: mockIsCurrentWorkspaceManager(), + isCurrentWorkspaceEditor: mockIsCurrentWorkspaceEditor(), + }), +})) + +// Mock the timestamp hook +vi.mock('@/hooks/use-timestamp', () => ({ + default: () => ({ + formatTime: vi.fn((value: number, _format: string) => `Formatted: ${value}`), + formatDate: vi.fn((value: string, _format: string) => `Formatted: ${value}`), + }), +})) + +// Mock API services +const mockCreateAppApikey = vi.fn().mockResolvedValue({ token: 'new-app-token-123' }) +const mockDelAppApikey = vi.fn().mockResolvedValue({}) +vi.mock('@/service/apps', () => ({ + createApikey: (...args: unknown[]) => mockCreateAppApikey(...args), + delApikey: (...args: unknown[]) => mockDelAppApikey(...args), +})) + +const mockCreateDatasetApikey = vi.fn().mockResolvedValue({ token: 'new-dataset-token-123' }) +const mockDelDatasetApikey = vi.fn().mockResolvedValue({}) +vi.mock('@/service/datasets', () => ({ + createApikey: (...args: unknown[]) => mockCreateDatasetApikey(...args), + delApikey: (...args: unknown[]) => mockDelDatasetApikey(...args), +})) + +// Mock React Query hooks for apps +const mockAppApiKeysData = vi.fn().mockReturnValue({ data: [] }) +const mockIsAppApiKeysLoading = vi.fn().mockReturnValue(false) +const mockInvalidateAppApiKeys = vi.fn() + +vi.mock('@/service/use-apps', () => ({ + useAppApiKeys: (_appId: string, _options: unknown) => ({ + data: mockAppApiKeysData(), + isLoading: mockIsAppApiKeysLoading(), + }), + useInvalidateAppApiKeys: () => mockInvalidateAppApiKeys, +})) + +// Mock React Query hooks for datasets +const mockDatasetApiKeysData = vi.fn().mockReturnValue({ data: [] }) +const mockIsDatasetApiKeysLoading = vi.fn().mockReturnValue(false) +const mockInvalidateDatasetApiKeys = vi.fn() + +vi.mock('@/service/knowledge/use-dataset', () => ({ + useDatasetApiKeys: (_options: unknown) => ({ + data: mockDatasetApiKeysData(), + isLoading: mockIsDatasetApiKeysLoading(), + }), + useInvalidateDatasetApiKeys: () => mockInvalidateDatasetApiKeys, +})) + +describe('SecretKeyModal', () => { + const defaultProps = { + isShow: true, + onClose: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + mockCurrentWorkspace.mockReturnValue({ id: 'workspace-1', name: 'Test Workspace' }) + mockIsCurrentWorkspaceManager.mockReturnValue(true) + mockIsCurrentWorkspaceEditor.mockReturnValue(true) + mockAppApiKeysData.mockReturnValue({ data: [] }) + mockIsAppApiKeysLoading.mockReturnValue(false) + mockDatasetApiKeysData.mockReturnValue({ data: [] }) + mockIsDatasetApiKeysLoading.mockReturnValue(false) + }) + + describe('rendering when shown', () => { + it('should render the modal when isShow is true', () => { + render() + expect(screen.getByText('appApi.apiKeyModal.apiSecretKey')).toBeInTheDocument() + }) + + it('should render the tips text', () => { + render() + expect(screen.getByText('appApi.apiKeyModal.apiSecretKeyTips')).toBeInTheDocument() + }) + + it('should render the create new key button', () => { + render() + expect(screen.getByText('appApi.apiKeyModal.createNewSecretKey')).toBeInTheDocument() + }) + + it('should render the close icon', () => { + render() + // Modal renders via portal, so we need to query from document.body + const closeIcon = document.body.querySelector('svg.cursor-pointer') + expect(closeIcon).toBeInTheDocument() + }) + }) + + describe('rendering when hidden', () => { + it('should not render content when isShow is false', () => { + render() + expect(screen.queryByText('appApi.apiKeyModal.apiSecretKey')).not.toBeInTheDocument() + }) + }) + + describe('loading state', () => { + it('should show loading when app API keys are loading', () => { + mockIsAppApiKeysLoading.mockReturnValue(true) + render() + expect(screen.getByRole('status')).toBeInTheDocument() + }) + + it('should show loading when dataset API keys are loading', () => { + mockIsDatasetApiKeysLoading.mockReturnValue(true) + render() + expect(screen.getByRole('status')).toBeInTheDocument() + }) + + it('should not show loading when data is loaded', () => { + mockIsAppApiKeysLoading.mockReturnValue(false) + render() + expect(screen.queryByRole('status')).not.toBeInTheDocument() + }) + }) + + describe('API keys list for app', () => { + const apiKeys = [ + { id: 'key-1', token: 'sk-abc123def456ghi789', created_at: 1700000000, last_used_at: 1700100000 }, + { id: 'key-2', token: 'sk-xyz987wvu654tsr321', created_at: 1700050000, last_used_at: null }, + ] + + beforeEach(() => { + mockAppApiKeysData.mockReturnValue({ data: apiKeys }) + }) + + it('should render API keys when available', () => { + render() + // Token 'sk-abc123def456ghi789' (21 chars) -> first 3 'sk-' + '...' + last 20 'k-abc123def456ghi789' + expect(screen.getByText('sk-...k-abc123def456ghi789')).toBeInTheDocument() + }) + + it('should render created time for keys', () => { + render() + expect(screen.getByText('Formatted: 1700000000')).toBeInTheDocument() + }) + + it('should render last used time for keys', () => { + render() + expect(screen.getByText('Formatted: 1700100000')).toBeInTheDocument() + }) + + it('should render "never" for keys without last_used_at', () => { + render() + expect(screen.getByText('appApi.never')).toBeInTheDocument() + }) + + it('should render delete button for managers', () => { + render() + // Delete button contains RiDeleteBinLine SVG - look for SVGs with h-4 w-4 class within buttons + const buttons = screen.getAllByRole('button') + // There should be at least 3 buttons: copy feedback, delete, and create + expect(buttons.length).toBeGreaterThanOrEqual(2) + // Check for delete icon SVG - Modal renders via portal + const deleteIcon = document.body.querySelector('svg[class*="h-4"][class*="w-4"]') + expect(deleteIcon).toBeInTheDocument() + }) + + it('should not render delete button for non-managers', () => { + mockIsCurrentWorkspaceManager.mockReturnValue(false) + render() + // The specific delete action button should not be present + const actionButtons = screen.getAllByRole('button') + // Should only have copy and create buttons, not delete + expect(actionButtons.length).toBeGreaterThan(0) + }) + + it('should render table headers', () => { + render() + expect(screen.getByText('appApi.apiKeyModal.secretKey')).toBeInTheDocument() + expect(screen.getByText('appApi.apiKeyModal.created')).toBeInTheDocument() + expect(screen.getByText('appApi.apiKeyModal.lastUsed')).toBeInTheDocument() + }) + }) + + describe('API keys list for dataset', () => { + const datasetKeys = [ + { id: 'dk-1', token: 'dk-abc123def456ghi789', created_at: 1700000000, last_used_at: 1700100000 }, + ] + + beforeEach(() => { + mockDatasetApiKeysData.mockReturnValue({ data: datasetKeys }) + }) + + it('should render dataset API keys when no appId', () => { + render() + // Token 'dk-abc123def456ghi789' (21 chars) -> first 3 'dk-' + '...' + last 20 'k-abc123def456ghi789' + expect(screen.getByText('dk-...k-abc123def456ghi789')).toBeInTheDocument() + }) + }) + + describe('close functionality', () => { + it('should call onClose when X icon is clicked', async () => { + const user = userEvent.setup() + const onClose = vi.fn() + render() + + // Modal renders via portal, so we need to query from document.body + const closeIcon = document.body.querySelector('svg.cursor-pointer') + expect(closeIcon).toBeInTheDocument() + + await act(async () => { + await user.click(closeIcon!) + }) + + expect(onClose).toHaveBeenCalledTimes(1) + }) + }) + + describe('create new key', () => { + it('should call create API for app when button is clicked', async () => { + const user = userEvent.setup() + render() + + const createButton = screen.getByText('appApi.apiKeyModal.createNewSecretKey') + await act(async () => { + await user.click(createButton) + }) + + await waitFor(() => { + expect(mockCreateAppApikey).toHaveBeenCalledWith({ + url: '/apps/app-123/api-keys', + body: {}, + }) + }) + }) + + it('should call create API for dataset when no appId', async () => { + const user = userEvent.setup() + render() + + const createButton = screen.getByText('appApi.apiKeyModal.createNewSecretKey') + await act(async () => { + await user.click(createButton) + }) + + await waitFor(() => { + expect(mockCreateDatasetApikey).toHaveBeenCalledWith({ + url: '/datasets/api-keys', + body: {}, + }) + }) + }) + + it('should show generate modal after creating key', async () => { + const user = userEvent.setup() + render() + + const createButton = screen.getByText('appApi.apiKeyModal.createNewSecretKey') + await act(async () => { + await user.click(createButton) + }) + + await waitFor(() => { + // The SecretKeyGenerateModal should be shown with the new token + expect(screen.getByText('appApi.apiKeyModal.generateTips')).toBeInTheDocument() + }) + }) + + it('should invalidate app API keys after creating', async () => { + const user = userEvent.setup() + render() + + const createButton = screen.getByText('appApi.apiKeyModal.createNewSecretKey') + await act(async () => { + await user.click(createButton) + }) + + await waitFor(() => { + expect(mockInvalidateAppApiKeys).toHaveBeenCalledWith('app-123') + }) + }) + + it('should invalidate dataset API keys after creating (no appId)', async () => { + const user = userEvent.setup() + render() + + const createButton = screen.getByText('appApi.apiKeyModal.createNewSecretKey') + await act(async () => { + await user.click(createButton) + }) + + await waitFor(() => { + expect(mockInvalidateDatasetApiKeys).toHaveBeenCalled() + }) + }) + + it('should disable create button when no workspace', () => { + mockCurrentWorkspace.mockReturnValue(null) + render() + + const createButton = screen.getByText('appApi.apiKeyModal.createNewSecretKey').closest('button') + expect(createButton).toBeDisabled() + }) + + it('should disable create button when not editor', () => { + mockIsCurrentWorkspaceEditor.mockReturnValue(false) + render() + + const createButton = screen.getByText('appApi.apiKeyModal.createNewSecretKey').closest('button') + expect(createButton).toBeDisabled() + }) + }) + + describe('delete key', () => { + const apiKeys = [ + { id: 'key-1', token: 'sk-abc123def456ghi789', created_at: 1700000000, last_used_at: 1700100000 }, + ] + + beforeEach(() => { + mockAppApiKeysData.mockReturnValue({ data: apiKeys }) + }) + + it('should render delete button for managers', () => { + render() + + // Find buttons that contain SVG (delete/copy buttons) + const actionButtons = screen.getAllByRole('button') + // There should be at least copy, delete, and create buttons + expect(actionButtons.length).toBeGreaterThanOrEqual(3) + }) + + it('should render API key row with actions', () => { + render() + + // Verify the truncated token is rendered + expect(screen.getByText('sk-...k-abc123def456ghi789')).toBeInTheDocument() + }) + + it('should have action buttons in the key row', () => { + render() + + // Check for action button containers - Modal renders via portal + const actionContainers = document.body.querySelectorAll('[class*="space-x-2"]') + expect(actionContainers.length).toBeGreaterThan(0) + }) + + it('should have delete button visible for managers', async () => { + render() + + // Find the delete button by looking for the button with the delete icon + const deleteIcon = document.body.querySelector('svg[class*="h-4"][class*="w-4"]') + const deleteButton = deleteIcon?.closest('button') + expect(deleteButton).toBeInTheDocument() + }) + + it('should show confirm dialog when delete button is clicked', async () => { + const user = userEvent.setup() + render() + + // Find delete button by action-btn class (second action button after copy) + const actionButtons = document.body.querySelectorAll('button.action-btn') + // The delete button is the second action button (first is copy) + const deleteButton = actionButtons[1] + expect(deleteButton).toBeInTheDocument() + + await act(async () => { + await user.click(deleteButton!) + }) + + // Confirm dialog should appear + await waitFor(() => { + expect(screen.getByText('appApi.actionMsg.deleteConfirmTitle')).toBeInTheDocument() + expect(screen.getByText('appApi.actionMsg.deleteConfirmTips')).toBeInTheDocument() + }) + }) + + it('should call delete API for app when confirmed', async () => { + const user = userEvent.setup() + render() + + // Find and click delete button + const actionButtons = document.body.querySelectorAll('button.action-btn') + const deleteButton = actionButtons[1] + await act(async () => { + await user.click(deleteButton!) + }) + + // Wait for confirm dialog and click confirm + await waitFor(() => { + expect(screen.getByText('appApi.actionMsg.deleteConfirmTitle')).toBeInTheDocument() + }) + + // Find and click the confirm button + const confirmButton = screen.getByText('common.operation.confirm') + await act(async () => { + await user.click(confirmButton) + }) + + await waitFor(() => { + expect(mockDelAppApikey).toHaveBeenCalledWith({ + url: '/apps/app-123/api-keys/key-1', + params: {}, + }) + }) + }) + + it('should invalidate app API keys after deleting', async () => { + const user = userEvent.setup() + render() + + // Find and click delete button + const actionButtons = document.body.querySelectorAll('button.action-btn') + const deleteButton = actionButtons[1] + await act(async () => { + await user.click(deleteButton!) + }) + + // Wait for confirm dialog and click confirm + await waitFor(() => { + expect(screen.getByText('appApi.actionMsg.deleteConfirmTitle')).toBeInTheDocument() + }) + + const confirmButton = screen.getByText('common.operation.confirm') + await act(async () => { + await user.click(confirmButton) + }) + + await waitFor(() => { + expect(mockInvalidateAppApiKeys).toHaveBeenCalledWith('app-123') + }) + }) + + it('should close confirm dialog and clear delKeyId when cancel is clicked', async () => { + const user = userEvent.setup() + render() + + // Find and click delete button + const actionButtons = document.body.querySelectorAll('button.action-btn') + const deleteButton = actionButtons[1] + await act(async () => { + await user.click(deleteButton!) + }) + + // Wait for confirm dialog + await waitFor(() => { + expect(screen.getByText('appApi.actionMsg.deleteConfirmTitle')).toBeInTheDocument() + }) + + // Click cancel button + const cancelButton = screen.getByText('common.operation.cancel') + await act(async () => { + await user.click(cancelButton) + }) + + // Confirm dialog should close + await waitFor(() => { + expect(screen.queryByText('appApi.actionMsg.deleteConfirmTitle')).not.toBeInTheDocument() + }) + + // Delete API should not be called + expect(mockDelAppApikey).not.toHaveBeenCalled() + }) + }) + + describe('delete key for dataset', () => { + const datasetKeys = [ + { id: 'dk-1', token: 'dk-abc123def456ghi789', created_at: 1700000000, last_used_at: 1700100000 }, + ] + + beforeEach(() => { + mockDatasetApiKeysData.mockReturnValue({ data: datasetKeys }) + }) + + it('should call delete API for dataset when no appId', async () => { + const user = userEvent.setup() + render() + + // Find and click delete button + const actionButtons = document.body.querySelectorAll('button.action-btn') + const deleteButton = actionButtons[1] + await act(async () => { + await user.click(deleteButton!) + }) + + // Wait for confirm dialog and click confirm + await waitFor(() => { + expect(screen.getByText('appApi.actionMsg.deleteConfirmTitle')).toBeInTheDocument() + }) + + const confirmButton = screen.getByText('common.operation.confirm') + await act(async () => { + await user.click(confirmButton) + }) + + await waitFor(() => { + expect(mockDelDatasetApikey).toHaveBeenCalledWith({ + url: '/datasets/api-keys/dk-1', + params: {}, + }) + }) + }) + + it('should invalidate dataset API keys after deleting', async () => { + const user = userEvent.setup() + render() + + // Find and click delete button + const actionButtons = document.body.querySelectorAll('button.action-btn') + const deleteButton = actionButtons[1] + await act(async () => { + await user.click(deleteButton!) + }) + + // Wait for confirm dialog and click confirm + await waitFor(() => { + expect(screen.getByText('appApi.actionMsg.deleteConfirmTitle')).toBeInTheDocument() + }) + + const confirmButton = screen.getByText('common.operation.confirm') + await act(async () => { + await user.click(confirmButton) + }) + + await waitFor(() => { + expect(mockInvalidateDatasetApiKeys).toHaveBeenCalled() + }) + }) + }) + + describe('token truncation', () => { + it('should truncate token correctly', () => { + const apiKeys = [ + { id: 'key-1', token: 'sk-abcdefghijklmnopqrstuvwxyz1234567890', created_at: 1700000000, last_used_at: null }, + ] + mockAppApiKeysData.mockReturnValue({ data: apiKeys }) + + render() + + // Token format: first 3 chars + ... + last 20 chars + // 'sk-abcdefghijklmnopqrstuvwxyz1234567890' -> 'sk-...qrstuvwxyz1234567890' + expect(screen.getByText('sk-...qrstuvwxyz1234567890')).toBeInTheDocument() + }) + }) + + describe('styling', () => { + it('should render modal with expected structure', () => { + render() + // Modal should render and contain the title + expect(screen.getByText('appApi.apiKeyModal.apiSecretKey')).toBeInTheDocument() + }) + + it('should render create button with flex styling', () => { + render() + // Modal renders via portal, so query from document.body + const flexContainers = document.body.querySelectorAll('[class*="flex"]') + expect(flexContainers.length).toBeGreaterThan(0) + }) + }) + + describe('empty state', () => { + it('should not render table when no keys', () => { + mockAppApiKeysData.mockReturnValue({ data: [] }) + render() + + expect(screen.queryByText('appApi.apiKeyModal.secretKey')).not.toBeInTheDocument() + }) + + it('should not render table when data is null', () => { + mockAppApiKeysData.mockReturnValue(null) + render() + + expect(screen.queryByText('appApi.apiKeyModal.secretKey')).not.toBeInTheDocument() + }) + }) + + describe('SecretKeyGenerateModal', () => { + it('should close generate modal on close', async () => { + const user = userEvent.setup() + render() + + // Create a new key to open generate modal + const createButton = screen.getByText('appApi.apiKeyModal.createNewSecretKey') + await act(async () => { + await user.click(createButton) + }) + + await waitFor(() => { + expect(screen.getByText('appApi.apiKeyModal.generateTips')).toBeInTheDocument() + }) + + // Find and click the close/OK button in generate modal + const okButton = screen.getByText('appApi.actionMsg.ok') + await act(async () => { + await user.click(okButton) + }) + + await waitFor(() => { + expect(screen.queryByText('appApi.apiKeyModal.generateTips')).not.toBeInTheDocument() + }) + }) + }) +}) diff --git a/web/app/components/develop/tag.spec.tsx b/web/app/components/develop/tag.spec.tsx new file mode 100644 index 0000000000..60a12040fa --- /dev/null +++ b/web/app/components/develop/tag.spec.tsx @@ -0,0 +1,242 @@ +import { render, screen } from '@testing-library/react' +import { Tag } from './tag' + +describe('Tag', () => { + describe('rendering', () => { + it('should render children text', () => { + render(GET) + expect(screen.getByText('GET')).toBeInTheDocument() + }) + + it('should render as a span element', () => { + render(POST) + const tag = screen.getByText('POST') + expect(tag.tagName).toBe('SPAN') + }) + }) + + describe('default color mapping based on HTTP methods', () => { + it('should apply emerald color for GET method', () => { + render(GET) + const tag = screen.getByText('GET') + expect(tag.className).toContain('text-emerald') + }) + + it('should apply sky color for POST method', () => { + render(POST) + const tag = screen.getByText('POST') + expect(tag.className).toContain('text-sky') + }) + + it('should apply amber color for PUT method', () => { + render(PUT) + const tag = screen.getByText('PUT') + expect(tag.className).toContain('text-amber') + }) + + it('should apply rose color for DELETE method', () => { + render(DELETE) + const tag = screen.getByText('DELETE') + expect(tag.className).toContain('text-red') + }) + + it('should apply emerald color for unknown methods', () => { + render(UNKNOWN) + const tag = screen.getByText('UNKNOWN') + expect(tag.className).toContain('text-emerald') + }) + + it('should handle lowercase method names', () => { + render(get) + const tag = screen.getByText('get') + expect(tag.className).toContain('text-emerald') + }) + + it('should handle mixed case method names', () => { + render(Post) + const tag = screen.getByText('Post') + expect(tag.className).toContain('text-sky') + }) + }) + + describe('custom color prop', () => { + it('should override default color with custom emerald color', () => { + render(CUSTOM) + const tag = screen.getByText('CUSTOM') + expect(tag.className).toContain('text-emerald') + }) + + it('should override default color with custom sky color', () => { + render(CUSTOM) + const tag = screen.getByText('CUSTOM') + expect(tag.className).toContain('text-sky') + }) + + it('should override default color with custom amber color', () => { + render(CUSTOM) + const tag = screen.getByText('CUSTOM') + expect(tag.className).toContain('text-amber') + }) + + it('should override default color with custom rose color', () => { + render(CUSTOM) + const tag = screen.getByText('CUSTOM') + expect(tag.className).toContain('text-red') + }) + + it('should override default color with custom zinc color', () => { + render(CUSTOM) + const tag = screen.getByText('CUSTOM') + expect(tag.className).toContain('text-zinc') + }) + + it('should override automatic color mapping with explicit color', () => { + render(GET) + const tag = screen.getByText('GET') + expect(tag.className).toContain('text-sky') + }) + }) + + describe('variant styles', () => { + it('should apply medium variant styles by default', () => { + render(GET) + const tag = screen.getByText('GET') + expect(tag.className).toContain('rounded-lg') + expect(tag.className).toContain('px-1.5') + expect(tag.className).toContain('ring-1') + expect(tag.className).toContain('ring-inset') + }) + + it('should apply small variant styles', () => { + render(GET) + const tag = screen.getByText('GET') + // Small variant should not have ring styles + expect(tag.className).not.toContain('rounded-lg') + expect(tag.className).not.toContain('ring-1') + }) + }) + + describe('base styles', () => { + it('should always have font-mono class', () => { + render(GET) + const tag = screen.getByText('GET') + expect(tag.className).toContain('font-mono') + }) + + it('should always have correct font-size class', () => { + render(GET) + const tag = screen.getByText('GET') + expect(tag.className).toContain('text-[0.625rem]') + }) + + it('should always have font-semibold class', () => { + render(GET) + const tag = screen.getByText('GET') + expect(tag.className).toContain('font-semibold') + }) + + it('should always have leading-6 class', () => { + render(GET) + const tag = screen.getByText('GET') + expect(tag.className).toContain('leading-6') + }) + }) + + describe('color styles for medium variant', () => { + it('should apply full emerald medium styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('ring-emerald-300') + expect(tag.className).toContain('bg-emerald-400/10') + expect(tag.className).toContain('text-emerald-500') + }) + + it('should apply full sky medium styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('ring-sky-300') + expect(tag.className).toContain('bg-sky-400/10') + expect(tag.className).toContain('text-sky-500') + }) + + it('should apply full amber medium styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('ring-amber-300') + expect(tag.className).toContain('bg-amber-400/10') + expect(tag.className).toContain('text-amber-500') + }) + + it('should apply full rose medium styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('ring-rose-200') + expect(tag.className).toContain('bg-rose-50') + expect(tag.className).toContain('text-red-500') + }) + + it('should apply full zinc medium styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('ring-zinc-200') + expect(tag.className).toContain('bg-zinc-50') + expect(tag.className).toContain('text-zinc-500') + }) + }) + + describe('color styles for small variant', () => { + it('should apply emerald small styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('text-emerald-500') + // Small variant should not have background/ring styles + expect(tag.className).not.toContain('bg-emerald-400/10') + expect(tag.className).not.toContain('ring-emerald-300') + }) + + it('should apply sky small styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('text-sky-500') + }) + + it('should apply amber small styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('text-amber-500') + }) + + it('should apply rose small styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('text-red-500') + }) + + it('should apply zinc small styles', () => { + render(TEST) + const tag = screen.getByText('TEST') + expect(tag.className).toContain('text-zinc-400') + }) + }) + + describe('HTTP method color combinations', () => { + it('should correctly map PATCH to emerald (default)', () => { + render(PATCH) + const tag = screen.getByText('PATCH') + // PATCH is not in the valueColorMap, so it defaults to emerald + expect(tag.className).toContain('text-emerald') + }) + + it('should correctly render all standard HTTP methods', () => { + const methods = ['GET', 'POST', 'PUT', 'DELETE'] + const expectedColors = ['emerald', 'sky', 'amber', 'red'] + + methods.forEach((method, index) => { + const { unmount } = render({method}) + const tag = screen.getByText(method) + expect(tag.className).toContain(`text-${expectedColors[index]}`) + unmount() + }) + }) + }) +}) diff --git a/web/app/components/explore/banner/banner-item.spec.tsx b/web/app/components/explore/banner/banner-item.spec.tsx new file mode 100644 index 0000000000..c890c08dc5 --- /dev/null +++ b/web/app/components/explore/banner/banner-item.spec.tsx @@ -0,0 +1,483 @@ +import type { Banner } from '@/models/app' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { BannerItem } from './banner-item' + +const mockScrollTo = vi.fn() +const mockSlideNodes = vi.fn() + +vi.mock('@/app/components/base/carousel', () => ({ + useCarousel: () => ({ + api: { + scrollTo: mockScrollTo, + slideNodes: mockSlideNodes, + }, + selectedIndex: 0, + }), +})) + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'banner.viewMore': 'View More', + } + return translations[key] || key + }, + }), +})) + +const createMockBanner = (overrides: Partial = {}): Banner => ({ + id: 'banner-1', + status: 'enabled', + link: 'https://example.com', + content: { + 'category': 'Featured', + 'title': 'Test Banner Title', + 'description': 'Test banner description text', + 'img-src': 'https://example.com/image.png', + }, + ...overrides, +} as Banner) + +// Mock ResizeObserver methods declared at module level and initialized +const mockResizeObserverObserve = vi.fn() +const mockResizeObserverDisconnect = vi.fn() + +// Create mock class outside of describe block for proper hoisting +class MockResizeObserver { + constructor(_callback: ResizeObserverCallback) { + // Store callback if needed + } + + observe(...args: Parameters) { + mockResizeObserverObserve(...args) + } + + disconnect() { + mockResizeObserverDisconnect() + } + + unobserve() { + // No-op + } +} + +describe('BannerItem', () => { + let mockWindowOpen: ReturnType + + beforeEach(() => { + mockWindowOpen = vi.spyOn(window, 'open').mockImplementation(() => null) + mockSlideNodes.mockReturnValue([{}, {}, {}]) // 3 slides + + vi.stubGlobal('ResizeObserver', MockResizeObserver) + + // Mock window.innerWidth for responsive tests + Object.defineProperty(window, 'innerWidth', { + writable: true, + configurable: true, + value: 1400, // Above RESPONSIVE_BREAKPOINT (1200) + }) + }) + + afterEach(() => { + cleanup() + vi.clearAllMocks() + vi.unstubAllGlobals() + mockWindowOpen.mockRestore() + }) + + describe('basic rendering', () => { + it('renders banner category', () => { + const banner = createMockBanner() + render( + , + ) + + expect(screen.getByText('Featured')).toBeInTheDocument() + }) + + it('renders banner title', () => { + const banner = createMockBanner() + render( + , + ) + + expect(screen.getByText('Test Banner Title')).toBeInTheDocument() + }) + + it('renders banner description', () => { + const banner = createMockBanner() + render( + , + ) + + expect(screen.getByText('Test banner description text')).toBeInTheDocument() + }) + + it('renders banner image with correct src and alt', () => { + const banner = createMockBanner() + render( + , + ) + + const image = screen.getByRole('img') + expect(image).toHaveAttribute('src', 'https://example.com/image.png') + expect(image).toHaveAttribute('alt', 'Test Banner Title') + }) + + it('renders view more text', () => { + const banner = createMockBanner() + render( + , + ) + + expect(screen.getByText('View More')).toBeInTheDocument() + }) + }) + + describe('click handling', () => { + it('opens banner link in new tab when clicked', () => { + const banner = createMockBanner({ link: 'https://test-link.com' }) + render( + , + ) + + const bannerElement = screen.getByText('Test Banner Title').closest('div[class*="cursor-pointer"]') + fireEvent.click(bannerElement!) + + expect(mockWindowOpen).toHaveBeenCalledWith( + 'https://test-link.com', + '_blank', + 'noopener,noreferrer', + ) + }) + + it('does not open window when banner has no link', () => { + const banner = createMockBanner({ link: '' }) + render( + , + ) + + const bannerElement = screen.getByText('Test Banner Title').closest('div[class*="cursor-pointer"]') + fireEvent.click(bannerElement!) + + expect(mockWindowOpen).not.toHaveBeenCalled() + }) + }) + + describe('slide indicators', () => { + it('renders correct number of indicator buttons', () => { + mockSlideNodes.mockReturnValue([{}, {}, {}]) + const banner = createMockBanner() + render( + , + ) + + const buttons = screen.getAllByRole('button') + expect(buttons).toHaveLength(3) + }) + + it('renders indicator buttons with correct numbers', () => { + mockSlideNodes.mockReturnValue([{}, {}, {}]) + const banner = createMockBanner() + render( + , + ) + + expect(screen.getByText('01')).toBeInTheDocument() + expect(screen.getByText('02')).toBeInTheDocument() + expect(screen.getByText('03')).toBeInTheDocument() + }) + + it('calls scrollTo when indicator is clicked', () => { + mockSlideNodes.mockReturnValue([{}, {}, {}]) + const banner = createMockBanner() + render( + , + ) + + const secondIndicator = screen.getByText('02').closest('button') + fireEvent.click(secondIndicator!) + + expect(mockScrollTo).toHaveBeenCalledWith(1) + }) + + it('renders no indicators when no slides', () => { + mockSlideNodes.mockReturnValue([]) + const banner = createMockBanner() + render( + , + ) + + expect(screen.queryByRole('button')).not.toBeInTheDocument() + }) + }) + + describe('isPaused prop', () => { + it('defaults isPaused to false', () => { + const banner = createMockBanner() + render( + , + ) + + // Component should render without issues + expect(screen.getByText('Test Banner Title')).toBeInTheDocument() + }) + + it('accepts isPaused prop', () => { + const banner = createMockBanner() + render( + , + ) + + // Component should render with isPaused + expect(screen.getByText('Test Banner Title')).toBeInTheDocument() + }) + }) + + describe('responsive behavior', () => { + it('sets up ResizeObserver on mount', () => { + const banner = createMockBanner() + render( + , + ) + + expect(mockResizeObserverObserve).toHaveBeenCalled() + }) + + it('adds resize event listener on mount', () => { + const addEventListenerSpy = vi.spyOn(window, 'addEventListener') + const banner = createMockBanner() + render( + , + ) + + expect(addEventListenerSpy).toHaveBeenCalledWith('resize', expect.any(Function)) + addEventListenerSpy.mockRestore() + }) + + it('removes resize event listener on unmount', () => { + const removeEventListenerSpy = vi.spyOn(window, 'removeEventListener') + const banner = createMockBanner() + const { unmount } = render( + , + ) + + unmount() + + expect(removeEventListenerSpy).toHaveBeenCalledWith('resize', expect.any(Function)) + removeEventListenerSpy.mockRestore() + }) + + it('sets maxWidth when window width is below breakpoint', () => { + // Set window width below RESPONSIVE_BREAKPOINT (1200) + Object.defineProperty(window, 'innerWidth', { + writable: true, + configurable: true, + value: 1000, + }) + + const banner = createMockBanner() + render( + , + ) + + // Component should render and apply responsive styles + expect(screen.getByText('Test Banner Title')).toBeInTheDocument() + }) + + it('applies responsive styles when below breakpoint', () => { + // Set window width below RESPONSIVE_BREAKPOINT (1200) + Object.defineProperty(window, 'innerWidth', { + writable: true, + configurable: true, + value: 800, + }) + + const banner = createMockBanner() + render( + , + ) + + // The component should render even with responsive mode + expect(screen.getByText('View More')).toBeInTheDocument() + }) + }) + + describe('content variations', () => { + it('renders long category text', () => { + const banner = createMockBanner({ + content: { + 'category': 'Very Long Category Name', + 'title': 'Title', + 'description': 'Description', + 'img-src': 'https://example.com/img.png', + }, + } as Partial) + render( + , + ) + + expect(screen.getByText('Very Long Category Name')).toBeInTheDocument() + }) + + it('renders long title with truncation class', () => { + const banner = createMockBanner({ + content: { + 'category': 'Category', + 'title': 'A Very Long Title That Should Be Truncated Eventually', + 'description': 'Description', + 'img-src': 'https://example.com/img.png', + }, + } as Partial) + render( + , + ) + + const titleElement = screen.getByText('A Very Long Title That Should Be Truncated Eventually') + expect(titleElement).toHaveClass('line-clamp-2') + }) + + it('renders long description with truncation class', () => { + const banner = createMockBanner({ + content: { + 'category': 'Category', + 'title': 'Title', + 'description': 'A very long description that should be limited to a certain number of lines for proper display in the banner component.', + 'img-src': 'https://example.com/img.png', + }, + } as Partial) + render( + , + ) + + const descriptionElement = screen.getByText(/A very long description/) + expect(descriptionElement).toHaveClass('line-clamp-4') + }) + }) + + describe('slide calculation', () => { + it('calculates next index correctly for first slide', () => { + mockSlideNodes.mockReturnValue([{}, {}, {}]) + const banner = createMockBanner() + render( + , + ) + + // With selectedIndex=0 and 3 slides, nextIndex should be 1 + // The second indicator button should show the "next slide" state + const buttons = screen.getAllByRole('button') + expect(buttons).toHaveLength(3) + }) + + it('handles single slide case', () => { + mockSlideNodes.mockReturnValue([{}]) + const banner = createMockBanner() + render( + , + ) + + const buttons = screen.getAllByRole('button') + expect(buttons).toHaveLength(1) + }) + }) + + describe('wrapper styling', () => { + it('has cursor-pointer class', () => { + const banner = createMockBanner() + const { container } = render( + , + ) + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('cursor-pointer') + }) + + it('has rounded-2xl class', () => { + const banner = createMockBanner() + const { container } = render( + , + ) + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('rounded-2xl') + }) + }) +}) diff --git a/web/app/components/explore/banner/banner.spec.tsx b/web/app/components/explore/banner/banner.spec.tsx new file mode 100644 index 0000000000..de719c3936 --- /dev/null +++ b/web/app/components/explore/banner/banner.spec.tsx @@ -0,0 +1,472 @@ +import type * as React from 'react' +import type { Banner as BannerType } from '@/models/app' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { act } from 'react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import Banner from './banner' + +const mockUseGetBanners = vi.fn() + +vi.mock('@/service/use-explore', () => ({ + useGetBanners: (...args: unknown[]) => mockUseGetBanners(...args), +})) + +vi.mock('@/context/i18n', () => ({ + useLocale: () => 'en-US', +})) + +vi.mock('@/app/components/base/carousel', () => ({ + Carousel: Object.assign( + ({ children, onMouseEnter, onMouseLeave, className }: { + children: React.ReactNode + onMouseEnter?: () => void + onMouseLeave?: () => void + className?: string + }) => ( +
    + {children} +
    + ), + { + Content: ({ children }: { children: React.ReactNode }) => ( +
    {children}
    + ), + Item: ({ children }: { children: React.ReactNode }) => ( +
    {children}
    + ), + Plugin: { + Autoplay: (config: Record) => ({ type: 'autoplay', ...config }), + }, + }, + ), + useCarousel: () => ({ + api: { + scrollTo: vi.fn(), + slideNodes: () => [], + }, + selectedIndex: 0, + }), +})) + +vi.mock('./banner-item', () => ({ + BannerItem: ({ banner, autoplayDelay, isPaused }: { + banner: BannerType + autoplayDelay: number + isPaused?: boolean + }) => ( +
    + BannerItem: + {' '} + {banner.content.title} +
    + ), +})) + +const createMockBanner = (id: string, status: string = 'enabled', title: string = 'Test Banner'): BannerType => ({ + id, + status, + link: 'https://example.com', + content: { + 'category': 'Featured', + title, + 'description': 'Test description', + 'img-src': 'https://example.com/image.png', + }, +} as BannerType) + +describe('Banner', () => { + beforeEach(() => { + vi.useFakeTimers() + }) + + afterEach(() => { + cleanup() + vi.clearAllMocks() + vi.useRealTimers() + }) + + describe('loading state', () => { + it('renders loading state when isLoading is true', () => { + mockUseGetBanners.mockReturnValue({ + data: null, + isLoading: true, + isError: false, + }) + + render() + + // Loading component renders a spinner + const loadingWrapper = document.querySelector('[style*="min-height"]') + expect(loadingWrapper).toBeInTheDocument() + }) + + it('shows loading indicator with correct minimum height', () => { + mockUseGetBanners.mockReturnValue({ + data: null, + isLoading: true, + isError: false, + }) + + render() + + const loadingWrapper = document.querySelector('[style*="min-height: 168px"]') + expect(loadingWrapper).toBeInTheDocument() + }) + }) + + describe('error state', () => { + it('returns null when isError is true', () => { + mockUseGetBanners.mockReturnValue({ + data: null, + isLoading: false, + isError: true, + }) + + const { container } = render() + + expect(container.firstChild).toBeNull() + }) + }) + + describe('empty state', () => { + it('returns null when banners array is empty', () => { + mockUseGetBanners.mockReturnValue({ + data: [], + isLoading: false, + isError: false, + }) + + const { container } = render() + + expect(container.firstChild).toBeNull() + }) + + it('returns null when all banners are disabled', () => { + mockUseGetBanners.mockReturnValue({ + data: [ + createMockBanner('1', 'disabled'), + createMockBanner('2', 'disabled'), + ], + isLoading: false, + isError: false, + }) + + const { container } = render() + + expect(container.firstChild).toBeNull() + }) + + it('returns null when data is undefined', () => { + mockUseGetBanners.mockReturnValue({ + data: undefined, + isLoading: false, + isError: false, + }) + + const { container } = render() + + expect(container.firstChild).toBeNull() + }) + }) + + describe('successful render', () => { + it('renders carousel when enabled banners exist', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + render() + + expect(screen.getByTestId('carousel')).toBeInTheDocument() + }) + + it('renders only enabled banners', () => { + mockUseGetBanners.mockReturnValue({ + data: [ + createMockBanner('1', 'enabled', 'Enabled Banner 1'), + createMockBanner('2', 'disabled', 'Disabled Banner'), + createMockBanner('3', 'enabled', 'Enabled Banner 2'), + ], + isLoading: false, + isError: false, + }) + + render() + + const bannerItems = screen.getAllByTestId('banner-item') + expect(bannerItems).toHaveLength(2) + expect(screen.getByText('BannerItem: Enabled Banner 1')).toBeInTheDocument() + expect(screen.getByText('BannerItem: Enabled Banner 2')).toBeInTheDocument() + expect(screen.queryByText('BannerItem: Disabled Banner')).not.toBeInTheDocument() + }) + + it('passes correct autoplayDelay to BannerItem', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + render() + + const bannerItem = screen.getByTestId('banner-item') + expect(bannerItem).toHaveAttribute('data-autoplay-delay', '5000') + }) + + it('renders carousel with correct class', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + render() + + expect(screen.getByTestId('carousel')).toHaveClass('rounded-2xl') + }) + }) + + describe('hover behavior', () => { + it('sets isPaused to true on mouse enter', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + render() + + const carousel = screen.getByTestId('carousel') + fireEvent.mouseEnter(carousel) + + const bannerItem = screen.getByTestId('banner-item') + expect(bannerItem).toHaveAttribute('data-is-paused', 'true') + }) + + it('sets isPaused to false on mouse leave', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + render() + + const carousel = screen.getByTestId('carousel') + + // Enter and then leave + fireEvent.mouseEnter(carousel) + fireEvent.mouseLeave(carousel) + + const bannerItem = screen.getByTestId('banner-item') + expect(bannerItem).toHaveAttribute('data-is-paused', 'false') + }) + }) + + describe('resize behavior', () => { + it('pauses animation during resize', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + render() + + // Trigger resize event + act(() => { + window.dispatchEvent(new Event('resize')) + }) + + const bannerItem = screen.getByTestId('banner-item') + expect(bannerItem).toHaveAttribute('data-is-paused', 'true') + }) + + it('resumes animation after resize debounce delay', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + render() + + // Trigger resize event + act(() => { + window.dispatchEvent(new Event('resize')) + }) + + // Wait for debounce delay (50ms) + act(() => { + vi.advanceTimersByTime(50) + }) + + const bannerItem = screen.getByTestId('banner-item') + expect(bannerItem).toHaveAttribute('data-is-paused', 'false') + }) + + it('resets debounce timer on multiple resize events', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + render() + + // Trigger first resize event + act(() => { + window.dispatchEvent(new Event('resize')) + }) + + // Wait partial time + act(() => { + vi.advanceTimersByTime(30) + }) + + // Trigger second resize event + act(() => { + window.dispatchEvent(new Event('resize')) + }) + + // Wait another 30ms (total 60ms from second resize but only 30ms after) + act(() => { + vi.advanceTimersByTime(30) + }) + + // Should still be paused (debounce resets) + let bannerItem = screen.getByTestId('banner-item') + expect(bannerItem).toHaveAttribute('data-is-paused', 'true') + + // Wait remaining time + act(() => { + vi.advanceTimersByTime(20) + }) + + bannerItem = screen.getByTestId('banner-item') + expect(bannerItem).toHaveAttribute('data-is-paused', 'false') + }) + }) + + describe('cleanup', () => { + it('removes resize event listener on unmount', () => { + const removeEventListenerSpy = vi.spyOn(window, 'removeEventListener') + + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + const { unmount } = render() + unmount() + + expect(removeEventListenerSpy).toHaveBeenCalledWith('resize', expect.any(Function)) + removeEventListenerSpy.mockRestore() + }) + + it('clears resize timer on unmount', () => { + const clearTimeoutSpy = vi.spyOn(globalThis, 'clearTimeout') + + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + const { unmount } = render() + + // Trigger resize to create timer + act(() => { + window.dispatchEvent(new Event('resize')) + }) + + unmount() + + expect(clearTimeoutSpy).toHaveBeenCalled() + clearTimeoutSpy.mockRestore() + }) + }) + + describe('hook calls', () => { + it('calls useGetBanners with correct locale', () => { + mockUseGetBanners.mockReturnValue({ + data: [], + isLoading: false, + isError: false, + }) + + render() + + expect(mockUseGetBanners).toHaveBeenCalledWith('en-US') + }) + }) + + describe('multiple banners', () => { + it('renders all enabled banners in carousel items', () => { + mockUseGetBanners.mockReturnValue({ + data: [ + createMockBanner('1', 'enabled', 'Banner 1'), + createMockBanner('2', 'enabled', 'Banner 2'), + createMockBanner('3', 'enabled', 'Banner 3'), + ], + isLoading: false, + isError: false, + }) + + render() + + const carouselItems = screen.getAllByTestId('carousel-item') + expect(carouselItems).toHaveLength(3) + }) + + it('preserves banner order', () => { + mockUseGetBanners.mockReturnValue({ + data: [ + createMockBanner('1', 'enabled', 'First Banner'), + createMockBanner('2', 'enabled', 'Second Banner'), + createMockBanner('3', 'enabled', 'Third Banner'), + ], + isLoading: false, + isError: false, + }) + + render() + + const bannerItems = screen.getAllByTestId('banner-item') + expect(bannerItems[0]).toHaveAttribute('data-banner-id', '1') + expect(bannerItems[1]).toHaveAttribute('data-banner-id', '2') + expect(bannerItems[2]).toHaveAttribute('data-banner-id', '3') + }) + }) + + describe('React.memo behavior', () => { + it('renders as memoized component', () => { + mockUseGetBanners.mockReturnValue({ + data: [createMockBanner('1', 'enabled')], + isLoading: false, + isError: false, + }) + + const { rerender } = render() + + // Re-render with same props + rerender() + + // Component should still be present (memo doesn't break rendering) + expect(screen.getByTestId('carousel')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/explore/banner/indicator-button.spec.tsx b/web/app/components/explore/banner/indicator-button.spec.tsx new file mode 100644 index 0000000000..545f4e2f9a --- /dev/null +++ b/web/app/components/explore/banner/indicator-button.spec.tsx @@ -0,0 +1,448 @@ +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { act } from 'react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { IndicatorButton } from './indicator-button' + +describe('IndicatorButton', () => { + beforeEach(() => { + vi.useFakeTimers() + }) + + afterEach(() => { + cleanup() + vi.clearAllMocks() + vi.useRealTimers() + }) + + describe('basic rendering', () => { + it('renders button with correct index number', () => { + const mockOnClick = vi.fn() + render( + , + ) + + expect(screen.getByRole('button')).toBeInTheDocument() + expect(screen.getByText('01')).toBeInTheDocument() + }) + + it('renders two-digit index numbers', () => { + const mockOnClick = vi.fn() + render( + , + ) + + expect(screen.getByText('10')).toBeInTheDocument() + }) + + it('pads single digit index numbers with leading zero', () => { + const mockOnClick = vi.fn() + render( + , + ) + + expect(screen.getByText('05')).toBeInTheDocument() + }) + }) + + describe('active state', () => { + it('applies active styles when index equals selectedIndex', () => { + const mockOnClick = vi.fn() + render( + , + ) + + const button = screen.getByRole('button') + expect(button).toHaveClass('bg-text-primary') + }) + + it('applies inactive styles when index does not equal selectedIndex', () => { + const mockOnClick = vi.fn() + render( + , + ) + + const button = screen.getByRole('button') + expect(button).toHaveClass('bg-components-panel-on-panel-item-bg') + }) + }) + + describe('click handling', () => { + it('calls onClick when button is clicked', () => { + const mockOnClick = vi.fn() + render( + , + ) + + fireEvent.click(screen.getByRole('button')) + expect(mockOnClick).toHaveBeenCalledTimes(1) + }) + + it('stops event propagation when clicked', () => { + const mockOnClick = vi.fn() + const mockParentClick = vi.fn() + + render( +
    + +
    , + ) + + fireEvent.click(screen.getByRole('button')) + expect(mockOnClick).toHaveBeenCalledTimes(1) + expect(mockParentClick).not.toHaveBeenCalled() + }) + }) + + describe('progress indicator', () => { + it('does not show progress indicator when not next slide', () => { + const mockOnClick = vi.fn() + const { container } = render( + , + ) + + // Check for conic-gradient style which indicates progress indicator + const progressIndicator = container.querySelector('[style*="conic-gradient"]') + expect(progressIndicator).not.toBeInTheDocument() + }) + + it('shows progress indicator when isNextSlide is true and not active', () => { + const mockOnClick = vi.fn() + const { container } = render( + , + ) + + const progressIndicator = container.querySelector('[style*="conic-gradient"]') + expect(progressIndicator).toBeInTheDocument() + }) + + it('does not show progress indicator when isNextSlide but also active', () => { + const mockOnClick = vi.fn() + const { container } = render( + , + ) + + const progressIndicator = container.querySelector('[style*="conic-gradient"]') + expect(progressIndicator).not.toBeInTheDocument() + }) + }) + + describe('animation behavior', () => { + it('starts progress from 0 when isNextSlide becomes true', () => { + const mockOnClick = vi.fn() + const { container, rerender } = render( + , + ) + + // Initially no progress indicator + expect(container.querySelector('[style*="conic-gradient"]')).not.toBeInTheDocument() + + // Rerender with isNextSlide=true + rerender( + , + ) + + // Now progress indicator should be visible + expect(container.querySelector('[style*="conic-gradient"]')).toBeInTheDocument() + }) + + it('resets progress when resetKey changes', () => { + const mockOnClick = vi.fn() + const { rerender, container } = render( + , + ) + + // Progress indicator should be present + const progressIndicator = container.querySelector('[style*="conic-gradient"]') + expect(progressIndicator).toBeInTheDocument() + + // Rerender with new resetKey - this should reset the progress animation + rerender( + , + ) + + const newProgressIndicator = container.querySelector('[style*="conic-gradient"]') + // The progress indicator should still be present after reset + expect(newProgressIndicator).toBeInTheDocument() + }) + + it('stops animation when isPaused is true', () => { + const mockOnClick = vi.fn() + const mockRequestAnimationFrame = vi.spyOn(window, 'requestAnimationFrame') + + render( + , + ) + + // The component should still render but animation should be paused + // requestAnimationFrame might still be called for polling but progress won't update + expect(screen.getByRole('button')).toBeInTheDocument() + mockRequestAnimationFrame.mockRestore() + }) + + it('cancels animation frame on unmount', () => { + const mockOnClick = vi.fn() + const mockCancelAnimationFrame = vi.spyOn(window, 'cancelAnimationFrame') + + const { unmount } = render( + , + ) + + // Trigger animation frame + act(() => { + vi.advanceTimersToNextTimer() + }) + + unmount() + + expect(mockCancelAnimationFrame).toHaveBeenCalled() + mockCancelAnimationFrame.mockRestore() + }) + + it('cancels animation frame when isNextSlide becomes false', () => { + const mockOnClick = vi.fn() + const mockCancelAnimationFrame = vi.spyOn(window, 'cancelAnimationFrame') + + const { rerender } = render( + , + ) + + // Trigger animation frame + act(() => { + vi.advanceTimersToNextTimer() + }) + + // Change isNextSlide to false - this should cancel the animation frame + rerender( + , + ) + + expect(mockCancelAnimationFrame).toHaveBeenCalled() + mockCancelAnimationFrame.mockRestore() + }) + + it('continues polling when document is hidden', () => { + const mockOnClick = vi.fn() + const mockRequestAnimationFrame = vi.spyOn(window, 'requestAnimationFrame') + + // Mock document.hidden to be true + Object.defineProperty(document, 'hidden', { + writable: true, + configurable: true, + value: true, + }) + + render( + , + ) + + // Component should still render + expect(screen.getByRole('button')).toBeInTheDocument() + + // Reset document.hidden + Object.defineProperty(document, 'hidden', { + writable: true, + configurable: true, + value: false, + }) + + mockRequestAnimationFrame.mockRestore() + }) + }) + + describe('isPaused prop default', () => { + it('defaults isPaused to false when not provided', () => { + const mockOnClick = vi.fn() + const { container } = render( + , + ) + + // Progress indicator should be visible (animation running) + expect(container.querySelector('[style*="conic-gradient"]')).toBeInTheDocument() + }) + }) + + describe('button styling', () => { + it('has correct base classes', () => { + const mockOnClick = vi.fn() + render( + , + ) + + const button = screen.getByRole('button') + expect(button).toHaveClass('relative') + expect(button).toHaveClass('flex') + expect(button).toHaveClass('items-center') + expect(button).toHaveClass('justify-center') + expect(button).toHaveClass('rounded-[7px]') + expect(button).toHaveClass('border') + expect(button).toHaveClass('transition-colors') + }) + }) +}) diff --git a/web/app/components/explore/try-app/app-info/index.spec.tsx b/web/app/components/explore/try-app/app-info/index.spec.tsx new file mode 100644 index 0000000000..cfae862a72 --- /dev/null +++ b/web/app/components/explore/try-app/app-info/index.spec.tsx @@ -0,0 +1,395 @@ +import type { TryAppInfo } from '@/service/try-app' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import AppInfo from './index' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'types.advanced': 'Advanced', + 'types.chatbot': 'Chatbot', + 'types.agent': 'Agent', + 'types.workflow': 'Workflow', + 'types.completion': 'Completion', + 'tryApp.createFromSampleApp': 'Create from Sample', + 'tryApp.category': 'Category', + 'tryApp.requirements': 'Requirements', + } + return translations[key] || key + }, + }), +})) + +const mockUseGetRequirements = vi.fn() + +vi.mock('./use-get-requirements', () => ({ + default: (...args: unknown[]) => mockUseGetRequirements(...args), +})) + +const createMockAppDetail = (mode: string, overrides: Partial = {}): TryAppInfo => ({ + id: 'test-app-id', + name: 'Test App Name', + description: 'Test App Description', + mode, + site: { + title: 'Test Site Title', + icon: '๐Ÿš€', + icon_type: 'emoji', + icon_background: '#FFFFFF', + icon_url: '', + }, + model_config: { + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + mode: 'chat', + }, + dataset_configs: { + datasets: { + datasets: [], + }, + }, + agent_mode: { + tools: [], + }, + user_input_form: [], + }, + ...overrides, +} as unknown as TryAppInfo) + +describe('AppInfo', () => { + beforeEach(() => { + mockUseGetRequirements.mockReturnValue({ + requirements: [], + }) + }) + + afterEach(() => { + cleanup() + vi.clearAllMocks() + }) + + describe('app name and icon', () => { + it('renders app name', () => { + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('Test App Name')).toBeInTheDocument() + }) + + it('renders app name with title attribute', () => { + const appDetail = createMockAppDetail('chat', { + name: 'Very Long App Name That Should Be Truncated', + } as Partial) + const mockOnCreate = vi.fn() + + render( + , + ) + + const nameElement = screen.getByText('Very Long App Name That Should Be Truncated') + expect(nameElement).toHaveAttribute('title', 'Very Long App Name That Should Be Truncated') + }) + }) + + describe('app type', () => { + it('displays ADVANCED for advanced-chat mode', () => { + const appDetail = createMockAppDetail('advanced-chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('ADVANCED')).toBeInTheDocument() + }) + + it('displays CHATBOT for chat mode', () => { + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('CHATBOT')).toBeInTheDocument() + }) + + it('displays AGENT for agent-chat mode', () => { + const appDetail = createMockAppDetail('agent-chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('AGENT')).toBeInTheDocument() + }) + + it('displays WORKFLOW for workflow mode', () => { + const appDetail = createMockAppDetail('workflow') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('WORKFLOW')).toBeInTheDocument() + }) + + it('displays COMPLETION for completion mode', () => { + const appDetail = createMockAppDetail('completion') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('COMPLETION')).toBeInTheDocument() + }) + }) + + describe('description', () => { + it('renders description when provided', () => { + const appDetail = createMockAppDetail('chat', { + description: 'This is a test description', + } as Partial) + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('This is a test description')).toBeInTheDocument() + }) + + it('does not render description when empty', () => { + const appDetail = createMockAppDetail('chat', { + description: '', + } as Partial) + const mockOnCreate = vi.fn() + + const { container } = render( + , + ) + + // Check that there's no element with the description class that has empty content + const descriptionElements = container.querySelectorAll('.system-sm-regular.mt-\\[14px\\]') + expect(descriptionElements.length).toBe(0) + }) + }) + + describe('create button', () => { + it('renders create button with correct text', () => { + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('Create from Sample')).toBeInTheDocument() + }) + + it('calls onCreate when button is clicked', () => { + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + fireEvent.click(screen.getByText('Create from Sample')) + expect(mockOnCreate).toHaveBeenCalledTimes(1) + }) + }) + + describe('category', () => { + it('renders category when provided', () => { + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('Category')).toBeInTheDocument() + expect(screen.getByText('AI Assistant')).toBeInTheDocument() + }) + + it('does not render category section when not provided', () => { + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.queryByText('Category')).not.toBeInTheDocument() + }) + }) + + describe('requirements', () => { + it('renders requirements when available', () => { + mockUseGetRequirements.mockReturnValue({ + requirements: [ + { name: 'OpenAI GPT-4', iconUrl: 'https://example.com/icon1.png' }, + { name: 'Google Search', iconUrl: 'https://example.com/icon2.png' }, + ], + }) + + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.getByText('Requirements')).toBeInTheDocument() + expect(screen.getByText('OpenAI GPT-4')).toBeInTheDocument() + expect(screen.getByText('Google Search')).toBeInTheDocument() + }) + + it('does not render requirements section when empty', () => { + mockUseGetRequirements.mockReturnValue({ + requirements: [], + }) + + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(screen.queryByText('Requirements')).not.toBeInTheDocument() + }) + + it('renders requirement icons with correct background image', () => { + mockUseGetRequirements.mockReturnValue({ + requirements: [ + { name: 'Test Tool', iconUrl: 'https://example.com/test-icon.png' }, + ], + }) + + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + const { container } = render( + , + ) + + const iconElement = container.querySelector('[style*="background-image"]') + expect(iconElement).toBeInTheDocument() + expect(iconElement).toHaveStyle({ backgroundImage: 'url(https://example.com/test-icon.png)' }) + }) + }) + + describe('className prop', () => { + it('applies custom className', () => { + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + const { container } = render( + , + ) + + expect(container.firstChild).toHaveClass('custom-class') + }) + }) + + describe('hook calls', () => { + it('calls useGetRequirements with correct parameters', () => { + const appDetail = createMockAppDetail('chat') + const mockOnCreate = vi.fn() + + render( + , + ) + + expect(mockUseGetRequirements).toHaveBeenCalledWith({ + appDetail, + appId: 'my-app-id', + }) + }) + }) +}) diff --git a/web/app/components/explore/try-app/app-info/use-get-requirements.spec.ts b/web/app/components/explore/try-app/app-info/use-get-requirements.spec.ts new file mode 100644 index 0000000000..c8af6121d1 --- /dev/null +++ b/web/app/components/explore/try-app/app-info/use-get-requirements.spec.ts @@ -0,0 +1,425 @@ +import type { TryAppInfo } from '@/service/try-app' +import { renderHook } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import useGetRequirements from './use-get-requirements' + +const mockUseGetTryAppFlowPreview = vi.fn() + +vi.mock('@/service/use-try-app', () => ({ + useGetTryAppFlowPreview: (...args: unknown[]) => mockUseGetTryAppFlowPreview(...args), +})) + +vi.mock('@/config', () => ({ + MARKETPLACE_API_PREFIX: 'https://marketplace.api', +})) + +const createMockAppDetail = (mode: string, overrides: Partial = {}): TryAppInfo => ({ + id: 'test-app-id', + name: 'Test App', + description: 'Test Description', + mode, + site: { + title: 'Test Site Title', + icon: 'icon', + icon_type: 'emoji', + icon_background: '#FFFFFF', + icon_url: '', + }, + model_config: { + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + mode: 'chat', + }, + dataset_configs: { + datasets: { + datasets: [], + }, + }, + agent_mode: { + tools: [], + }, + user_input_form: [], + }, + ...overrides, +} as unknown as TryAppInfo) + +describe('useGetRequirements', () => { + afterEach(() => { + vi.clearAllMocks() + }) + + describe('basic app modes (chat, completion, agent-chat)', () => { + it('returns model provider for chat mode', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ data: null }) + + const appDetail = createMockAppDetail('chat') + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(1) + expect(result.current.requirements[0].name).toBe('openai') + expect(result.current.requirements[0].iconUrl).toBe('https://marketplace.api/plugins/langgenius/openai/icon') + }) + + it('returns model provider for completion mode', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ data: null }) + + const appDetail = createMockAppDetail('completion', { + model_config: { + model: { + provider: 'anthropic/claude/claude', + name: 'claude-3', + mode: 'completion', + }, + dataset_configs: { datasets: { datasets: [] } }, + agent_mode: { tools: [] }, + user_input_form: [], + }, + } as unknown as Partial) + + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(1) + expect(result.current.requirements[0].name).toBe('claude') + }) + + it('returns model provider and tools for agent-chat mode', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ data: null }) + + const appDetail = createMockAppDetail('agent-chat', { + model_config: { + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + mode: 'chat', + }, + dataset_configs: { datasets: { datasets: [] } }, + agent_mode: { + tools: [ + { + enabled: true, + provider_id: 'langgenius/google_search/google_search', + tool_label: 'Google Search', + }, + { + enabled: true, + provider_id: 'langgenius/web_scraper/web_scraper', + tool_label: 'Web Scraper', + }, + { + enabled: false, + provider_id: 'langgenius/disabled_tool/disabled_tool', + tool_label: 'Disabled Tool', + }, + ], + }, + user_input_form: [], + }, + } as unknown as Partial) + + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(3) + expect(result.current.requirements.map(r => r.name)).toContain('openai') + expect(result.current.requirements.map(r => r.name)).toContain('Google Search') + expect(result.current.requirements.map(r => r.name)).toContain('Web Scraper') + expect(result.current.requirements.map(r => r.name)).not.toContain('Disabled Tool') + }) + + it('filters out disabled tools in agent mode', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ data: null }) + + const appDetail = createMockAppDetail('agent-chat', { + model_config: { + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + mode: 'chat', + }, + dataset_configs: { datasets: { datasets: [] } }, + agent_mode: { + tools: [ + { + enabled: false, + provider_id: 'langgenius/tool1/tool1', + tool_label: 'Tool 1', + }, + { + enabled: false, + provider_id: 'langgenius/tool2/tool2', + tool_label: 'Tool 2', + }, + ], + }, + user_input_form: [], + }, + } as unknown as Partial) + + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + // Only model provider should be included, no disabled tools + expect(result.current.requirements).toHaveLength(1) + expect(result.current.requirements[0].name).toBe('openai') + }) + }) + + describe('advanced app modes (workflow, advanced-chat)', () => { + it('returns requirements from flow data for workflow mode', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: { + graph: { + nodes: [ + { + data: { + type: 'llm', + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + }, + }, + }, + { + data: { + type: 'tool', + provider_id: 'langgenius/google/google', + tool_label: 'Google Tool', + }, + }, + ], + }, + }, + }) + + const appDetail = createMockAppDetail('workflow') + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(2) + expect(result.current.requirements.map(r => r.name)).toContain('gpt-4') + expect(result.current.requirements.map(r => r.name)).toContain('Google Tool') + }) + + it('returns requirements from flow data for advanced-chat mode', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: { + graph: { + nodes: [ + { + data: { + type: 'llm', + model: { + provider: 'anthropic/claude/claude', + name: 'claude-3-opus', + }, + }, + }, + ], + }, + }, + }) + + const appDetail = createMockAppDetail('advanced-chat') + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(1) + expect(result.current.requirements[0].name).toBe('claude-3-opus') + }) + + it('returns empty requirements when flow data has no nodes', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: { + graph: { + nodes: [], + }, + }, + }) + + const appDetail = createMockAppDetail('workflow') + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(0) + }) + + it('returns empty requirements when flow data is null', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: null, + }) + + const appDetail = createMockAppDetail('workflow') + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(0) + }) + + it('extracts multiple LLM nodes from flow data', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: { + graph: { + nodes: [ + { + data: { + type: 'llm', + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + }, + }, + }, + { + data: { + type: 'llm', + model: { + provider: 'anthropic/claude/claude', + name: 'claude-3', + }, + }, + }, + ], + }, + }, + }) + + const appDetail = createMockAppDetail('workflow') + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(2) + expect(result.current.requirements.map(r => r.name)).toContain('gpt-4') + expect(result.current.requirements.map(r => r.name)).toContain('claude-3') + }) + + it('extracts multiple tool nodes from flow data', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: { + graph: { + nodes: [ + { + data: { + type: 'tool', + provider_id: 'langgenius/tool1/tool1', + tool_label: 'Tool 1', + }, + }, + { + data: { + type: 'tool', + provider_id: 'langgenius/tool2/tool2', + tool_label: 'Tool 2', + }, + }, + ], + }, + }, + }) + + const appDetail = createMockAppDetail('workflow') + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(2) + expect(result.current.requirements.map(r => r.name)).toContain('Tool 1') + expect(result.current.requirements.map(r => r.name)).toContain('Tool 2') + }) + }) + + describe('deduplication', () => { + it('removes duplicate requirements by name', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: { + graph: { + nodes: [ + { + data: { + type: 'llm', + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + }, + }, + }, + { + data: { + type: 'llm', + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + }, + }, + }, + ], + }, + }, + }) + + const appDetail = createMockAppDetail('workflow') + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements).toHaveLength(1) + expect(result.current.requirements[0].name).toBe('gpt-4') + }) + }) + + describe('icon URL generation', () => { + it('generates correct icon URL for model providers', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ data: null }) + + const appDetail = createMockAppDetail('chat', { + model_config: { + model: { + provider: 'org/plugin/model', + name: 'model-name', + mode: 'chat', + }, + dataset_configs: { datasets: { datasets: [] } }, + agent_mode: { tools: [] }, + user_input_form: [], + }, + } as unknown as Partial) + + const { result } = renderHook(() => + useGetRequirements({ appDetail, appId: 'test-app-id' }), + ) + + expect(result.current.requirements[0].iconUrl).toBe('https://marketplace.api/plugins/org/plugin/icon') + }) + }) + + describe('hook calls', () => { + it('calls useGetTryAppFlowPreview with correct parameters for basic apps', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ data: null }) + + const appDetail = createMockAppDetail('chat') + renderHook(() => useGetRequirements({ appDetail, appId: 'test-app-id' })) + + expect(mockUseGetTryAppFlowPreview).toHaveBeenCalledWith('test-app-id', true) + }) + + it('calls useGetTryAppFlowPreview with correct parameters for advanced apps', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ data: null }) + + const appDetail = createMockAppDetail('workflow') + renderHook(() => useGetRequirements({ appDetail, appId: 'test-app-id' })) + + expect(mockUseGetTryAppFlowPreview).toHaveBeenCalledWith('test-app-id', false) + }) + }) +}) diff --git a/web/app/components/explore/try-app/app/chat.spec.tsx b/web/app/components/explore/try-app/app/chat.spec.tsx new file mode 100644 index 0000000000..ebd430c4e8 --- /dev/null +++ b/web/app/components/explore/try-app/app/chat.spec.tsx @@ -0,0 +1,357 @@ +import type { TryAppInfo } from '@/service/try-app' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import TryApp from './chat' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'chat.resetChat': 'Reset Chat', + 'tryApp.tryInfo': 'This is try mode info', + } + return translations[key] || key + }, + }), +})) + +const mockRemoveConversationIdInfo = vi.fn() +const mockHandleNewConversation = vi.fn() +const mockUseEmbeddedChatbot = vi.fn() + +vi.mock('@/app/components/base/chat/embedded-chatbot/hooks', () => ({ + useEmbeddedChatbot: (...args: unknown[]) => mockUseEmbeddedChatbot(...args), +})) + +vi.mock('@/hooks/use-breakpoints', () => ({ + default: () => 'pc', + MediaType: { + mobile: 'mobile', + pc: 'pc', + }, +})) + +vi.mock('../../../base/chat/embedded-chatbot/theme/theme-context', () => ({ + useThemeContext: () => ({ + primaryColor: '#1890ff', + }), +})) + +vi.mock('@/app/components/base/chat/embedded-chatbot/chat-wrapper', () => ({ + default: () =>
    ChatWrapper
    , +})) + +vi.mock('@/app/components/base/chat/embedded-chatbot/inputs-form/view-form-dropdown', () => ({ + default: () =>
    ViewFormDropdown
    , +})) + +const createMockAppDetail = (overrides: Partial = {}): TryAppInfo => ({ + id: 'test-app-id', + name: 'Test Chat App', + description: 'Test Description', + mode: 'chat', + site: { + title: 'Test Site Title', + icon: '๐Ÿ’ฌ', + icon_type: 'emoji', + icon_background: '#4F46E5', + icon_url: '', + }, + model_config: { + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + mode: 'chat', + }, + dataset_configs: { + datasets: { + datasets: [], + }, + }, + agent_mode: { + tools: [], + }, + user_input_form: [], + }, + ...overrides, +} as unknown as TryAppInfo) + +describe('TryApp (chat.tsx)', () => { + beforeEach(() => { + mockUseEmbeddedChatbot.mockReturnValue({ + removeConversationIdInfo: mockRemoveConversationIdInfo, + handleNewConversation: mockHandleNewConversation, + currentConversationId: null, + inputsForms: [], + }) + }) + + afterEach(() => { + cleanup() + vi.clearAllMocks() + }) + + describe('basic rendering', () => { + it('renders app name', () => { + const appDetail = createMockAppDetail() + + render( + , + ) + + expect(screen.getByText('Test Chat App')).toBeInTheDocument() + }) + + it('renders app name with title attribute', () => { + const appDetail = createMockAppDetail({ name: 'Long App Name' } as Partial) + + render( + , + ) + + const nameElement = screen.getByText('Long App Name') + expect(nameElement).toHaveAttribute('title', 'Long App Name') + }) + + it('renders ChatWrapper', () => { + const appDetail = createMockAppDetail() + + render( + , + ) + + expect(screen.getByTestId('chat-wrapper')).toBeInTheDocument() + }) + + it('renders alert with try info', () => { + const appDetail = createMockAppDetail() + + render( + , + ) + + expect(screen.getByText('This is try mode info')).toBeInTheDocument() + }) + + it('applies className prop', () => { + const appDetail = createMockAppDetail() + + const { container } = render( + , + ) + + // The component wraps with EmbeddedChatbotContext.Provider, first child is the div with className + const innerDiv = container.querySelector('.custom-class') + expect(innerDiv).toBeInTheDocument() + }) + }) + + describe('reset button', () => { + it('does not render reset button when no conversation', () => { + mockUseEmbeddedChatbot.mockReturnValue({ + removeConversationIdInfo: mockRemoveConversationIdInfo, + handleNewConversation: mockHandleNewConversation, + currentConversationId: null, + inputsForms: [], + }) + + const appDetail = createMockAppDetail() + + render( + , + ) + + // Reset button should not be present + expect(screen.queryByRole('button')).not.toBeInTheDocument() + }) + + it('renders reset button when conversation exists', () => { + mockUseEmbeddedChatbot.mockReturnValue({ + removeConversationIdInfo: mockRemoveConversationIdInfo, + handleNewConversation: mockHandleNewConversation, + currentConversationId: 'conv-123', + inputsForms: [], + }) + + const appDetail = createMockAppDetail() + + render( + , + ) + + // Should have a button (the reset button) + expect(screen.getByRole('button')).toBeInTheDocument() + }) + + it('calls handleNewConversation when reset button is clicked', () => { + mockUseEmbeddedChatbot.mockReturnValue({ + removeConversationIdInfo: mockRemoveConversationIdInfo, + handleNewConversation: mockHandleNewConversation, + currentConversationId: 'conv-123', + inputsForms: [], + }) + + const appDetail = createMockAppDetail() + + render( + , + ) + + fireEvent.click(screen.getByRole('button')) + + expect(mockRemoveConversationIdInfo).toHaveBeenCalledWith('test-app-id') + expect(mockHandleNewConversation).toHaveBeenCalled() + }) + }) + + describe('view form dropdown', () => { + it('does not render view form dropdown when no conversation', () => { + mockUseEmbeddedChatbot.mockReturnValue({ + removeConversationIdInfo: mockRemoveConversationIdInfo, + handleNewConversation: mockHandleNewConversation, + currentConversationId: null, + inputsForms: [{ id: 'form1' }], + }) + + const appDetail = createMockAppDetail() + + render( + , + ) + + expect(screen.queryByTestId('view-form-dropdown')).not.toBeInTheDocument() + }) + + it('does not render view form dropdown when no input forms', () => { + mockUseEmbeddedChatbot.mockReturnValue({ + removeConversationIdInfo: mockRemoveConversationIdInfo, + handleNewConversation: mockHandleNewConversation, + currentConversationId: 'conv-123', + inputsForms: [], + }) + + const appDetail = createMockAppDetail() + + render( + , + ) + + expect(screen.queryByTestId('view-form-dropdown')).not.toBeInTheDocument() + }) + + it('renders view form dropdown when conversation and input forms exist', () => { + mockUseEmbeddedChatbot.mockReturnValue({ + removeConversationIdInfo: mockRemoveConversationIdInfo, + handleNewConversation: mockHandleNewConversation, + currentConversationId: 'conv-123', + inputsForms: [{ id: 'form1' }], + }) + + const appDetail = createMockAppDetail() + + render( + , + ) + + expect(screen.getByTestId('view-form-dropdown')).toBeInTheDocument() + }) + }) + + describe('alert hiding', () => { + it('hides alert when onHide is called', () => { + const appDetail = createMockAppDetail() + + render( + , + ) + + // Find and click the hide button on the alert + const alertElement = screen.getByText('This is try mode info').closest('[class*="alert"]')?.parentElement + const hideButton = alertElement?.querySelector('button, [role="button"], svg') + + if (hideButton) { + fireEvent.click(hideButton) + // After hiding, the alert should not be visible + expect(screen.queryByText('This is try mode info')).not.toBeInTheDocument() + } + }) + }) + + describe('hook calls', () => { + it('calls useEmbeddedChatbot with correct parameters', () => { + const appDetail = createMockAppDetail() + + render( + , + ) + + expect(mockUseEmbeddedChatbot).toHaveBeenCalledWith('tryApp', 'my-app-id') + }) + + it('calls removeConversationIdInfo on mount', () => { + const appDetail = createMockAppDetail() + + render( + , + ) + + expect(mockRemoveConversationIdInfo).toHaveBeenCalledWith('my-app-id') + }) + }) +}) diff --git a/web/app/components/explore/try-app/app/index.spec.tsx b/web/app/components/explore/try-app/app/index.spec.tsx new file mode 100644 index 0000000000..927365a648 --- /dev/null +++ b/web/app/components/explore/try-app/app/index.spec.tsx @@ -0,0 +1,188 @@ +import type { TryAppInfo } from '@/service/try-app' +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import TryApp from './index' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +vi.mock('@/hooks/use-document-title', () => ({ + default: vi.fn(), +})) + +vi.mock('./chat', () => ({ + default: ({ appId, appDetail, className }: { appId: string, appDetail: TryAppInfo, className: string }) => ( +
    + Chat Component +
    + ), +})) + +vi.mock('./text-generation', () => ({ + default: ({ + appId, + className, + isWorkflow, + appData, + }: { appId: string, className: string, isWorkflow: boolean, appData: { mode: string } }) => ( +
    + TextGeneration Component +
    + ), +})) + +const createMockAppDetail = (mode: string): TryAppInfo => ({ + id: 'test-app-id', + name: 'Test App', + description: 'Test Description', + mode, + site: { + title: 'Test Site Title', + icon: 'icon', + icon_type: 'emoji', + icon_background: '#FFFFFF', + icon_url: '', + }, + model_config: { + model: { + provider: 'test/provider', + name: 'test-model', + mode: 'chat', + }, + dataset_configs: { + datasets: { + datasets: [], + }, + }, + agent_mode: { + tools: [], + }, + user_input_form: [], + }, +} as unknown as TryAppInfo) + +describe('TryApp (app/index.tsx)', () => { + afterEach(() => { + cleanup() + }) + + describe('chat mode rendering', () => { + it('renders Chat component for chat mode', () => { + const appDetail = createMockAppDetail('chat') + render() + + expect(screen.getByTestId('chat-component')).toBeInTheDocument() + expect(screen.queryByTestId('text-generation-component')).not.toBeInTheDocument() + }) + + it('renders Chat component for advanced-chat mode', () => { + const appDetail = createMockAppDetail('advanced-chat') + render() + + expect(screen.getByTestId('chat-component')).toBeInTheDocument() + expect(screen.queryByTestId('text-generation-component')).not.toBeInTheDocument() + }) + + it('renders Chat component for agent-chat mode', () => { + const appDetail = createMockAppDetail('agent-chat') + render() + + expect(screen.getByTestId('chat-component')).toBeInTheDocument() + expect(screen.queryByTestId('text-generation-component')).not.toBeInTheDocument() + }) + + it('passes correct props to Chat component', () => { + const appDetail = createMockAppDetail('chat') + render() + + const chatComponent = screen.getByTestId('chat-component') + expect(chatComponent).toHaveAttribute('data-app-id', 'test-app-id') + expect(chatComponent).toHaveAttribute('data-mode', 'chat') + expect(chatComponent).toHaveClass('h-full', 'grow') + }) + }) + + describe('completion mode rendering', () => { + it('renders TextGeneration component for completion mode', () => { + const appDetail = createMockAppDetail('completion') + render() + + expect(screen.getByTestId('text-generation-component')).toBeInTheDocument() + expect(screen.queryByTestId('chat-component')).not.toBeInTheDocument() + }) + + it('renders TextGeneration component for workflow mode', () => { + const appDetail = createMockAppDetail('workflow') + render() + + expect(screen.getByTestId('text-generation-component')).toBeInTheDocument() + expect(screen.queryByTestId('chat-component')).not.toBeInTheDocument() + }) + + it('passes isWorkflow=true for workflow mode', () => { + const appDetail = createMockAppDetail('workflow') + render() + + const textGenComponent = screen.getByTestId('text-generation-component') + expect(textGenComponent).toHaveAttribute('data-is-workflow', 'true') + }) + + it('passes isWorkflow=false for completion mode', () => { + const appDetail = createMockAppDetail('completion') + render() + + const textGenComponent = screen.getByTestId('text-generation-component') + expect(textGenComponent).toHaveAttribute('data-is-workflow', 'false') + }) + + it('passes correct props to TextGeneration component', () => { + const appDetail = createMockAppDetail('completion') + render() + + const textGenComponent = screen.getByTestId('text-generation-component') + expect(textGenComponent).toHaveAttribute('data-app-id', 'test-app-id') + expect(textGenComponent).toHaveClass('h-full', 'grow') + }) + }) + + describe('document title', () => { + it('calls useDocumentTitle with site title', async () => { + const useDocumentTitle = (await import('@/hooks/use-document-title')).default + const appDetail = createMockAppDetail('chat') + appDetail.site.title = 'My App Title' + + render() + + expect(useDocumentTitle).toHaveBeenCalledWith('My App Title') + }) + + it('calls useDocumentTitle with empty string when site.title is undefined', async () => { + const useDocumentTitle = (await import('@/hooks/use-document-title')).default + const appDetail = createMockAppDetail('chat') + appDetail.site = undefined as unknown as TryAppInfo['site'] + + render() + + expect(useDocumentTitle).toHaveBeenCalledWith('') + }) + }) + + describe('wrapper styling', () => { + it('renders with correct wrapper classes', () => { + const appDetail = createMockAppDetail('chat') + const { container } = render() + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('flex', 'h-full', 'w-full') + }) + }) +}) diff --git a/web/app/components/explore/try-app/app/text-generation.spec.tsx b/web/app/components/explore/try-app/app/text-generation.spec.tsx new file mode 100644 index 0000000000..cbeafc5132 --- /dev/null +++ b/web/app/components/explore/try-app/app/text-generation.spec.tsx @@ -0,0 +1,468 @@ +import type { AppData } from '@/models/share' +import { cleanup, fireEvent, render, screen, waitFor } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import TextGeneration from './text-generation' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'tryApp.tryInfo': 'This is a try app notice', + } + return translations[key] || key + }, + }), +})) + +const mockUpdateAppInfo = vi.fn() +const mockUpdateAppParams = vi.fn() +const mockAppParams = { + user_input_form: [], + more_like_this: { enabled: false }, + file_upload: null, + text_to_speech: { enabled: false }, + system_parameters: {}, +} +let mockStoreAppParams: typeof mockAppParams | null = mockAppParams + +vi.mock('@/context/web-app-context', () => ({ + useWebAppStore: (selector: (state: unknown) => unknown) => { + const state = { + updateAppInfo: mockUpdateAppInfo, + updateAppParams: mockUpdateAppParams, + appParams: mockStoreAppParams, + } + return selector(state) + }, +})) + +const mockUseGetTryAppParams = vi.fn() + +vi.mock('@/service/use-try-app', () => ({ + useGetTryAppParams: (...args: unknown[]) => mockUseGetTryAppParams(...args), +})) + +let mockMediaType = 'pc' + +vi.mock('@/hooks/use-breakpoints', () => ({ + default: () => mockMediaType, + MediaType: { + mobile: 'mobile', + pc: 'pc', + }, +})) + +vi.mock('@/app/components/share/text-generation/run-once', () => ({ + default: ({ + siteInfo, + onSend, + onInputsChange, + }: { siteInfo: { title: string }, onSend: () => void, onInputsChange: (inputs: Record) => void }) => ( +
    + {siteInfo?.title} + + +
    + ), +})) + +vi.mock('@/app/components/share/text-generation/result', () => ({ + default: ({ + isWorkflow, + appId, + onCompleted, + onRunStart, + }: { isWorkflow: boolean, appId: string, onCompleted: () => void, onRunStart: () => void }) => ( +
    + + +
    + ), +})) + +const createMockAppData = (overrides: Partial = {}): AppData => ({ + app_id: 'test-app-id', + site: { + title: 'Test App Title', + description: 'Test App Description', + icon: '๐Ÿš€', + icon_type: 'emoji', + icon_background: '#FFFFFF', + icon_url: '', + default_language: 'en', + prompt_public: true, + copyright: '', + privacy_policy: '', + custom_disclaimer: '', + }, + custom_config: { + remove_webapp_brand: false, + }, + ...overrides, +} as AppData) + +describe('TextGeneration', () => { + beforeEach(() => { + mockStoreAppParams = mockAppParams + mockMediaType = 'pc' + mockUseGetTryAppParams.mockReturnValue({ + data: mockAppParams, + }) + }) + + afterEach(() => { + cleanup() + vi.clearAllMocks() + }) + + describe('loading state', () => { + it('renders loading when appData is null', () => { + render( + , + ) + + expect(screen.getByRole('status')).toBeInTheDocument() + }) + + it('renders loading when appParams is not available', () => { + mockStoreAppParams = null + mockUseGetTryAppParams.mockReturnValue({ + data: null, + }) + + render( + , + ) + + expect(screen.getByRole('status')).toBeInTheDocument() + }) + }) + + describe('content rendering', () => { + it('renders app title', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + // Multiple elements may have the title (header and RunOnce mock) + const titles = screen.getAllByText('Test App Title') + expect(titles.length).toBeGreaterThan(0) + }) + }) + + it('renders app description when available', async () => { + const appData = createMockAppData({ + site: { + title: 'Test App', + description: 'This is a description', + icon: '๐Ÿš€', + icon_type: 'emoji', + icon_background: '#FFFFFF', + icon_url: '', + default_language: 'en', + prompt_public: true, + copyright: '', + privacy_policy: '', + custom_disclaimer: '', + }, + } as unknown as Partial) + + render( + , + ) + + await waitFor(() => { + expect(screen.getByText('This is a description')).toBeInTheDocument() + }) + }) + + it('renders RunOnce component', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + expect(screen.getByTestId('run-once')).toBeInTheDocument() + }) + }) + + it('renders Result component', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + expect(screen.getByTestId('result-component')).toBeInTheDocument() + }) + }) + }) + + describe('workflow mode', () => { + it('passes isWorkflow=true to Result when isWorkflow prop is true', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + const resultComponent = screen.getByTestId('result-component') + expect(resultComponent).toHaveAttribute('data-is-workflow', 'true') + }) + }) + + it('passes isWorkflow=false to Result when isWorkflow prop is false', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + const resultComponent = screen.getByTestId('result-component') + expect(resultComponent).toHaveAttribute('data-is-workflow', 'false') + }) + }) + }) + + describe('send functionality', () => { + it('triggers send when RunOnce sends', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + expect(screen.getByTestId('send-button')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('send-button')) + + // The send should work without errors + expect(screen.getByTestId('result-component')).toBeInTheDocument() + }) + }) + + describe('completion handling', () => { + it('shows alert after completion', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + expect(screen.getByTestId('complete-button')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('complete-button')) + + await waitFor(() => { + expect(screen.getByText('This is a try app notice')).toBeInTheDocument() + }) + }) + }) + + describe('className prop', () => { + it('applies custom className', async () => { + const appData = createMockAppData() + + const { container } = render( + , + ) + + await waitFor(() => { + const element = container.querySelector('.custom-class') + expect(element).toBeInTheDocument() + }) + }) + }) + + describe('hook effects', () => { + it('calls updateAppInfo when appData changes', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + expect(mockUpdateAppInfo).toHaveBeenCalledWith(appData) + }) + }) + + it('calls updateAppParams when tryAppParams changes', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + expect(mockUpdateAppParams).toHaveBeenCalledWith(mockAppParams) + }) + }) + + it('calls useGetTryAppParams with correct appId', () => { + const appData = createMockAppData() + + render( + , + ) + + expect(mockUseGetTryAppParams).toHaveBeenCalledWith('my-app-id') + }) + }) + + describe('result panel visibility', () => { + it('shows result panel after run starts', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + expect(screen.getByTestId('run-start-button')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('run-start-button')) + + // Result panel should remain visible + expect(screen.getByTestId('result-component')).toBeInTheDocument() + }) + }) + + describe('input handling', () => { + it('handles input changes from RunOnce', async () => { + const appData = createMockAppData() + + render( + , + ) + + await waitFor(() => { + expect(screen.getByTestId('inputs-change-button')).toBeInTheDocument() + }) + + // Trigger input change which should call setInputs callback + fireEvent.click(screen.getByTestId('inputs-change-button')) + + // The component should handle the input change without errors + expect(screen.getByTestId('run-once')).toBeInTheDocument() + }) + }) + + describe('mobile behavior', () => { + it('renders mobile toggle panel on mobile', async () => { + mockMediaType = 'mobile' + const appData = createMockAppData() + + const { container } = render( + , + ) + + await waitFor(() => { + // Mobile toggle panel should be rendered + const togglePanel = container.querySelector('.cursor-grab') + expect(togglePanel).toBeInTheDocument() + }) + }) + + it('toggles result panel visibility on mobile', async () => { + mockMediaType = 'mobile' + const appData = createMockAppData() + + const { container } = render( + , + ) + + await waitFor(() => { + const togglePanel = container.querySelector('.cursor-grab') + expect(togglePanel).toBeInTheDocument() + }) + + // Click to show result panel + const toggleParent = container.querySelector('.cursor-grab')?.parentElement + if (toggleParent) { + fireEvent.click(toggleParent) + } + + // Click again to hide result panel + await waitFor(() => { + const newToggleParent = container.querySelector('.cursor-grab')?.parentElement + if (newToggleParent) { + fireEvent.click(newToggleParent) + } + }) + + // Component should handle both show and hide without errors + expect(screen.getByTestId('result-component')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/explore/try-app/index.spec.tsx b/web/app/components/explore/try-app/index.spec.tsx new file mode 100644 index 0000000000..3ae132b7ed --- /dev/null +++ b/web/app/components/explore/try-app/index.spec.tsx @@ -0,0 +1,411 @@ +import type { TryAppInfo } from '@/service/try-app' +import { cleanup, fireEvent, render, screen, waitFor } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import TryApp from './index' +import { TypeEnum } from './tab' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'tryApp.tabHeader.try': 'Try', + 'tryApp.tabHeader.detail': 'Detail', + } + return translations[key] || key + }, + }), +})) + +const mockUseGetTryAppInfo = vi.fn() + +vi.mock('@/service/use-try-app', () => ({ + useGetTryAppInfo: (...args: unknown[]) => mockUseGetTryAppInfo(...args), +})) + +vi.mock('./app', () => ({ + default: ({ appId, appDetail }: { appId: string, appDetail: TryAppInfo }) => ( +
    + App Component +
    + ), +})) + +vi.mock('./preview', () => ({ + default: ({ appId, appDetail }: { appId: string, appDetail: TryAppInfo }) => ( +
    + Preview Component +
    + ), +})) + +vi.mock('./app-info', () => ({ + default: ({ + appId, + appDetail, + category, + className, + onCreate, + }: { appId: string, appDetail: TryAppInfo, category?: string, className?: string, onCreate: () => void }) => ( +
    + + App Info: + {' '} + {appDetail?.name} +
    + ), +})) + +const createMockAppDetail = (mode: string = 'chat'): TryAppInfo => ({ + id: 'test-app-id', + name: 'Test App Name', + description: 'Test Description', + mode, + site: { + title: 'Test Site Title', + icon: '๐Ÿš€', + icon_type: 'emoji', + icon_background: '#FFFFFF', + icon_url: '', + }, + model_config: { + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + mode: 'chat', + }, + dataset_configs: { + datasets: { + datasets: [], + }, + }, + agent_mode: { + tools: [], + }, + user_input_form: [], + }, +} as unknown as TryAppInfo) + +describe('TryApp (main index.tsx)', () => { + beforeEach(() => { + mockUseGetTryAppInfo.mockReturnValue({ + data: createMockAppDetail(), + isLoading: false, + }) + }) + + afterEach(() => { + cleanup() + vi.clearAllMocks() + }) + + describe('loading state', () => { + it('renders loading when isLoading is true', () => { + mockUseGetTryAppInfo.mockReturnValue({ + data: null, + isLoading: true, + }) + + render( + , + ) + + expect(document.body.querySelector('[role="status"]')).toBeInTheDocument() + }) + }) + + describe('content rendering', () => { + it('renders Tab component', async () => { + render( + , + ) + + await waitFor(() => { + expect(screen.getByText('Try')).toBeInTheDocument() + expect(screen.getByText('Detail')).toBeInTheDocument() + }) + }) + + it('renders App component by default (TRY mode)', async () => { + render( + , + ) + + await waitFor(() => { + expect(document.body.querySelector('[data-testid="app-component"]')).toBeInTheDocument() + expect(document.body.querySelector('[data-testid="preview-component"]')).not.toBeInTheDocument() + }) + }) + + it('renders AppInfo component', async () => { + render( + , + ) + + await waitFor(() => { + expect(document.body.querySelector('[data-testid="app-info-component"]')).toBeInTheDocument() + }) + }) + + it('renders close button', async () => { + render( + , + ) + + await waitFor(() => { + // Find the close button (the one with RiCloseLine icon) + const buttons = document.body.querySelectorAll('button') + expect(buttons.length).toBeGreaterThan(0) + }) + }) + }) + + describe('tab switching', () => { + it('switches to Preview when Detail tab is clicked', async () => { + render( + , + ) + + await waitFor(() => { + expect(screen.getByText('Detail')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByText('Detail')) + + await waitFor(() => { + expect(document.body.querySelector('[data-testid="preview-component"]')).toBeInTheDocument() + expect(document.body.querySelector('[data-testid="app-component"]')).not.toBeInTheDocument() + }) + }) + + it('switches back to App when Try tab is clicked', async () => { + render( + , + ) + + await waitFor(() => { + expect(screen.getByText('Detail')).toBeInTheDocument() + }) + + // First switch to Detail + fireEvent.click(screen.getByText('Detail')) + + await waitFor(() => { + expect(document.body.querySelector('[data-testid="preview-component"]')).toBeInTheDocument() + }) + + // Then switch back to Try + fireEvent.click(screen.getByText('Try')) + + await waitFor(() => { + expect(document.body.querySelector('[data-testid="app-component"]')).toBeInTheDocument() + }) + }) + }) + + describe('close functionality', () => { + it('calls onClose when close button is clicked', async () => { + const mockOnClose = vi.fn() + + render( + , + ) + + await waitFor(() => { + // Find the button with close icon + const buttons = document.body.querySelectorAll('button') + const closeButton = Array.from(buttons).find(btn => + btn.querySelector('svg') || btn.className.includes('rounded-[10px]'), + ) + expect(closeButton).toBeInTheDocument() + + if (closeButton) + fireEvent.click(closeButton) + }) + + expect(mockOnClose).toHaveBeenCalled() + }) + }) + + describe('create functionality', () => { + it('calls onCreate when create button in AppInfo is clicked', async () => { + const mockOnCreate = vi.fn() + + render( + , + ) + + await waitFor(() => { + const createButton = document.body.querySelector('[data-testid="create-button"]') + expect(createButton).toBeInTheDocument() + + if (createButton) + fireEvent.click(createButton) + }) + + expect(mockOnCreate).toHaveBeenCalledTimes(1) + }) + }) + + describe('category prop', () => { + it('passes category to AppInfo when provided', async () => { + render( + , + ) + + await waitFor(() => { + const appInfo = document.body.querySelector('[data-testid="app-info-component"]') + expect(appInfo).toHaveAttribute('data-category', 'AI Assistant') + }) + }) + + it('does not pass category to AppInfo when not provided', async () => { + render( + , + ) + + await waitFor(() => { + const appInfo = document.body.querySelector('[data-testid="app-info-component"]') + expect(appInfo).not.toHaveAttribute('data-category', expect.any(String)) + }) + }) + }) + + describe('hook calls', () => { + it('calls useGetTryAppInfo with correct appId', () => { + render( + , + ) + + expect(mockUseGetTryAppInfo).toHaveBeenCalledWith('my-specific-app-id') + }) + }) + + describe('props passing', () => { + it('passes appId to App component', async () => { + render( + , + ) + + await waitFor(() => { + const appComponent = document.body.querySelector('[data-testid="app-component"]') + expect(appComponent).toHaveAttribute('data-app-id', 'my-app-id') + }) + }) + + it('passes appId to Preview component when in Detail mode', async () => { + render( + , + ) + + await waitFor(() => { + expect(screen.getByText('Detail')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByText('Detail')) + + await waitFor(() => { + const previewComponent = document.body.querySelector('[data-testid="preview-component"]') + expect(previewComponent).toHaveAttribute('data-app-id', 'my-app-id') + }) + }) + + it('passes appId to AppInfo component', async () => { + render( + , + ) + + await waitFor(() => { + const appInfoComponent = document.body.querySelector('[data-testid="app-info-component"]') + expect(appInfoComponent).toHaveAttribute('data-app-id', 'my-app-id') + }) + }) + + it('passes appDetail to AppInfo component', async () => { + render( + , + ) + + await waitFor(() => { + const appInfoComponent = document.body.querySelector('[data-testid="app-info-component"]') + expect(appInfoComponent?.textContent).toContain('Test App Name') + }) + }) + }) + + describe('TypeEnum export', () => { + it('exports TypeEnum correctly', () => { + expect(TypeEnum.TRY).toBe('try') + expect(TypeEnum.DETAIL).toBe('detail') + }) + }) +}) diff --git a/web/app/components/explore/try-app/preview/basic-app-preview.spec.tsx b/web/app/components/explore/try-app/preview/basic-app-preview.spec.tsx new file mode 100644 index 0000000000..bf86d3f02f --- /dev/null +++ b/web/app/components/explore/try-app/preview/basic-app-preview.spec.tsx @@ -0,0 +1,527 @@ +import { cleanup, render, screen, waitFor } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import BasicAppPreview from './basic-app-preview' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +const mockUseGetTryAppInfo = vi.fn() +const mockUseAllToolProviders = vi.fn() +const mockUseGetTryAppDataSets = vi.fn() +const mockUseTextGenerationCurrentProviderAndModelAndModelList = vi.fn() + +vi.mock('@/service/use-try-app', () => ({ + useGetTryAppInfo: (...args: unknown[]) => mockUseGetTryAppInfo(...args), + useGetTryAppDataSets: (...args: unknown[]) => mockUseGetTryAppDataSets(...args), +})) + +vi.mock('@/service/use-tools', () => ({ + useAllToolProviders: () => mockUseAllToolProviders(), +})) + +vi.mock('../../../header/account-setting/model-provider-page/hooks', () => ({ + useTextGenerationCurrentProviderAndModelAndModelList: (...args: unknown[]) => + mockUseTextGenerationCurrentProviderAndModelAndModelList(...args), +})) + +vi.mock('@/hooks/use-breakpoints', () => ({ + default: () => 'pc', + MediaType: { + mobile: 'mobile', + pc: 'pc', + }, +})) + +vi.mock('@/app/components/app/configuration/config', () => ({ + default: () =>
    Config
    , +})) + +vi.mock('@/app/components/app/configuration/debug', () => ({ + default: () =>
    Debug
    , +})) + +vi.mock('@/app/components/base/features', () => ({ + FeaturesProvider: ({ children }: { children: React.ReactNode }) => ( +
    {children}
    + ), +})) + +const createMockAppDetail = (mode: string = 'chat'): Record => ({ + id: 'test-app-id', + name: 'Test App', + description: 'Test Description', + mode, + site: { + title: 'Test Site Title', + icon: '๐Ÿš€', + icon_type: 'emoji', + icon_background: '#FFFFFF', + icon_url: '', + }, + model_config: { + model: { + provider: 'langgenius/openai/openai', + name: 'gpt-4', + mode: 'chat', + }, + pre_prompt: 'You are a helpful assistant', + user_input_form: [] as unknown[], + external_data_tools: [] as unknown[], + dataset_configs: { + datasets: { + datasets: [] as unknown[], + }, + }, + agent_mode: { + tools: [] as unknown[], + enabled: false, + }, + more_like_this: { enabled: false }, + opening_statement: 'Hello!', + suggested_questions: ['Question 1'], + sensitive_word_avoidance: null, + speech_to_text: null, + text_to_speech: null, + file_upload: null as unknown, + suggested_questions_after_answer: null, + retriever_resource: null, + annotation_reply: null, + }, + deleted_tools: [] as unknown[], +}) + +describe('BasicAppPreview', () => { + beforeEach(() => { + mockUseGetTryAppInfo.mockReturnValue({ + data: createMockAppDetail(), + isLoading: false, + }) + mockUseAllToolProviders.mockReturnValue({ + data: [], + isLoading: false, + }) + mockUseGetTryAppDataSets.mockReturnValue({ + data: { data: [] }, + isLoading: false, + }) + mockUseTextGenerationCurrentProviderAndModelAndModelList.mockReturnValue({ + currentModel: { + features: [], + }, + }) + }) + + afterEach(() => { + cleanup() + vi.clearAllMocks() + }) + + describe('loading state', () => { + it('renders loading when app detail is loading', () => { + mockUseGetTryAppInfo.mockReturnValue({ + data: null, + isLoading: true, + }) + + render() + + expect(screen.getByRole('status')).toBeInTheDocument() + }) + + it('renders loading when tool providers are loading', () => { + mockUseAllToolProviders.mockReturnValue({ + data: null, + isLoading: true, + }) + + render() + + expect(screen.getByRole('status')).toBeInTheDocument() + }) + + it('renders loading when datasets are loading', () => { + mockUseGetTryAppDataSets.mockReturnValue({ + data: null, + isLoading: true, + }) + + render() + + expect(screen.getByRole('status')).toBeInTheDocument() + }) + }) + + describe('content rendering', () => { + it('renders Config component when data is loaded', async () => { + render() + + await waitFor(() => { + expect(screen.getByTestId('config-component')).toBeInTheDocument() + }) + }) + + it('renders Debug component when data is loaded on PC', async () => { + render() + + await waitFor(() => { + expect(screen.getByTestId('debug-component')).toBeInTheDocument() + }) + }) + + it('renders FeaturesProvider', async () => { + render() + + await waitFor(() => { + expect(screen.getByTestId('features-provider')).toBeInTheDocument() + }) + }) + }) + + describe('different app modes', () => { + it('handles chat mode', async () => { + mockUseGetTryAppInfo.mockReturnValue({ + data: createMockAppDetail('chat'), + isLoading: false, + }) + + render() + + await waitFor(() => { + expect(screen.getByTestId('config-component')).toBeInTheDocument() + }) + }) + + it('handles completion mode', async () => { + mockUseGetTryAppInfo.mockReturnValue({ + data: createMockAppDetail('completion'), + isLoading: false, + }) + + render() + + await waitFor(() => { + expect(screen.getByTestId('config-component')).toBeInTheDocument() + }) + }) + + it('handles agent-chat mode', async () => { + const agentAppDetail = createMockAppDetail('agent-chat') + const modelConfig = agentAppDetail.model_config as Record + modelConfig.agent_mode = { + tools: [ + { + provider_id: 'test-provider', + provider_name: 'test-provider', + provider_type: 'builtin', + tool_name: 'test-tool', + enabled: true, + }, + ], + enabled: true, + max_iteration: 5, + } + + mockUseGetTryAppInfo.mockReturnValue({ + data: agentAppDetail, + isLoading: false, + }) + + mockUseAllToolProviders.mockReturnValue({ + data: [ + { + id: 'test-provider', + is_team_authorization: true, + icon: '/icon.png', + }, + ], + isLoading: false, + }) + + render() + + await waitFor(() => { + expect(screen.getByTestId('config-component')).toBeInTheDocument() + }) + }) + }) + + describe('hook calls', () => { + it('calls useGetTryAppInfo with correct appId', () => { + render() + + expect(mockUseGetTryAppInfo).toHaveBeenCalledWith('my-app-id') + }) + + it('calls useTextGenerationCurrentProviderAndModelAndModelList with model config', async () => { + render() + + await waitFor(() => { + expect(mockUseTextGenerationCurrentProviderAndModelAndModelList).toHaveBeenCalled() + }) + }) + }) + + describe('model features', () => { + it('handles vision feature', async () => { + mockUseTextGenerationCurrentProviderAndModelAndModelList.mockReturnValue({ + currentModel: { + features: ['vision'], + }, + }) + + render() + + await waitFor(() => { + expect(screen.getByTestId('config-component')).toBeInTheDocument() + }) + }) + + it('handles document feature', async () => { + mockUseTextGenerationCurrentProviderAndModelAndModelList.mockReturnValue({ + currentModel: { + features: ['document'], + }, + }) + + render() + + await waitFor(() => { + expect(screen.getByTestId('config-component')).toBeInTheDocument() + }) + }) + + it('handles audio feature', async () => { + mockUseTextGenerationCurrentProviderAndModelAndModelList.mockReturnValue({ + currentModel: { + features: ['audio'], + }, + }) + + render() + + await waitFor(() => { + expect(screen.getByTestId('config-component')).toBeInTheDocument() + }) + }) + + it('handles video feature', async () => { + mockUseTextGenerationCurrentProviderAndModelAndModelList.mockReturnValue({ + currentModel: { + features: ['video'], + }, + }) + + render() + + await waitFor(() => { + expect(screen.getByTestId('config-component')).toBeInTheDocument() + }) + }) + }) + + describe('dataset handling', () => { + it('handles app with datasets in agent mode', async () => { + const appWithDatasets = createMockAppDetail('agent-chat') + const modelConfig = appWithDatasets.model_config as Record + modelConfig.agent_mode = { + tools: [ + { + dataset: { + enabled: true, + id: 'dataset-1', + }, + }, + ], + enabled: true, + } + + mockUseGetTryAppInfo.mockReturnValue({ + data: appWithDatasets, + isLoading: false, + }) + + render() + + await waitFor(() => { + expect(mockUseGetTryAppDataSets).toHaveBeenCalled() + }) + }) + + it('handles app with datasets in dataset_configs', async () => { + const appWithDatasets = createMockAppDetail('chat') + const modelConfig = appWithDatasets.model_config as Record + modelConfig.dataset_configs = { + datasets: { + datasets: [ + { dataset: { id: 'dataset-1' } }, + { dataset: { id: 'dataset-2' } }, + ], + }, + } + + mockUseGetTryAppInfo.mockReturnValue({ + data: appWithDatasets, + isLoading: false, + }) + + render() + + await waitFor(() => { + expect(mockUseGetTryAppDataSets).toHaveBeenCalled() + }) + }) + }) + + describe('advanced prompt mode', () => { + it('handles advanced prompt mode', async () => { + const appWithAdvancedPrompt = createMockAppDetail('chat') + const modelConfig = appWithAdvancedPrompt.model_config as Record + modelConfig.prompt_type = 'advanced' + modelConfig.chat_prompt_config = { + prompt: [{ role: 'system', text: 'You are helpful' }], + } + + mockUseGetTryAppInfo.mockReturnValue({ + data: appWithAdvancedPrompt, + isLoading: false, + }) + + render() + + await waitFor(() => { + expect(screen.getByTestId('config-component')).toBeInTheDocument() + }) + }) + }) + + describe('file upload config', () => { + it('handles file upload config', async () => { + const appWithFileUpload = createMockAppDetail('chat') + const modelConfig = appWithFileUpload.model_config as Record + modelConfig.file_upload = { + enabled: true, + image: { + enabled: true, + detail: 'high', + number_limits: 5, + transfer_methods: ['local_file', 'remote_url'], + }, + allowed_file_types: ['image'], + allowed_file_extensions: ['.jpg', '.png'], + allowed_file_upload_methods: ['local_file'], + number_limits: 3, + } + + mockUseGetTryAppInfo.mockReturnValue({ + data: appWithFileUpload, + isLoading: false, + }) + + render() + + await waitFor(() => { + expect(screen.getByTestId('config-component')).toBeInTheDocument() + }) + }) + }) + + describe('external data tools', () => { + it('handles app with external_data_tools', async () => { + const appWithExternalTools = createMockAppDetail('chat') + const modelConfig = appWithExternalTools.model_config as Record + modelConfig.external_data_tools = [ + { + variable: 'test_var', + label: 'Test Label', + enabled: true, + type: 'text', + config: {}, + icon: '/icon.png', + icon_background: '#FFFFFF', + }, + ] + + mockUseGetTryAppInfo.mockReturnValue({ + data: appWithExternalTools, + isLoading: false, + }) + + render() + + await waitFor(() => { + expect(screen.getByTestId('config-component')).toBeInTheDocument() + }) + }) + }) + + describe('deleted tools handling', () => { + it('handles app with deleted tools', async () => { + const agentAppDetail = createMockAppDetail('agent-chat') + const modelConfig = agentAppDetail.model_config as Record + modelConfig.agent_mode = { + tools: [ + { + id: 'tool-1', + provider_id: 'test-provider', + provider_name: 'test-provider', + provider_type: 'builtin', + tool_name: 'test-tool', + enabled: true, + }, + ], + enabled: true, + max_iteration: 5, + } + agentAppDetail.deleted_tools = [ + { + id: 'tool-1', + tool_name: 'test-tool', + }, + ] + + mockUseGetTryAppInfo.mockReturnValue({ + data: agentAppDetail, + isLoading: false, + }) + + mockUseAllToolProviders.mockReturnValue({ + data: [ + { + id: 'test-provider', + is_team_authorization: false, + icon: '/icon.png', + }, + ], + isLoading: false, + }) + + render() + + await waitFor(() => { + expect(screen.getByTestId('config-component')).toBeInTheDocument() + }) + }) + }) + + describe('edge cases', () => { + it('handles app without model_config', async () => { + const appWithoutModelConfig = createMockAppDetail('chat') + appWithoutModelConfig.model_config = undefined + + mockUseGetTryAppInfo.mockReturnValue({ + data: appWithoutModelConfig, + isLoading: false, + }) + + render() + + // Should still render (with default model config) + await waitFor(() => { + expect(mockUseGetTryAppDataSets).toHaveBeenCalled() + }) + }) + }) +}) diff --git a/web/app/components/explore/try-app/preview/flow-app-preview.spec.tsx b/web/app/components/explore/try-app/preview/flow-app-preview.spec.tsx new file mode 100644 index 0000000000..c4e8175b82 --- /dev/null +++ b/web/app/components/explore/try-app/preview/flow-app-preview.spec.tsx @@ -0,0 +1,179 @@ +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import FlowAppPreview from './flow-app-preview' + +const mockUseGetTryAppFlowPreview = vi.fn() + +vi.mock('@/service/use-try-app', () => ({ + useGetTryAppFlowPreview: (...args: unknown[]) => mockUseGetTryAppFlowPreview(...args), +})) + +vi.mock('@/app/components/workflow/workflow-preview', () => ({ + default: ({ + className, + miniMapToRight, + nodes, + edges, + }: { className?: string, miniMapToRight?: boolean, nodes?: unknown[], edges?: unknown[] }) => ( +
    + WorkflowPreview +
    + ), +})) + +describe('FlowAppPreview', () => { + afterEach(() => { + cleanup() + vi.clearAllMocks() + }) + + describe('loading state', () => { + it('renders Loading component when isLoading is true', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: null, + isLoading: true, + }) + + render() + + expect(screen.getByRole('status')).toBeInTheDocument() + expect(screen.queryByTestId('workflow-preview')).not.toBeInTheDocument() + }) + }) + + describe('no data state', () => { + it('returns null when data is null', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: null, + isLoading: false, + }) + + const { container } = render() + + expect(container.firstChild).toBeNull() + }) + + it('returns null when data is undefined', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: undefined, + isLoading: false, + }) + + const { container } = render() + + expect(container.firstChild).toBeNull() + }) + }) + + describe('data loaded state', () => { + it('renders WorkflowPreview when data is loaded', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: { + graph: { + nodes: [{ id: 'node1' }], + edges: [{ id: 'edge1' }], + }, + }, + isLoading: false, + }) + + render() + + expect(screen.getByTestId('workflow-preview')).toBeInTheDocument() + expect(screen.queryByRole('status')).not.toBeInTheDocument() + }) + + it('passes graph data to WorkflowPreview', () => { + const mockNodes = [{ id: 'node1' }, { id: 'node2' }, { id: 'node3' }] + const mockEdges = [{ id: 'edge1' }, { id: 'edge2' }] + + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: { + graph: { + nodes: mockNodes, + edges: mockEdges, + }, + }, + isLoading: false, + }) + + render() + + const workflowPreview = screen.getByTestId('workflow-preview') + expect(workflowPreview).toHaveAttribute('data-nodes-count', '3') + expect(workflowPreview).toHaveAttribute('data-edges-count', '2') + }) + + it('passes miniMapToRight=true to WorkflowPreview', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: { + graph: { + nodes: [], + edges: [], + }, + }, + isLoading: false, + }) + + render() + + const workflowPreview = screen.getByTestId('workflow-preview') + expect(workflowPreview).toHaveAttribute('data-mini-map-to-right', 'true') + }) + + it('passes className to WorkflowPreview', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: { + graph: { + nodes: [], + edges: [], + }, + }, + isLoading: false, + }) + + render() + + const workflowPreview = screen.getByTestId('workflow-preview') + expect(workflowPreview).toHaveClass('custom-class') + }) + }) + + describe('hook calls', () => { + it('calls useGetTryAppFlowPreview with correct appId', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: null, + isLoading: true, + }) + + render() + + expect(mockUseGetTryAppFlowPreview).toHaveBeenCalledWith('my-specific-app-id') + }) + }) + + describe('wrapper styling', () => { + it('renders with correct wrapper classes when data is loaded', () => { + mockUseGetTryAppFlowPreview.mockReturnValue({ + data: { + graph: { + nodes: [], + edges: [], + }, + }, + isLoading: false, + }) + + const { container } = render() + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('h-full', 'w-full') + }) + }) +}) diff --git a/web/app/components/explore/try-app/preview/index.spec.tsx b/web/app/components/explore/try-app/preview/index.spec.tsx new file mode 100644 index 0000000000..022511efac --- /dev/null +++ b/web/app/components/explore/try-app/preview/index.spec.tsx @@ -0,0 +1,127 @@ +import type { TryAppInfo } from '@/service/try-app' +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import Preview from './index' + +vi.mock('./basic-app-preview', () => ({ + default: ({ appId }: { appId: string }) => ( +
    + BasicAppPreview +
    + ), +})) + +vi.mock('./flow-app-preview', () => ({ + default: ({ appId, className }: { appId: string, className?: string }) => ( +
    + FlowAppPreview +
    + ), +})) + +const createMockAppDetail = (mode: string): TryAppInfo => ({ + id: 'test-app-id', + name: 'Test App', + description: 'Test Description', + mode, + site: { + title: 'Test Site Title', + icon: 'icon', + icon_type: 'emoji', + icon_background: '#FFFFFF', + icon_url: '', + }, + model_config: { + model: { + provider: 'test/provider', + name: 'test-model', + mode: 'chat', + }, + dataset_configs: { + datasets: { + datasets: [], + }, + }, + agent_mode: { + tools: [], + }, + user_input_form: [], + }, +} as unknown as TryAppInfo) + +describe('Preview', () => { + afterEach(() => { + cleanup() + }) + + describe('basic app rendering', () => { + it('renders BasicAppPreview for agent-chat mode', () => { + const appDetail = createMockAppDetail('agent-chat') + render() + + expect(screen.getByTestId('basic-app-preview')).toBeInTheDocument() + expect(screen.queryByTestId('flow-app-preview')).not.toBeInTheDocument() + }) + + it('renders BasicAppPreview for chat mode', () => { + const appDetail = createMockAppDetail('chat') + render() + + expect(screen.getByTestId('basic-app-preview')).toBeInTheDocument() + expect(screen.queryByTestId('flow-app-preview')).not.toBeInTheDocument() + }) + + it('renders BasicAppPreview for completion mode', () => { + const appDetail = createMockAppDetail('completion') + render() + + expect(screen.getByTestId('basic-app-preview')).toBeInTheDocument() + expect(screen.queryByTestId('flow-app-preview')).not.toBeInTheDocument() + }) + + it('passes appId to BasicAppPreview', () => { + const appDetail = createMockAppDetail('chat') + render() + + const basicPreview = screen.getByTestId('basic-app-preview') + expect(basicPreview).toHaveAttribute('data-app-id', 'my-app-id') + }) + }) + + describe('flow app rendering', () => { + it('renders FlowAppPreview for workflow mode', () => { + const appDetail = createMockAppDetail('workflow') + render() + + expect(screen.getByTestId('flow-app-preview')).toBeInTheDocument() + expect(screen.queryByTestId('basic-app-preview')).not.toBeInTheDocument() + }) + + it('renders FlowAppPreview for advanced-chat mode', () => { + const appDetail = createMockAppDetail('advanced-chat') + render() + + expect(screen.getByTestId('flow-app-preview')).toBeInTheDocument() + expect(screen.queryByTestId('basic-app-preview')).not.toBeInTheDocument() + }) + + it('passes appId and className to FlowAppPreview', () => { + const appDetail = createMockAppDetail('workflow') + render() + + const flowPreview = screen.getByTestId('flow-app-preview') + expect(flowPreview).toHaveAttribute('data-app-id', 'my-flow-app-id') + expect(flowPreview).toHaveClass('h-full') + }) + }) + + describe('wrapper styling', () => { + it('renders with correct wrapper classes', () => { + const appDetail = createMockAppDetail('chat') + const { container } = render() + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('h-full', 'w-full') + }) + }) +}) diff --git a/web/app/components/explore/try-app/tab.spec.tsx b/web/app/components/explore/try-app/tab.spec.tsx new file mode 100644 index 0000000000..81bb841887 --- /dev/null +++ b/web/app/components/explore/try-app/tab.spec.tsx @@ -0,0 +1,58 @@ +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import Tab, { TypeEnum } from './tab' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + const translations: Record = { + 'tryApp.tabHeader.try': 'Try', + 'tryApp.tabHeader.detail': 'Detail', + } + return translations[key] || key + }, + }), +})) + +describe('Tab', () => { + afterEach(() => { + cleanup() + }) + + it('renders tab with TRY value selected', () => { + const mockOnChange = vi.fn() + render() + + expect(screen.getByText('Try')).toBeInTheDocument() + expect(screen.getByText('Detail')).toBeInTheDocument() + }) + + it('renders tab with DETAIL value selected', () => { + const mockOnChange = vi.fn() + render() + + expect(screen.getByText('Try')).toBeInTheDocument() + expect(screen.getByText('Detail')).toBeInTheDocument() + }) + + it('calls onChange when clicking a tab', () => { + const mockOnChange = vi.fn() + render() + + fireEvent.click(screen.getByText('Detail')) + expect(mockOnChange).toHaveBeenCalledWith(TypeEnum.DETAIL) + }) + + it('calls onChange when clicking Try tab', () => { + const mockOnChange = vi.fn() + render() + + fireEvent.click(screen.getByText('Try')) + expect(mockOnChange).toHaveBeenCalledWith(TypeEnum.TRY) + }) + + it('exports TypeEnum correctly', () => { + expect(TypeEnum.TRY).toBe('try') + expect(TypeEnum.DETAIL).toBe('detail') + }) +}) diff --git a/web/app/components/header/account-dropdown/compliance.tsx b/web/app/components/header/account-dropdown/compliance.tsx index 562914dd07..6bc5b5c3f1 100644 --- a/web/app/components/header/account-dropdown/compliance.tsx +++ b/web/app/components/header/account-dropdown/compliance.tsx @@ -10,6 +10,7 @@ import { useModalContext } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' import { getDocDownloadUrl } from '@/service/common' import { cn } from '@/utils/classnames' +import { downloadUrl } from '@/utils/download' import Button from '../../base/button' import Gdpr from '../../base/icons/src/public/common/Gdpr' import Iso from '../../base/icons/src/public/common/Iso' @@ -47,9 +48,7 @@ const UpgradeOrDownload: FC = ({ doc_name }) => { mutationFn: async () => { try { const ret = await getDocDownloadUrl(doc_name) - const a = document.createElement('a') - a.href = ret.url - a.click() + downloadUrl({ url: ret.url }) Toast.notify({ type: 'success', message: t('operation.downloadSuccess', { ns: 'common' }), diff --git a/web/app/components/rag-pipeline/components/panel/input-field/footer-tip.spec.tsx b/web/app/components/rag-pipeline/components/panel/input-field/footer-tip.spec.tsx new file mode 100644 index 0000000000..5d5cde9735 --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/input-field/footer-tip.spec.tsx @@ -0,0 +1,59 @@ +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import FooterTip from './footer-tip' + +afterEach(() => { + cleanup() + vi.clearAllMocks() +}) + +describe('FooterTip', () => { + describe('rendering', () => { + it('should render without crashing', () => { + render() + + expect(screen.getByText('Drag to adjust grouping')).toBeInTheDocument() + }) + + it('should render the drag tip text', () => { + render() + + expect(screen.getByText('Drag to adjust grouping')).toBeInTheDocument() + }) + + it('should have correct container classes', () => { + const { container } = render() + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('flex', 'shrink-0', 'items-center', 'justify-center', 'gap-x-2', 'py-4') + }) + + it('should have correct text styling', () => { + render() + + const text = screen.getByText('Drag to adjust grouping') + expect(text).toHaveClass('system-xs-regular') + }) + + it('should have correct text color', () => { + const { container } = render() + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('text-text-quaternary') + }) + + it('should render the drag icon', () => { + const { container } = render() + + // The RiDragDropLine icon should be rendered + const icon = container.querySelector('.size-4') + expect(icon).toBeInTheDocument() + }) + }) + + describe('memoization', () => { + it('should be wrapped with React.memo', () => { + expect((FooterTip as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/input-field/hooks.spec.ts b/web/app/components/rag-pipeline/components/panel/input-field/hooks.spec.ts new file mode 100644 index 0000000000..452963ba7f --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/input-field/hooks.spec.ts @@ -0,0 +1,166 @@ +import { renderHook } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { useFloatingRight } from './hooks' + +// Mock reactflow +const mockGetNodes = vi.fn() +vi.mock('reactflow', () => ({ + useStore: (selector: (s: { getNodes: () => { id: string, data: { selected: boolean } }[] }) => unknown) => { + return selector({ getNodes: mockGetNodes }) + }, +})) + +// Mock zustand/react/shallow +vi.mock('zustand/react/shallow', () => ({ + useShallow: (fn: (...args: unknown[]) => unknown) => fn, +})) + +// Mock workflow store +let mockNodePanelWidth = 400 +let mockWorkflowCanvasWidth: number | undefined = 1200 +let mockOtherPanelWidth = 0 + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: Record) => unknown) => { + return selector({ + nodePanelWidth: mockNodePanelWidth, + workflowCanvasWidth: mockWorkflowCanvasWidth, + otherPanelWidth: mockOtherPanelWidth, + }) + }, +})) + +beforeEach(() => { + mockNodePanelWidth = 400 + mockWorkflowCanvasWidth = 1200 + mockOtherPanelWidth = 0 + mockGetNodes.mockReturnValue([]) +}) + +afterEach(() => { + vi.clearAllMocks() +}) + +describe('useFloatingRight', () => { + describe('initial state', () => { + it('should return floatingRight as false initially', () => { + mockGetNodes.mockReturnValue([]) + + const { result } = renderHook(() => useFloatingRight(600)) + + expect(result.current.floatingRight).toBe(false) + }) + + it('should return floatingRightWidth as target width when not floating', () => { + mockGetNodes.mockReturnValue([]) + + const { result } = renderHook(() => useFloatingRight(600)) + + expect(result.current.floatingRightWidth).toBe(600) + }) + }) + + describe('with no selected node', () => { + it('should calculate space without node panel width', () => { + mockGetNodes.mockReturnValue([{ id: 'node-1', data: { selected: false } }]) + mockWorkflowCanvasWidth = 1000 + + const { result } = renderHook(() => useFloatingRight(400)) + + // leftWidth = 1000 - 0 (no selected node) - 0 - 400 - 4 = 596 + // 596 >= 404 so floatingRight should be false + expect(result.current.floatingRight).toBe(false) + }) + }) + + describe('with selected node', () => { + it('should subtract node panel width from available space', () => { + mockGetNodes.mockReturnValue([{ id: 'node-1', data: { selected: true } }]) + mockWorkflowCanvasWidth = 1200 + + const { result } = renderHook(() => useFloatingRight(400)) + + // leftWidth = 1200 - 400 (node panel) - 0 - 400 - 4 = 396 + // 396 < 404 so floatingRight should be true + expect(result.current.floatingRight).toBe(true) + }) + }) + + describe('floatingRightWidth calculation', () => { + it('should return target width when not floating', () => { + mockGetNodes.mockReturnValue([]) + mockWorkflowCanvasWidth = 2000 + + const { result } = renderHook(() => useFloatingRight(600)) + + expect(result.current.floatingRightWidth).toBe(600) + }) + + it('should return minimum of target width and available panel widths when floating with no selected node', () => { + mockGetNodes.mockReturnValue([]) + mockWorkflowCanvasWidth = 500 + mockOtherPanelWidth = 200 + + const { result } = renderHook(() => useFloatingRight(600)) + + // When floating and no selected node, width = min(600, 0 + 200) = 200 + expect(result.current.floatingRightWidth).toBeLessThanOrEqual(600) + }) + + it('should include node panel width when node is selected', () => { + mockGetNodes.mockReturnValue([{ id: 'node-1', data: { selected: true } }]) + mockWorkflowCanvasWidth = 500 + mockNodePanelWidth = 300 + mockOtherPanelWidth = 100 + + const { result } = renderHook(() => useFloatingRight(600)) + + // When floating with selected node, width = min(600, 300 + 100) = 400 + expect(result.current.floatingRightWidth).toBeLessThanOrEqual(600) + }) + }) + + describe('edge cases', () => { + it('should handle undefined workflowCanvasWidth', () => { + mockGetNodes.mockReturnValue([]) + mockWorkflowCanvasWidth = undefined + + const { result } = renderHook(() => useFloatingRight(400)) + + // Should not throw and should maintain initial state + expect(result.current.floatingRight).toBe(false) + }) + + it('should handle zero target element width', () => { + mockGetNodes.mockReturnValue([]) + + const { result } = renderHook(() => useFloatingRight(0)) + + expect(result.current.floatingRightWidth).toBe(0) + }) + + it('should handle very large target element width', () => { + mockGetNodes.mockReturnValue([]) + mockWorkflowCanvasWidth = 500 + + const { result } = renderHook(() => useFloatingRight(10000)) + + // Should be floating due to limited space + expect(result.current.floatingRight).toBe(true) + }) + + it('should return first selected node id when multiple nodes exist', () => { + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { selected: false } }, + { id: 'node-2', data: { selected: true } }, + { id: 'node-3', data: { selected: false } }, + ]) + mockWorkflowCanvasWidth = 1200 + + const { result } = renderHook(() => useFloatingRight(400)) + + // Should have selected node so node panel is considered + expect(result.current).toBeDefined() + }) + }) +}) diff --git a/web/app/components/rag-pipeline/components/panel/input-field/label-right-content/index.spec.tsx b/web/app/components/rag-pipeline/components/panel/input-field/label-right-content/index.spec.tsx new file mode 100644 index 0000000000..71be12bb8d --- /dev/null +++ b/web/app/components/rag-pipeline/components/panel/input-field/label-right-content/index.spec.tsx @@ -0,0 +1,212 @@ +import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types' +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import { BlockEnum } from '@/app/components/workflow/types' +import Datasource from './datasource' +import GlobalInputs from './global-inputs' + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock BlockIcon +vi.mock('@/app/components/workflow/block-icon', () => ({ + default: ({ type, toolIcon, className }: { type: BlockEnum, toolIcon?: string, className?: string }) => ( +
    + ), +})) + +// Mock useToolIcon +vi.mock('@/app/components/workflow/hooks', () => ({ + useToolIcon: (nodeData: DataSourceNodeType) => nodeData.provider_name || 'default-icon', +})) + +// Mock Tooltip +vi.mock('@/app/components/base/tooltip', () => ({ + default: ({ popupContent, popupClassName }: { popupContent: string, popupClassName?: string }) => ( +
    + ), +})) + +afterEach(() => { + cleanup() + vi.clearAllMocks() +}) + +describe('Datasource', () => { + const createMockNodeData = (overrides?: Partial): DataSourceNodeType => ({ + title: 'Test Data Source', + desc: 'Test description', + type: BlockEnum.DataSource, + provider_name: 'test-provider', + provider_type: 'api', + datasource_name: 'test-datasource', + datasource_label: 'Test Datasource', + plugin_id: 'test-plugin', + datasource_parameters: {}, + datasource_configurations: {}, + ...overrides, + } as DataSourceNodeType) + + describe('rendering', () => { + it('should render without crashing', () => { + const nodeData = createMockNodeData() + + render() + + expect(screen.getByTestId('block-icon')).toBeInTheDocument() + }) + + it('should render the node title', () => { + const nodeData = createMockNodeData({ title: 'My Custom Data Source' }) + + render() + + expect(screen.getByText('My Custom Data Source')).toBeInTheDocument() + }) + + it('should render BlockIcon with correct type', () => { + const nodeData = createMockNodeData() + + render() + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-type', BlockEnum.DataSource) + }) + + it('should pass toolIcon from useToolIcon hook', () => { + const nodeData = createMockNodeData({ provider_name: 'custom-provider' }) + + render() + + const blockIcon = screen.getByTestId('block-icon') + expect(blockIcon).toHaveAttribute('data-tool-icon', 'custom-provider') + }) + + it('should have correct icon container styling', () => { + const nodeData = createMockNodeData() + + const { container } = render() + + const iconContainer = container.querySelector('.size-5') + expect(iconContainer).toBeInTheDocument() + expect(iconContainer).toHaveClass('flex', 'items-center', 'justify-center', 'rounded-md') + }) + + it('should have correct text styling', () => { + const nodeData = createMockNodeData() + + render() + + const titleElement = screen.getByText('Test Data Source') + expect(titleElement).toHaveClass('system-sm-medium', 'text-text-secondary') + }) + + it('should have correct container layout', () => { + const nodeData = createMockNodeData() + + const { container } = render() + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('flex', 'items-center', 'gap-x-1.5') + }) + }) + + describe('memoization', () => { + it('should be wrapped with React.memo', () => { + expect((Datasource as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) + }) + + describe('edge cases', () => { + it('should handle empty title', () => { + const nodeData = createMockNodeData({ title: '' }) + + render() + + // Should still render without the title text + expect(screen.getByTestId('block-icon')).toBeInTheDocument() + }) + + it('should handle long title', () => { + const longTitle = 'A'.repeat(100) + const nodeData = createMockNodeData({ title: longTitle }) + + render() + + expect(screen.getByText(longTitle)).toBeInTheDocument() + }) + + it('should handle special characters in title', () => { + const nodeData = createMockNodeData({ title: 'Test ' }) + + render() + + expect(screen.getByText('Test ')).toBeInTheDocument() + }) + }) +}) + +describe('GlobalInputs', () => { + describe('rendering', () => { + it('should render without crashing', () => { + render() + + expect(screen.getByText('inputFieldPanel.globalInputs.title')).toBeInTheDocument() + }) + + it('should render title with correct translation key', () => { + render() + + expect(screen.getByText('inputFieldPanel.globalInputs.title')).toBeInTheDocument() + }) + + it('should render tooltip component', () => { + render() + + expect(screen.getByTestId('tooltip')).toBeInTheDocument() + }) + + it('should pass correct tooltip content', () => { + render() + + const tooltip = screen.getByTestId('tooltip') + expect(tooltip).toHaveAttribute('data-content', 'inputFieldPanel.globalInputs.tooltip') + }) + + it('should have correct tooltip className', () => { + render() + + const tooltip = screen.getByTestId('tooltip') + expect(tooltip).toHaveClass('w-[240px]') + }) + + it('should have correct container layout', () => { + const { container } = render() + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('flex', 'items-center', 'gap-x-1') + }) + + it('should have correct title styling', () => { + render() + + const titleElement = screen.getByText('inputFieldPanel.globalInputs.title') + expect(titleElement).toHaveClass('system-sm-semibold-uppercase', 'text-text-secondary') + }) + }) + + describe('memoization', () => { + it('should be wrapped with React.memo', () => { + expect((GlobalInputs as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) + }) +}) diff --git a/web/app/components/rag-pipeline/components/publish-toast.spec.tsx b/web/app/components/rag-pipeline/components/publish-toast.spec.tsx new file mode 100644 index 0000000000..d61f091ed2 --- /dev/null +++ b/web/app/components/rag-pipeline/components/publish-toast.spec.tsx @@ -0,0 +1,129 @@ +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import PublishToast from './publish-toast' + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock workflow store with controllable state +let mockPublishedAt = 0 +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: Record) => unknown) => { + return selector({ publishedAt: mockPublishedAt }) + }, +})) + +afterEach(() => { + cleanup() + vi.clearAllMocks() +}) + +describe('PublishToast', () => { + beforeEach(() => { + mockPublishedAt = 0 + }) + + describe('rendering', () => { + it('should render when publishedAt is 0', () => { + mockPublishedAt = 0 + render() + + expect(screen.getByText('publishToast.title')).toBeInTheDocument() + }) + + it('should render toast title', () => { + render() + + expect(screen.getByText('publishToast.title')).toBeInTheDocument() + }) + + it('should render toast description', () => { + render() + + expect(screen.getByText('publishToast.desc')).toBeInTheDocument() + }) + + it('should not render when publishedAt is set', () => { + mockPublishedAt = Date.now() + const { container } = render() + + expect(container.firstChild).toBeNull() + }) + + it('should have correct positioning classes', () => { + render() + + const container = screen.getByText('publishToast.title').closest('.absolute') + expect(container).toHaveClass('bottom-[45px]', 'left-0', 'right-0', 'z-10') + }) + + it('should render info icon', () => { + const { container } = render() + + // The RiInformation2Fill icon should be rendered + const iconContainer = container.querySelector('.text-text-accent') + expect(iconContainer).toBeInTheDocument() + }) + + it('should render close button', () => { + const { container } = render() + + // The close button is a div with cursor-pointer, not a semantic button + const closeButton = container.querySelector('.cursor-pointer') + expect(closeButton).toBeInTheDocument() + }) + }) + + describe('user interactions', () => { + it('should hide toast when close button is clicked', () => { + const { container } = render() + + // The close button is a div with cursor-pointer, not a semantic button + const closeButton = container.querySelector('.cursor-pointer') + expect(screen.getByText('publishToast.title')).toBeInTheDocument() + + fireEvent.click(closeButton!) + + expect(screen.queryByText('publishToast.title')).not.toBeInTheDocument() + }) + + it('should remain hidden after close button is clicked', () => { + const { container, rerender } = render() + + // The close button is a div with cursor-pointer, not a semantic button + const closeButton = container.querySelector('.cursor-pointer') + fireEvent.click(closeButton!) + + rerender() + + expect(screen.queryByText('publishToast.title')).not.toBeInTheDocument() + }) + }) + + describe('styling', () => { + it('should have gradient overlay', () => { + const { container } = render() + + const gradientOverlay = container.querySelector('.bg-gradient-to-r') + expect(gradientOverlay).toBeInTheDocument() + }) + + it('should have correct toast width', () => { + render() + + const toastContainer = screen.getByText('publishToast.title').closest('.w-\\[420px\\]') + expect(toastContainer).toBeInTheDocument() + }) + + it('should have rounded border', () => { + render() + + const toastContainer = screen.getByText('publishToast.title').closest('.rounded-xl') + expect(toastContainer).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/rag-pipeline/components/rag-pipeline-main.spec.tsx b/web/app/components/rag-pipeline/components/rag-pipeline-main.spec.tsx new file mode 100644 index 0000000000..3de3c3deeb --- /dev/null +++ b/web/app/components/rag-pipeline/components/rag-pipeline-main.spec.tsx @@ -0,0 +1,276 @@ +import type { PropsWithChildren } from 'react' +import type { Edge, Node, Viewport } from 'reactflow' +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import RagPipelineMain from './rag-pipeline-main' + +// Mock hooks from ../hooks +vi.mock('../hooks', () => ({ + useAvailableNodesMetaData: () => ({ nodes: [], nodesMap: {} }), + useDSL: () => ({ + exportCheck: vi.fn(), + handleExportDSL: vi.fn(), + }), + useGetRunAndTraceUrl: () => ({ + getWorkflowRunAndTraceUrl: vi.fn(), + }), + useNodesSyncDraft: () => ({ + doSyncWorkflowDraft: vi.fn(), + syncWorkflowDraftWhenPageClose: vi.fn(), + }), + usePipelineRefreshDraft: () => ({ + handleRefreshWorkflowDraft: vi.fn(), + }), + usePipelineRun: () => ({ + handleBackupDraft: vi.fn(), + handleLoadBackupDraft: vi.fn(), + handleRestoreFromPublishedWorkflow: vi.fn(), + handleRun: vi.fn(), + handleStopRun: vi.fn(), + }), + usePipelineStartRun: () => ({ + handleStartWorkflowRun: vi.fn(), + handleWorkflowStartRunInWorkflow: vi.fn(), + }), +})) + +// Mock useConfigsMap +vi.mock('../hooks/use-configs-map', () => ({ + useConfigsMap: () => ({ + flowId: 'test-flow-id', + flowType: 'ragPipeline', + fileSettings: {}, + }), +})) + +// Mock useInspectVarsCrud +vi.mock('../hooks/use-inspect-vars-crud', () => ({ + useInspectVarsCrud: () => ({ + hasNodeInspectVars: vi.fn(), + hasSetInspectVar: vi.fn(), + fetchInspectVarValue: vi.fn(), + editInspectVarValue: vi.fn(), + renameInspectVarName: vi.fn(), + appendNodeInspectVars: vi.fn(), + deleteInspectVar: vi.fn(), + deleteNodeInspectorVars: vi.fn(), + deleteAllInspectorVars: vi.fn(), + isInspectVarEdited: vi.fn(), + resetToLastRunVar: vi.fn(), + invalidateSysVarValues: vi.fn(), + resetConversationVar: vi.fn(), + invalidateConversationVarValues: vi.fn(), + }), +})) + +// Mock workflow store +const mockSetRagPipelineVariables = vi.fn() +const mockSetEnvironmentVariables = vi.fn() +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ + getState: () => ({ + setRagPipelineVariables: mockSetRagPipelineVariables, + setEnvironmentVariables: mockSetEnvironmentVariables, + }), + }), +})) + +// Mock workflow hooks +vi.mock('@/app/components/workflow/hooks/use-fetch-workflow-inspect-vars', () => ({ + useSetWorkflowVarsWithValue: () => ({ + fetchInspectVars: vi.fn(), + }), +})) + +// Mock WorkflowWithInnerContext +vi.mock('@/app/components/workflow', () => ({ + WorkflowWithInnerContext: ({ children, onWorkflowDataUpdate }: PropsWithChildren<{ onWorkflowDataUpdate?: (payload: unknown) => void }>) => ( +
    + {children} + + +
    + ), +})) + +// Mock RagPipelineChildren +vi.mock('./rag-pipeline-children', () => ({ + default: () =>
    Children
    , +})) + +afterEach(() => { + cleanup() + vi.clearAllMocks() +}) + +describe('RagPipelineMain', () => { + const defaultProps = { + nodes: [] as Node[], + edges: [] as Edge[], + viewport: { x: 0, y: 0, zoom: 1 } as Viewport, + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering', () => { + it('should render without crashing', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should render RagPipelineChildren', () => { + render() + + expect(screen.getByTestId('rag-pipeline-children')).toBeInTheDocument() + }) + + it('should pass nodes to WorkflowWithInnerContext', () => { + const nodes = [{ id: 'node-1', type: 'custom', position: { x: 0, y: 0 }, data: {} }] as Node[] + + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should pass edges to WorkflowWithInnerContext', () => { + const edges = [{ id: 'edge-1', source: 'node-1', target: 'node-2' }] as Edge[] + + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should pass viewport to WorkflowWithInnerContext', () => { + const viewport = { x: 100, y: 200, zoom: 1.5 } + + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + }) + + describe('handleWorkflowDataUpdate callback', () => { + it('should update rag_pipeline_variables when provided', () => { + render() + + const button = screen.getByTestId('trigger-update') + button.click() + + expect(mockSetRagPipelineVariables).toHaveBeenCalledWith([{ id: '1', name: 'var1' }]) + }) + + it('should update environment_variables when provided', () => { + render() + + const button = screen.getByTestId('trigger-update') + button.click() + + expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([{ id: '2', name: 'env1' }]) + }) + + it('should only update rag_pipeline_variables when environment_variables is not provided', () => { + render() + + const button = screen.getByTestId('trigger-update-partial') + button.click() + + expect(mockSetRagPipelineVariables).toHaveBeenCalledWith([{ id: '3', name: 'var2' }]) + expect(mockSetEnvironmentVariables).not.toHaveBeenCalled() + }) + }) + + describe('hooks integration', () => { + it('should use useNodesSyncDraft hook', () => { + render() + + // If the component renders, the hook was called successfully + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should use usePipelineRefreshDraft hook', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should use usePipelineRun hook', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should use usePipelineStartRun hook', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should use useAvailableNodesMetaData hook', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should use useGetRunAndTraceUrl hook', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should use useDSL hook', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should use useConfigsMap hook', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should use useInspectVarsCrud hook', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + }) + + describe('edge cases', () => { + it('should handle empty nodes array', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should handle empty edges array', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + + it('should handle default viewport', () => { + render() + + expect(screen.getByTestId('workflow-inner-context')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/rag-pipeline/components/update-dsl-modal.spec.tsx b/web/app/components/rag-pipeline/components/update-dsl-modal.spec.tsx new file mode 100644 index 0000000000..b96d3dfb1f --- /dev/null +++ b/web/app/components/rag-pipeline/components/update-dsl-modal.spec.tsx @@ -0,0 +1,1076 @@ +import type { PropsWithChildren } from 'react' +import { cleanup, fireEvent, render, screen, waitFor } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { DSLImportStatus } from '@/models/app' +import UpdateDSLModal from './update-dsl-modal' + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock use-context-selector +const mockNotify = vi.fn() +vi.mock('use-context-selector', () => ({ + useContext: () => ({ notify: mockNotify }), +})) + +// Mock toast context +vi.mock('@/app/components/base/toast', () => ({ + ToastContext: { Provider: ({ children }: PropsWithChildren) => children }, +})) + +// Mock event emitter +const mockEmit = vi.fn() +vi.mock('@/context/event-emitter', () => ({ + useEventEmitterContextContext: () => ({ + eventEmitter: { emit: mockEmit }, + }), +})) + +// Mock workflow store +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ + getState: () => ({ + pipelineId: 'test-pipeline-id', + }), + }), +})) + +// Mock workflow utils +vi.mock('@/app/components/workflow/utils', () => ({ + initialNodes: (nodes: unknown[]) => nodes, + initialEdges: (edges: unknown[]) => edges, +})) + +// Mock plugin dependencies +const mockHandleCheckPluginDependencies = vi.fn() +vi.mock('@/app/components/workflow/plugin-dependency/hooks', () => ({ + usePluginDependencies: () => ({ + handleCheckPluginDependencies: mockHandleCheckPluginDependencies, + }), +})) + +// Mock pipeline service +const mockImportDSL = vi.fn() +const mockImportDSLConfirm = vi.fn() +vi.mock('@/service/use-pipeline', () => ({ + useImportPipelineDSL: () => ({ mutateAsync: mockImportDSL }), + useImportPipelineDSLConfirm: () => ({ mutateAsync: mockImportDSLConfirm }), +})) + +// Mock workflow service +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: vi.fn().mockResolvedValue({ + graph: { nodes: [], edges: [], viewport: { x: 0, y: 0, zoom: 1 } }, + hash: 'test-hash', + rag_pipeline_variables: [], + }), +})) + +// Mock Uploader +vi.mock('@/app/components/app/create-from-dsl-modal/uploader', () => ({ + default: ({ updateFile }: { updateFile: (file?: File) => void }) => ( +
    + { + const file = e.target.files?.[0] + updateFile(file) + }} + /> + +
    + ), +})) + +// Mock Button +vi.mock('@/app/components/base/button', () => ({ + default: ({ children, onClick, disabled, className, variant, loading }: { + children: React.ReactNode + onClick?: () => void + disabled?: boolean + className?: string + variant?: string + loading?: boolean + }) => ( + + ), +})) + +// Mock Modal +vi.mock('@/app/components/base/modal', () => ({ + default: ({ children, isShow, _onClose, className }: PropsWithChildren<{ + isShow: boolean + _onClose: () => void + className?: string + }>) => isShow + ? ( +
    + {children} +
    + ) + : null, +})) + +// Mock workflow constants +vi.mock('@/app/components/workflow/constants', () => ({ + WORKFLOW_DATA_UPDATE: 'WORKFLOW_DATA_UPDATE', +})) + +// Mock FileReader +class MockFileReader { + result: string | null = null + onload: ((e: { target: { result: string | null } }) => void) | null = null + + readAsText(_file: File) { + // Simulate async file reading + setTimeout(() => { + this.result = 'test file content' + if (this.onload) { + this.onload({ target: { result: this.result } }) + } + }, 0) + } +} + +afterEach(() => { + cleanup() + vi.clearAllMocks() +}) + +describe('UpdateDSLModal', () => { + const mockOnCancel = vi.fn() + const mockOnBackup = vi.fn() + const mockOnImport = vi.fn() + let originalFileReader: typeof FileReader + + const defaultProps = { + onCancel: mockOnCancel, + onBackup: mockOnBackup, + onImport: mockOnImport, + } + + beforeEach(() => { + vi.clearAllMocks() + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + + // Mock FileReader + originalFileReader = globalThis.FileReader + globalThis.FileReader = MockFileReader as unknown as typeof FileReader + }) + + afterEach(() => { + globalThis.FileReader = originalFileReader + }) + + describe('rendering', () => { + it('should render without crashing', () => { + render() + + expect(screen.getByTestId('modal')).toBeInTheDocument() + }) + + it('should render title', () => { + render() + + // The component uses t('common.importDSL', { ns: 'workflow' }) which returns 'common.importDSL' + expect(screen.getByText('common.importDSL')).toBeInTheDocument() + }) + + it('should render warning tip', () => { + render() + + // The component uses t('common.importDSLTip', { ns: 'workflow' }) + expect(screen.getByText('common.importDSLTip')).toBeInTheDocument() + }) + + it('should render uploader', () => { + render() + + expect(screen.getByTestId('uploader')).toBeInTheDocument() + }) + + it('should render backup button', () => { + render() + + // The component uses t('common.backupCurrentDraft', { ns: 'workflow' }) + expect(screen.getByText('common.backupCurrentDraft')).toBeInTheDocument() + }) + + it('should render cancel button', () => { + render() + + // The component uses t('newApp.Cancel', { ns: 'app' }) + expect(screen.getByText('newApp.Cancel')).toBeInTheDocument() + }) + + it('should render import button', () => { + render() + + // The component uses t('common.overwriteAndImport', { ns: 'workflow' }) + expect(screen.getByText('common.overwriteAndImport')).toBeInTheDocument() + }) + + it('should render choose DSL section', () => { + render() + + // The component uses t('common.chooseDSL', { ns: 'workflow' }) + expect(screen.getByText('common.chooseDSL')).toBeInTheDocument() + }) + }) + + describe('user interactions', () => { + it('should call onCancel when cancel button is clicked', () => { + render() + + const cancelButton = screen.getByText('newApp.Cancel') + fireEvent.click(cancelButton) + + expect(mockOnCancel).toHaveBeenCalled() + }) + + it('should call onBackup when backup button is clicked', () => { + render() + + const backupButton = screen.getByText('common.backupCurrentDraft') + fireEvent.click(backupButton) + + expect(mockOnBackup).toHaveBeenCalled() + }) + + it('should handle file upload', async () => { + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + + fireEvent.change(fileInput, { target: { files: [file] } }) + + // File should be processed + await waitFor(() => { + expect(screen.getByTestId('uploader')).toBeInTheDocument() + }) + }) + + it('should clear file when clear button is clicked', () => { + render() + + const clearButton = screen.getByTestId('clear-file') + fireEvent.click(clearButton) + + // File should be cleared + expect(screen.getByTestId('uploader')).toBeInTheDocument() + }) + + it('should call onCancel when close icon is clicked', () => { + render() + + // The close icon is in a div with onClick={onCancel} + const closeIconContainer = document.querySelector('.cursor-pointer') + if (closeIconContainer) { + fireEvent.click(closeIconContainer) + expect(mockOnCancel).toHaveBeenCalled() + } + }) + }) + + describe('import functionality', () => { + it('should show import button disabled when no file is selected', () => { + render() + + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).toBeDisabled() + }) + + it('should enable import button when file is selected', async () => { + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + }) + + it('should disable import button after file is cleared', async () => { + render() + + // First select a file + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + // Clear the file + const clearButton = screen.getByTestId('clear-file') + fireEvent.click(clearButton) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).toBeDisabled() + }) + }) + }) + + describe('memoization', () => { + it('should be wrapped with React.memo', () => { + expect((UpdateDSLModal as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) + }) + + describe('edge cases', () => { + it('should handle missing onImport callback', () => { + const props = { + onCancel: mockOnCancel, + onBackup: mockOnBackup, + } + + render() + + expect(screen.getByTestId('modal')).toBeInTheDocument() + }) + + it('should render import button with warning variant', () => { + render() + + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).toHaveAttribute('data-variant', 'warning') + }) + + it('should render backup button with secondary variant', () => { + render() + + // The backup button text is inside a nested div, so we need to find the closest button + const backupButtonText = screen.getByText('common.backupCurrentDraft') + const backupButton = backupButtonText.closest('button') + expect(backupButton).toHaveAttribute('data-variant', 'secondary') + }) + }) + + describe('import flow', () => { + it('should call importDSL when import button is clicked with file content', async () => { + render() + + // Select a file + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + // Wait for FileReader to process + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + // Click import button + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + // Wait for import to be called + await waitFor(() => { + expect(mockImportDSL).toHaveBeenCalled() + }) + }) + + it('should show success notification on completed import', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + // Select a file and click import + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'success', + })) + }) + }) + + it('should call onCancel after successful import', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(mockOnCancel).toHaveBeenCalled() + }) + }) + + it('should call onImport after successful import', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(mockOnImport).toHaveBeenCalled() + }) + }) + + it('should show warning notification on import with warnings', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.COMPLETED_WITH_WARNINGS, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'warning', + })) + }) + }) + + it('should show error notification when import fails', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.FAILED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + }) + + it('should show error notification when pipeline_id is missing on success', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.COMPLETED, + pipeline_id: undefined, + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + }) + + it('should show error notification when import throws exception', async () => { + mockImportDSL.mockRejectedValue(new Error('Import failed')) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + // Wait for FileReader to complete (setTimeout 0) and button to be enabled + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + // Give extra time for the FileReader's setTimeout to complete + await new Promise(resolve => setTimeout(resolve, 10)) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + }) + + it('should call handleCheckPluginDependencies on successful import', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(mockHandleCheckPluginDependencies).toHaveBeenCalledWith('test-pipeline-id', true) + }) + }) + + it('should emit WORKFLOW_DATA_UPDATE event after successful import', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(mockEmit).toHaveBeenCalled() + }) + }) + + it('should show error modal when import status is PENDING', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + // Wait for the error modal to be shown after setTimeout + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }, { timeout: 500 }) + }) + + it('should show version info in error modal', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + // Wait for error modal with version info + await waitFor(() => { + expect(screen.getByText('1.0.0')).toBeInTheDocument() + expect(screen.getByText('2.0.0')).toBeInTheDocument() + }, { timeout: 500 }) + }) + + it('should close error modal when cancel button is clicked', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + // Wait for error modal + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }, { timeout: 500 }) + + // Find and click cancel button in error modal - it should be the one with secondary variant + const cancelButtons = screen.getAllByText('newApp.Cancel') + const errorModalCancelButton = cancelButtons.find(btn => + btn.getAttribute('data-variant') === 'secondary', + ) + if (errorModalCancelButton) { + fireEvent.click(errorModalCancelButton) + } + + // Modal should be closed + await waitFor(() => { + expect(screen.queryByText('newApp.appCreateDSLErrorTitle')).not.toBeInTheDocument() + }) + }) + + it('should call importDSLConfirm when confirm button is clicked in error modal', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + mockImportDSLConfirm.mockResolvedValue({ + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + // Wait for error modal + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }, { timeout: 500 }) + + // Click confirm button + const confirmButton = screen.getByText('newApp.Confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(mockImportDSLConfirm).toHaveBeenCalledWith('import-id') + }) + }) + + it('should show success notification after confirm completes', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + mockImportDSLConfirm.mockResolvedValue({ + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }, { timeout: 500 }) + + const confirmButton = screen.getByText('newApp.Confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'success', + })) + }) + }) + + it('should show error notification when confirm fails with FAILED status', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + mockImportDSLConfirm.mockResolvedValue({ + status: DSLImportStatus.FAILED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }, { timeout: 500 }) + + const confirmButton = screen.getByText('newApp.Confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + }) + + it('should show error notification when confirm throws exception', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + mockImportDSLConfirm.mockRejectedValue(new Error('Confirm failed')) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }, { timeout: 500 }) + + const confirmButton = screen.getByText('newApp.Confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + }) + + it('should show error when confirm completes but pipeline_id is missing', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + mockImportDSLConfirm.mockResolvedValue({ + status: DSLImportStatus.COMPLETED, + pipeline_id: undefined, + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }, { timeout: 500 }) + + const confirmButton = screen.getByText('newApp.Confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ + type: 'error', + })) + }) + }) + + it('should call onImport after confirm completes successfully', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + mockImportDSLConfirm.mockResolvedValue({ + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }, { timeout: 500 }) + + const confirmButton = screen.getByText('newApp.Confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(mockOnImport).toHaveBeenCalled() + }) + }) + + it('should call handleCheckPluginDependencies after confirm', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: '1.0.0', + current_dsl_version: '2.0.0', + }) + + mockImportDSLConfirm.mockResolvedValue({ + status: DSLImportStatus.COMPLETED, + pipeline_id: 'test-pipeline-id', + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }, { timeout: 500 }) + + const confirmButton = screen.getByText('newApp.Confirm') + fireEvent.click(confirmButton) + + await waitFor(() => { + expect(mockHandleCheckPluginDependencies).toHaveBeenCalledWith('test-pipeline-id', true) + }) + }) + + it('should handle undefined imported_dsl_version and current_dsl_version', async () => { + mockImportDSL.mockResolvedValue({ + id: 'import-id', + status: DSLImportStatus.PENDING, + pipeline_id: 'test-pipeline-id', + imported_dsl_version: undefined, + current_dsl_version: undefined, + }) + + render() + + const fileInput = screen.getByTestId('file-input') + const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' }) + fireEvent.change(fileInput, { target: { files: [file] } }) + + await waitFor(() => { + const importButton = screen.getByText('common.overwriteAndImport') + expect(importButton).not.toBeDisabled() + }) + + const importButton = screen.getByText('common.overwriteAndImport') + fireEvent.click(importButton) + + // Should show error modal even with undefined versions + await waitFor(() => { + expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + }, { timeout: 500 }) + }) + + it('should not call importDSLConfirm when importId is not set', async () => { + // Render without triggering PENDING status first + render() + + // importId is not set, so confirm should not be called + // This is hard to test directly, but we can verify by checking the confirm flow + expect(mockImportDSLConfirm).not.toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/rag-pipeline/hooks/index.spec.ts b/web/app/components/rag-pipeline/hooks/index.spec.ts new file mode 100644 index 0000000000..7917275c18 --- /dev/null +++ b/web/app/components/rag-pipeline/hooks/index.spec.ts @@ -0,0 +1,536 @@ +import type { RAGPipelineVariables, VAR_TYPE_MAP } from '@/models/pipeline' +import { renderHook } from '@testing-library/react' +import { act } from 'react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { BlockEnum } from '@/app/components/workflow/types' +import { Resolution, TransferMethod } from '@/types/app' +import { FlowType } from '@/types/common' + +// ============================================================================ +// Import hooks after mocks +// ============================================================================ + +import { + useAvailableNodesMetaData, + useDSL, + useGetRunAndTraceUrl, + useInputFieldPanel, + useNodesSyncDraft, + usePipelineInit, + usePipelineRefreshDraft, + usePipelineRun, + usePipelineStartRun, +} from './index' +import { useConfigsMap } from './use-configs-map' +import { useConfigurations, useInitialData } from './use-input-fields' +import { usePipelineTemplate } from './use-pipeline-template' + +// ============================================================================ +// Mocks +// ============================================================================ + +// Mock the workflow store +const _mockGetState = vi.fn() +const mockUseStore = vi.fn() +const mockUseWorkflowStore = vi.fn() + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: Record) => unknown) => mockUseStore(selector), + useWorkflowStore: () => mockUseWorkflowStore(), +})) + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock toast context +const mockNotify = vi.fn() +vi.mock('@/app/components/base/toast', () => ({ + useToastContext: () => ({ + notify: mockNotify, + }), +})) + +// Mock event emitter context +const mockEventEmit = vi.fn() +vi.mock('@/context/event-emitter', () => ({ + useEventEmitterContextContext: () => ({ + eventEmitter: { + emit: mockEventEmit, + }, + }), +})) + +// Mock i18n docLink +vi.mock('@/context/i18n', () => ({ + useDocLink: () => (path: string) => `https://docs.dify.ai${path}`, +})) + +// Mock workflow constants +vi.mock('@/app/components/workflow/constants', () => ({ + DSL_EXPORT_CHECK: 'DSL_EXPORT_CHECK', + WORKFLOW_DATA_UPDATE: 'WORKFLOW_DATA_UPDATE', + START_INITIAL_POSITION: { x: 100, y: 100 }, +})) + +// Mock workflow constants/node +vi.mock('@/app/components/workflow/constants/node', () => ({ + WORKFLOW_COMMON_NODES: [ + { + metaData: { type: BlockEnum.Start }, + defaultValue: { type: BlockEnum.Start }, + }, + { + metaData: { type: BlockEnum.End }, + defaultValue: { type: BlockEnum.End }, + }, + ], +})) + +// Mock data source defaults +vi.mock('@/app/components/workflow/nodes/data-source-empty/default', () => ({ + default: { + metaData: { type: BlockEnum.DataSourceEmpty }, + defaultValue: { type: BlockEnum.DataSourceEmpty }, + }, +})) + +vi.mock('@/app/components/workflow/nodes/data-source/default', () => ({ + default: { + metaData: { type: BlockEnum.DataSource }, + defaultValue: { type: BlockEnum.DataSource }, + }, +})) + +vi.mock('@/app/components/workflow/nodes/knowledge-base/default', () => ({ + default: { + metaData: { type: BlockEnum.KnowledgeBase }, + defaultValue: { type: BlockEnum.KnowledgeBase }, + }, +})) + +// Mock workflow utils with all needed exports +vi.mock('@/app/components/workflow/utils', async (importOriginal) => { + const actual = await importOriginal() as Record + return { + ...actual, + generateNewNode: ({ id, data, position }: { id: string, data: object, position: { x: number, y: number } }) => ({ + newNode: { id, data, position, type: 'custom' }, + }), + } +}) + +// Mock pipeline service +const mockExportPipelineConfig = vi.fn() +vi.mock('@/service/use-pipeline', () => ({ + useExportPipelineDSL: () => ({ + mutateAsync: mockExportPipelineConfig, + }), +})) + +// Mock workflow service +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: vi.fn().mockResolvedValue({ + graph: { nodes: [], edges: [], viewport: {} }, + environment_variables: [], + }), +})) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('useConfigsMap', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseStore.mockImplementation((selector: (state: Record) => unknown) => { + const state = { + pipelineId: 'test-pipeline-id', + fileUploadConfig: { max_file_size: 10 }, + } + return selector(state) + }) + }) + + it('should return config map with correct flowId', () => { + const { result } = renderHook(() => useConfigsMap()) + + expect(result.current.flowId).toBe('test-pipeline-id') + }) + + it('should return config map with correct flowType', () => { + const { result } = renderHook(() => useConfigsMap()) + + expect(result.current.flowType).toBe(FlowType.ragPipeline) + }) + + it('should return file settings with image config', () => { + const { result } = renderHook(() => useConfigsMap()) + + expect(result.current.fileSettings.image).toEqual({ + enabled: false, + detail: Resolution.high, + number_limits: 3, + transfer_methods: [TransferMethod.local_file, TransferMethod.remote_url], + }) + }) + + it('should include fileUploadConfig from store', () => { + const { result } = renderHook(() => useConfigsMap()) + + expect(result.current.fileSettings.fileUploadConfig).toEqual({ max_file_size: 10 }) + }) +}) + +describe('useGetRunAndTraceUrl', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseWorkflowStore.mockReturnValue({ + getState: () => ({ + pipelineId: 'pipeline-123', + }), + }) + }) + + it('should return getWorkflowRunAndTraceUrl function', () => { + const { result } = renderHook(() => useGetRunAndTraceUrl()) + + expect(result.current.getWorkflowRunAndTraceUrl).toBeDefined() + expect(typeof result.current.getWorkflowRunAndTraceUrl).toBe('function') + }) + + it('should generate correct run URL', () => { + const { result } = renderHook(() => useGetRunAndTraceUrl()) + + const { runUrl } = result.current.getWorkflowRunAndTraceUrl('run-456') + + expect(runUrl).toBe('/rag/pipelines/pipeline-123/workflow-runs/run-456') + }) + + it('should generate correct trace URL', () => { + const { result } = renderHook(() => useGetRunAndTraceUrl()) + + const { traceUrl } = result.current.getWorkflowRunAndTraceUrl('run-456') + + expect(traceUrl).toBe('/rag/pipelines/pipeline-123/workflow-runs/run-456/node-executions') + }) +}) + +describe('useInputFieldPanel', () => { + const mockSetShowInputFieldPanel = vi.fn() + const mockSetShowInputFieldPreviewPanel = vi.fn() + const mockSetInputFieldEditPanelProps = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + mockUseStore.mockImplementation((selector: (state: Record) => unknown) => { + const state = { + showInputFieldPreviewPanel: false, + inputFieldEditPanelProps: null, + } + return selector(state) + }) + mockUseWorkflowStore.mockReturnValue({ + getState: () => ({ + showInputFieldPreviewPanel: false, + setShowInputFieldPanel: mockSetShowInputFieldPanel, + setShowInputFieldPreviewPanel: mockSetShowInputFieldPreviewPanel, + setInputFieldEditPanelProps: mockSetInputFieldEditPanelProps, + }), + }) + }) + + it('should return isPreviewing as false when showInputFieldPreviewPanel is false', () => { + const { result } = renderHook(() => useInputFieldPanel()) + + expect(result.current.isPreviewing).toBe(false) + }) + + it('should return isPreviewing as true when showInputFieldPreviewPanel is true', () => { + mockUseStore.mockImplementation((selector: (state: Record) => unknown) => { + const state = { + showInputFieldPreviewPanel: true, + inputFieldEditPanelProps: null, + } + return selector(state) + }) + + const { result } = renderHook(() => useInputFieldPanel()) + + expect(result.current.isPreviewing).toBe(true) + }) + + it('should return isEditing as false when inputFieldEditPanelProps is null', () => { + const { result } = renderHook(() => useInputFieldPanel()) + + expect(result.current.isEditing).toBe(false) + }) + + it('should return isEditing as true when inputFieldEditPanelProps exists', () => { + mockUseStore.mockImplementation((selector: (state: Record) => unknown) => { + const state = { + showInputFieldPreviewPanel: false, + inputFieldEditPanelProps: { some: 'props' }, + } + return selector(state) + }) + + const { result } = renderHook(() => useInputFieldPanel()) + + expect(result.current.isEditing).toBe(true) + }) + + it('should call all setters when closeAllInputFieldPanels is called', () => { + const { result } = renderHook(() => useInputFieldPanel()) + + act(() => { + result.current.closeAllInputFieldPanels() + }) + + expect(mockSetShowInputFieldPanel).toHaveBeenCalledWith(false) + expect(mockSetShowInputFieldPreviewPanel).toHaveBeenCalledWith(false) + expect(mockSetInputFieldEditPanelProps).toHaveBeenCalledWith(null) + }) + + it('should toggle preview panel when toggleInputFieldPreviewPanel is called', () => { + const { result } = renderHook(() => useInputFieldPanel()) + + act(() => { + result.current.toggleInputFieldPreviewPanel() + }) + + expect(mockSetShowInputFieldPreviewPanel).toHaveBeenCalledWith(true) + }) + + it('should set edit panel props when toggleInputFieldEditPanel is called', () => { + const { result } = renderHook(() => useInputFieldPanel()) + const editContent = { type: 'edit', data: {} } + + act(() => { + // eslint-disable-next-line ts/no-explicit-any + result.current.toggleInputFieldEditPanel(editContent as any) + }) + + expect(mockSetInputFieldEditPanelProps).toHaveBeenCalledWith(editContent) + }) +}) + +describe('useInitialData', () => { + it('should return empty object for empty variables', () => { + const { result } = renderHook(() => useInitialData([], undefined)) + + expect(result.current).toEqual({}) + }) + + it('should handle text input type with default value', () => { + const variables: RAGPipelineVariables = [ + { + type: 'text-input' as keyof typeof VAR_TYPE_MAP, + variable: 'textVar', + label: 'Text', + required: false, + default_value: 'default text', + belong_to_node_id: 'node-1', + }, + ] + + const { result } = renderHook(() => useInitialData(variables, undefined)) + + expect(result.current.textVar).toBe('default text') + }) + + it('should use lastRunInputData over default value', () => { + const variables: RAGPipelineVariables = [ + { + type: 'text-input' as keyof typeof VAR_TYPE_MAP, + variable: 'textVar', + label: 'Text', + required: false, + default_value: 'default text', + belong_to_node_id: 'node-1', + }, + ] + + const { result } = renderHook(() => useInitialData(variables, { textVar: 'last run value' })) + + expect(result.current.textVar).toBe('last run value') + }) + + it('should handle number input type with default 0', () => { + const variables: RAGPipelineVariables = [ + { + type: 'number' as keyof typeof VAR_TYPE_MAP, + variable: 'numVar', + label: 'Number', + required: false, + belong_to_node_id: 'node-1', + }, + ] + + const { result } = renderHook(() => useInitialData(variables, undefined)) + + expect(result.current.numVar).toBe(0) + }) + + it('should handle file type with default empty array', () => { + const variables: RAGPipelineVariables = [ + { + type: 'file' as keyof typeof VAR_TYPE_MAP, + variable: 'fileVar', + label: 'File', + required: false, + belong_to_node_id: 'node-1', + }, + ] + + const { result } = renderHook(() => useInitialData(variables, undefined)) + + expect(result.current.fileVar).toEqual([]) + }) +}) + +describe('useConfigurations', () => { + it('should return empty array for empty variables', () => { + const { result } = renderHook(() => useConfigurations([])) + + expect(result.current).toEqual([]) + }) + + it('should transform variables to configurations', () => { + const variables: RAGPipelineVariables = [ + { + type: 'text-input' as keyof typeof VAR_TYPE_MAP, + variable: 'textVar', + label: 'Text Label', + required: true, + max_length: 100, + placeholder: 'Enter text', + tooltips: 'Help text', + belong_to_node_id: 'node-1', + }, + ] + + const { result } = renderHook(() => useConfigurations(variables)) + + expect(result.current.length).toBe(1) + expect(result.current[0].variable).toBe('textVar') + expect(result.current[0].label).toBe('Text Label') + expect(result.current[0].required).toBe(true) + expect(result.current[0].maxLength).toBe(100) + expect(result.current[0].placeholder).toBe('Enter text') + expect(result.current[0].tooltip).toBe('Help text') + }) + + it('should transform options correctly', () => { + const variables: RAGPipelineVariables = [ + { + type: 'select' as keyof typeof VAR_TYPE_MAP, + variable: 'selectVar', + label: 'Select', + required: false, + options: ['option1', 'option2', 'option3'], + belong_to_node_id: 'node-1', + }, + ] + + const { result } = renderHook(() => useConfigurations(variables)) + + expect(result.current[0].options).toEqual([ + { label: 'option1', value: 'option1' }, + { label: 'option2', value: 'option2' }, + { label: 'option3', value: 'option3' }, + ]) + }) +}) + +describe('useAvailableNodesMetaData', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should return nodes array', () => { + const { result } = renderHook(() => useAvailableNodesMetaData()) + + expect(result.current.nodes).toBeDefined() + expect(Array.isArray(result.current.nodes)).toBe(true) + }) + + it('should return nodesMap object', () => { + const { result } = renderHook(() => useAvailableNodesMetaData()) + + expect(result.current.nodesMap).toBeDefined() + expect(typeof result.current.nodesMap).toBe('object') + }) +}) + +describe('usePipelineTemplate', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should return nodes array with knowledge base node', () => { + const { result } = renderHook(() => usePipelineTemplate()) + + expect(result.current.nodes).toBeDefined() + expect(Array.isArray(result.current.nodes)).toBe(true) + expect(result.current.nodes.length).toBe(1) + }) + + it('should return empty edges array', () => { + const { result } = renderHook(() => usePipelineTemplate()) + + expect(result.current.edges).toEqual([]) + }) +}) + +describe('useDSL', () => { + it('should be defined and exported', () => { + expect(useDSL).toBeDefined() + expect(typeof useDSL).toBe('function') + }) +}) + +describe('exports', () => { + it('should export useAvailableNodesMetaData', () => { + expect(useAvailableNodesMetaData).toBeDefined() + }) + + it('should export useDSL', () => { + expect(useDSL).toBeDefined() + }) + + it('should export useGetRunAndTraceUrl', () => { + expect(useGetRunAndTraceUrl).toBeDefined() + }) + + it('should export useInputFieldPanel', () => { + expect(useInputFieldPanel).toBeDefined() + }) + + it('should export useNodesSyncDraft', () => { + expect(useNodesSyncDraft).toBeDefined() + }) + + it('should export usePipelineInit', () => { + expect(usePipelineInit).toBeDefined() + }) + + it('should export usePipelineRefreshDraft', () => { + expect(usePipelineRefreshDraft).toBeDefined() + }) + + it('should export usePipelineRun', () => { + expect(usePipelineRun).toBeDefined() + }) + + it('should export usePipelineStartRun', () => { + expect(usePipelineStartRun).toBeDefined() + }) +}) + +afterEach(() => { + vi.clearAllMocks() +}) diff --git a/web/app/components/rag-pipeline/hooks/use-DSL.spec.ts b/web/app/components/rag-pipeline/hooks/use-DSL.spec.ts new file mode 100644 index 0000000000..0f235516e0 --- /dev/null +++ b/web/app/components/rag-pipeline/hooks/use-DSL.spec.ts @@ -0,0 +1,368 @@ +import { renderHook } from '@testing-library/react' +import { act } from 'react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +// ============================================================================ +// Import after mocks +// ============================================================================ + +import { useDSL } from './use-DSL' + +// ============================================================================ +// Mocks +// ============================================================================ + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock toast context +const mockNotify = vi.fn() +vi.mock('@/app/components/base/toast', () => ({ + useToastContext: () => ({ + notify: mockNotify, + }), +})) + +// Mock event emitter context +const mockEmit = vi.fn() +vi.mock('@/context/event-emitter', () => ({ + useEventEmitterContextContext: () => ({ + eventEmitter: { + emit: mockEmit, + }, + }), +})) + +// Mock workflow store +const mockWorkflowStoreGetState = vi.fn() +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ + getState: mockWorkflowStoreGetState, + }), +})) + +// Mock useNodesSyncDraft +const mockDoSyncWorkflowDraft = vi.fn() +vi.mock('./use-nodes-sync-draft', () => ({ + useNodesSyncDraft: () => ({ + doSyncWorkflowDraft: mockDoSyncWorkflowDraft, + }), +})) + +// Mock pipeline service +const mockExportPipelineConfig = vi.fn() +vi.mock('@/service/use-pipeline', () => ({ + useExportPipelineDSL: () => ({ + mutateAsync: mockExportPipelineConfig, + }), +})) + +// Mock workflow service +const mockFetchWorkflowDraft = vi.fn() +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: (url: string) => mockFetchWorkflowDraft(url), +})) + +// Mock workflow constants +vi.mock('@/app/components/workflow/constants', () => ({ + DSL_EXPORT_CHECK: 'DSL_EXPORT_CHECK', +})) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('useDSL', () => { + let mockLink: { href: string, download: string, click: ReturnType } + let originalCreateElement: typeof document.createElement + let mockCreateObjectURL: ReturnType + let mockRevokeObjectURL: ReturnType + + beforeEach(() => { + vi.clearAllMocks() + + // Create a proper mock link element + mockLink = { + href: '', + download: '', + click: vi.fn(), + } + + // Save original and mock selectively - only intercept 'a' elements + originalCreateElement = document.createElement.bind(document) + document.createElement = vi.fn((tagName: string) => { + if (tagName === 'a') { + return mockLink as unknown as HTMLElement + } + return originalCreateElement(tagName) + }) as typeof document.createElement + + mockCreateObjectURL = vi.spyOn(URL, 'createObjectURL').mockReturnValue('blob:test-url') + mockRevokeObjectURL = vi.spyOn(URL, 'revokeObjectURL').mockImplementation(() => {}) + + // Default store state + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + knowledgeName: 'Test Knowledge Base', + }) + + mockDoSyncWorkflowDraft.mockResolvedValue(undefined) + mockExportPipelineConfig.mockResolvedValue({ data: 'yaml-content' }) + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: [], + }) + }) + + afterEach(() => { + document.createElement = originalCreateElement + mockCreateObjectURL.mockRestore() + mockRevokeObjectURL.mockRestore() + vi.clearAllMocks() + }) + + describe('hook initialization', () => { + it('should return exportCheck function', () => { + const { result } = renderHook(() => useDSL()) + + expect(result.current.exportCheck).toBeDefined() + expect(typeof result.current.exportCheck).toBe('function') + }) + + it('should return handleExportDSL function', () => { + const { result } = renderHook(() => useDSL()) + + expect(result.current.handleExportDSL).toBeDefined() + expect(typeof result.current.handleExportDSL).toBe('function') + }) + }) + + describe('handleExportDSL', () => { + it('should not export when pipelineId is missing', async () => { + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: undefined, + knowledgeName: 'Test', + }) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.handleExportDSL() + }) + + expect(mockDoSyncWorkflowDraft).not.toHaveBeenCalled() + expect(mockExportPipelineConfig).not.toHaveBeenCalled() + }) + + it('should sync workflow draft before export', async () => { + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.handleExportDSL() + }) + + expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() + }) + + it('should call exportPipelineConfig with correct params', async () => { + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.handleExportDSL(true) + }) + + expect(mockExportPipelineConfig).toHaveBeenCalledWith({ + pipelineId: 'test-pipeline-id', + include: true, + }) + }) + + it('should create and download file', async () => { + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.handleExportDSL() + }) + + expect(document.createElement).toHaveBeenCalledWith('a') + expect(mockCreateObjectURL).toHaveBeenCalled() + expect(mockRevokeObjectURL).toHaveBeenCalledWith('blob:test-url') + }) + + it('should use correct file extension for download', async () => { + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.handleExportDSL() + }) + + expect(mockLink.download).toBe('Test Knowledge Base.pipeline') + }) + + it('should trigger download click', async () => { + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.handleExportDSL() + }) + + expect(mockLink.click).toHaveBeenCalled() + }) + + it('should show error notification on export failure', async () => { + mockExportPipelineConfig.mockRejectedValue(new Error('Export failed')) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.handleExportDSL() + }) + + expect(mockNotify).toHaveBeenCalledWith({ + type: 'error', + message: 'exportFailed', + }) + }) + }) + + describe('exportCheck', () => { + it('should not check when pipelineId is missing', async () => { + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: undefined, + knowledgeName: 'Test', + }) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(mockFetchWorkflowDraft).not.toHaveBeenCalled() + }) + + it('should fetch workflow draft', async () => { + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(mockFetchWorkflowDraft).toHaveBeenCalledWith('/rag/pipelines/test-pipeline-id/workflows/draft') + }) + + it('should directly export when no secret environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: [ + { id: '1', value_type: 'string', value: 'test' }, + ], + }) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + // Should call doSyncWorkflowDraft (which means handleExportDSL was called) + expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() + }) + + it('should emit DSL_EXPORT_CHECK event when secret variables exist', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: [ + { id: '1', value_type: 'secret', value: 'secret-value' }, + ], + }) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(mockEmit).toHaveBeenCalledWith({ + type: 'DSL_EXPORT_CHECK', + payload: { + data: [{ id: '1', value_type: 'secret', value: 'secret-value' }], + }, + }) + }) + + it('should show error notification on check failure', async () => { + mockFetchWorkflowDraft.mockRejectedValue(new Error('Fetch failed')) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(mockNotify).toHaveBeenCalledWith({ + type: 'error', + message: 'exportFailed', + }) + }) + + it('should filter only secret environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: [ + { id: '1', value_type: 'string', value: 'plain' }, + { id: '2', value_type: 'secret', value: 'secret1' }, + { id: '3', value_type: 'number', value: '123' }, + { id: '4', value_type: 'secret', value: 'secret2' }, + ], + }) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(mockEmit).toHaveBeenCalledWith({ + type: 'DSL_EXPORT_CHECK', + payload: { + data: [ + { id: '2', value_type: 'secret', value: 'secret1' }, + { id: '4', value_type: 'secret', value: 'secret2' }, + ], + }, + }) + }) + + it('should handle empty environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: [], + }) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + // Should directly call handleExportDSL since no secrets + expect(mockEmit).not.toHaveBeenCalled() + expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() + }) + + it('should handle undefined environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: undefined, + }) + + const { result } = renderHook(() => useDSL()) + + await act(async () => { + await result.current.exportCheck() + }) + + // Should directly call handleExportDSL since no secrets + expect(mockEmit).not.toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/rag-pipeline/hooks/use-DSL.ts b/web/app/components/rag-pipeline/hooks/use-DSL.ts index 1660d555eb..5c0f9def1c 100644 --- a/web/app/components/rag-pipeline/hooks/use-DSL.ts +++ b/web/app/components/rag-pipeline/hooks/use-DSL.ts @@ -11,6 +11,7 @@ import { useWorkflowStore } from '@/app/components/workflow/store' import { useEventEmitterContextContext } from '@/context/event-emitter' import { useExportPipelineDSL } from '@/service/use-pipeline' import { fetchWorkflowDraft } from '@/service/workflow' +import { downloadBlob } from '@/utils/download' import { useNodesSyncDraft } from './use-nodes-sync-draft' export const useDSL = () => { @@ -37,13 +38,8 @@ export const useDSL = () => { pipelineId, include, }) - const a = document.createElement('a') const file = new Blob([data], { type: 'application/yaml' }) - const url = URL.createObjectURL(file) - a.href = url - a.download = `${knowledgeName}.pipeline` - a.click() - URL.revokeObjectURL(url) + downloadBlob({ data: file, fileName: `${knowledgeName}.pipeline` }) } catch { notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) diff --git a/web/app/components/rag-pipeline/hooks/use-nodes-sync-draft.spec.ts b/web/app/components/rag-pipeline/hooks/use-nodes-sync-draft.spec.ts new file mode 100644 index 0000000000..5817d187ac --- /dev/null +++ b/web/app/components/rag-pipeline/hooks/use-nodes-sync-draft.spec.ts @@ -0,0 +1,469 @@ +import { renderHook } from '@testing-library/react' +import { act } from 'react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +// ============================================================================ +// Import after mocks +// ============================================================================ + +import { useNodesSyncDraft } from './use-nodes-sync-draft' + +// ============================================================================ +// Mocks +// ============================================================================ + +// Mock reactflow +const mockGetNodes = vi.fn() +const mockStoreGetState = vi.fn() + +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ + getState: mockStoreGetState, + }), +})) + +// Mock workflow store +const mockWorkflowStoreGetState = vi.fn() +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ + getState: mockWorkflowStoreGetState, + }), +})) + +// Mock useNodesReadOnly +const mockGetNodesReadOnly = vi.fn() +vi.mock('@/app/components/workflow/hooks/use-workflow', () => ({ + useNodesReadOnly: () => ({ + getNodesReadOnly: mockGetNodesReadOnly, + }), +})) + +// Mock useSerialAsyncCallback - must pass through arguments +vi.mock('@/app/components/workflow/hooks/use-serial-async-callback', () => ({ + useSerialAsyncCallback: (fn: (...args: unknown[]) => Promise, checkFn: () => boolean) => { + return (...args: unknown[]) => { + if (!checkFn()) { + return fn(...args) + } + } + }, +})) + +// Mock service +const mockSyncWorkflowDraft = vi.fn() +vi.mock('@/service/workflow', () => ({ + syncWorkflowDraft: (params: unknown) => mockSyncWorkflowDraft(params), +})) + +// Mock usePipelineRefreshDraft +const mockHandleRefreshWorkflowDraft = vi.fn() +vi.mock('@/app/components/rag-pipeline/hooks', () => ({ + usePipelineRefreshDraft: () => ({ + handleRefreshWorkflowDraft: mockHandleRefreshWorkflowDraft, + }), +})) + +// Mock API_PREFIX +vi.mock('@/config', () => ({ + API_PREFIX: '/api', +})) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('useNodesSyncDraft', () => { + const mockSendBeacon = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + + // Setup navigator.sendBeacon mock + Object.defineProperty(navigator, 'sendBeacon', { + value: mockSendBeacon, + writable: true, + configurable: true, + }) + + // Default store state + mockStoreGetState.mockReturnValue({ + getNodes: mockGetNodes, + edges: [], + transform: [0, 0, 1], + }) + + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start', _temp: true }, position: { x: 0, y: 0 } }, + { id: 'node-2', data: { type: 'end' }, position: { x: 100, y: 0 } }, + ]) + + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + environmentVariables: [], + syncWorkflowDraftHash: 'test-hash', + ragPipelineVariables: [], + setSyncWorkflowDraftHash: vi.fn(), + setDraftUpdatedAt: vi.fn(), + }) + + mockGetNodesReadOnly.mockReturnValue(false) + mockSyncWorkflowDraft.mockResolvedValue({ + hash: 'new-hash', + updated_at: '2024-01-01T00:00:00Z', + }) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('hook initialization', () => { + it('should return doSyncWorkflowDraft function', () => { + const { result } = renderHook(() => useNodesSyncDraft()) + + expect(result.current.doSyncWorkflowDraft).toBeDefined() + expect(typeof result.current.doSyncWorkflowDraft).toBe('function') + }) + + it('should return syncWorkflowDraftWhenPageClose function', () => { + const { result } = renderHook(() => useNodesSyncDraft()) + + expect(result.current.syncWorkflowDraftWhenPageClose).toBeDefined() + expect(typeof result.current.syncWorkflowDraftWhenPageClose).toBe('function') + }) + }) + + describe('syncWorkflowDraftWhenPageClose', () => { + it('should not call sendBeacon when nodes are read only', () => { + mockGetNodesReadOnly.mockReturnValue(true) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + expect(mockSendBeacon).not.toHaveBeenCalled() + }) + + it('should call sendBeacon with correct URL and params', () => { + mockGetNodesReadOnly.mockReturnValue(false) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + expect(mockSendBeacon).toHaveBeenCalledWith( + '/api/rag/pipelines/test-pipeline-id/workflows/draft', + expect.any(String), + ) + }) + + it('should not call sendBeacon when pipelineId is missing', () => { + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: undefined, + environmentVariables: [], + syncWorkflowDraftHash: 'test-hash', + ragPipelineVariables: [], + }) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + expect(mockSendBeacon).not.toHaveBeenCalled() + }) + + it('should not call sendBeacon when nodes array is empty', () => { + mockGetNodes.mockReturnValue([]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + expect(mockSendBeacon).not.toHaveBeenCalled() + }) + + it('should filter out temp nodes', () => { + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start', _isTempNode: true }, position: { x: 0, y: 0 } }, + ]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + // Should not call sendBeacon because after filtering temp nodes, array is empty + expect(mockSendBeacon).not.toHaveBeenCalled() + }) + + it('should remove underscore-prefixed data keys from nodes', () => { + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start', _privateData: 'secret' }, position: { x: 0, y: 0 } }, + ]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + expect(mockSendBeacon).toHaveBeenCalled() + const sentData = JSON.parse(mockSendBeacon.mock.calls[0][1]) + expect(sentData.graph.nodes[0].data._privateData).toBeUndefined() + }) + }) + + describe('doSyncWorkflowDraft', () => { + it('should not sync when nodes are read only', async () => { + mockGetNodesReadOnly.mockReturnValue(true) + + const { result } = renderHook(() => useNodesSyncDraft()) + + await act(async () => { + await result.current.doSyncWorkflowDraft() + }) + + expect(mockSyncWorkflowDraft).not.toHaveBeenCalled() + }) + + it('should call syncWorkflowDraft service', async () => { + mockGetNodesReadOnly.mockReturnValue(false) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + await act(async () => { + await result.current.doSyncWorkflowDraft() + }) + + expect(mockSyncWorkflowDraft).toHaveBeenCalled() + }) + + it('should call onSuccess callback when sync succeeds', async () => { + mockGetNodesReadOnly.mockReturnValue(false) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + const onSuccess = vi.fn() + + const { result } = renderHook(() => useNodesSyncDraft()) + + await act(async () => { + await result.current.doSyncWorkflowDraft(false, { onSuccess }) + }) + + expect(onSuccess).toHaveBeenCalled() + }) + + it('should call onSettled callback after sync completes', async () => { + mockGetNodesReadOnly.mockReturnValue(false) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + const onSettled = vi.fn() + + const { result } = renderHook(() => useNodesSyncDraft()) + + await act(async () => { + await result.current.doSyncWorkflowDraft(false, { onSettled }) + }) + + expect(onSettled).toHaveBeenCalled() + }) + + it('should call onError callback when sync fails', async () => { + mockGetNodesReadOnly.mockReturnValue(false) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + mockSyncWorkflowDraft.mockRejectedValue(new Error('Sync failed')) + const onError = vi.fn() + + const { result } = renderHook(() => useNodesSyncDraft()) + + await act(async () => { + await result.current.doSyncWorkflowDraft(false, { onError }) + }) + + expect(onError).toHaveBeenCalled() + }) + + it('should update hash and draft updated at on success', async () => { + const mockSetSyncWorkflowDraftHash = vi.fn() + const mockSetDraftUpdatedAt = vi.fn() + + mockGetNodesReadOnly.mockReturnValue(false) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + environmentVariables: [], + syncWorkflowDraftHash: 'test-hash', + ragPipelineVariables: [], + setSyncWorkflowDraftHash: mockSetSyncWorkflowDraftHash, + setDraftUpdatedAt: mockSetDraftUpdatedAt, + }) + + const { result } = renderHook(() => useNodesSyncDraft()) + + await act(async () => { + await result.current.doSyncWorkflowDraft() + }) + + expect(mockSetSyncWorkflowDraftHash).toHaveBeenCalledWith('new-hash') + expect(mockSetDraftUpdatedAt).toHaveBeenCalledWith('2024-01-01T00:00:00Z') + }) + + it('should handle draft not sync error', async () => { + mockGetNodesReadOnly.mockReturnValue(false) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + + const mockJsonError = { + json: vi.fn().mockResolvedValue({ code: 'draft_workflow_not_sync' }), + bodyUsed: false, + } + mockSyncWorkflowDraft.mockRejectedValue(mockJsonError) + + const { result } = renderHook(() => useNodesSyncDraft()) + + await act(async () => { + await result.current.doSyncWorkflowDraft(false) + }) + + // Wait for json to be called + await new Promise(resolve => setTimeout(resolve, 0)) + + expect(mockHandleRefreshWorkflowDraft).toHaveBeenCalled() + }) + + it('should not refresh when notRefreshWhenSyncError is true', async () => { + mockGetNodesReadOnly.mockReturnValue(false) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + + const mockJsonError = { + json: vi.fn().mockResolvedValue({ code: 'draft_workflow_not_sync' }), + bodyUsed: false, + } + mockSyncWorkflowDraft.mockRejectedValue(mockJsonError) + + const { result } = renderHook(() => useNodesSyncDraft()) + + await act(async () => { + await result.current.doSyncWorkflowDraft(true) + }) + + // Wait for json to be called + await new Promise(resolve => setTimeout(resolve, 0)) + + expect(mockHandleRefreshWorkflowDraft).not.toHaveBeenCalled() + }) + }) + + describe('getPostParams', () => { + it('should include viewport coordinates in params', () => { + mockStoreGetState.mockReturnValue({ + getNodes: mockGetNodes, + edges: [], + transform: [100, 200, 1.5], + }) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + const sentData = JSON.parse(mockSendBeacon.mock.calls[0][1]) + expect(sentData.graph.viewport).toEqual({ x: 100, y: 200, zoom: 1.5 }) + }) + + it('should include environment variables in params', () => { + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + environmentVariables: [{ key: 'API_KEY', value: 'secret' }], + syncWorkflowDraftHash: 'test-hash', + ragPipelineVariables: [], + setSyncWorkflowDraftHash: vi.fn(), + setDraftUpdatedAt: vi.fn(), + }) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + const sentData = JSON.parse(mockSendBeacon.mock.calls[0][1]) + expect(sentData.environment_variables).toEqual([{ key: 'API_KEY', value: 'secret' }]) + }) + + it('should include rag pipeline variables in params', () => { + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + environmentVariables: [], + syncWorkflowDraftHash: 'test-hash', + ragPipelineVariables: [{ variable: 'input', type: 'text-input' }], + setSyncWorkflowDraftHash: vi.fn(), + setDraftUpdatedAt: vi.fn(), + }) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + const sentData = JSON.parse(mockSendBeacon.mock.calls[0][1]) + expect(sentData.rag_pipeline_variables).toEqual([{ variable: 'input', type: 'text-input' }]) + }) + + it('should remove underscore-prefixed keys from edges', () => { + mockStoreGetState.mockReturnValue({ + getNodes: mockGetNodes, + edges: [{ id: 'edge-1', source: 'node-1', target: 'node-2', data: { _hidden: true, visible: false } }], + transform: [0, 0, 1], + }) + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start' }, position: { x: 0, y: 0 } }, + ]) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + const sentData = JSON.parse(mockSendBeacon.mock.calls[0][1]) + expect(sentData.graph.edges[0].data._hidden).toBeUndefined() + expect(sentData.graph.edges[0].data.visible).toBe(false) + }) + }) +}) diff --git a/web/app/components/rag-pipeline/hooks/use-pipeline-config.spec.ts b/web/app/components/rag-pipeline/hooks/use-pipeline-config.spec.ts new file mode 100644 index 0000000000..491d2828d8 --- /dev/null +++ b/web/app/components/rag-pipeline/hooks/use-pipeline-config.spec.ts @@ -0,0 +1,299 @@ +import { renderHook } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +// ============================================================================ +// Import after mocks +// ============================================================================ + +import { usePipelineConfig } from './use-pipeline-config' + +// ============================================================================ +// Mocks +// ============================================================================ + +// Mock workflow store +const mockUseStore = vi.fn() +const mockWorkflowStoreGetState = vi.fn() + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: Record) => unknown) => mockUseStore(selector), + useWorkflowStore: () => ({ + getState: mockWorkflowStoreGetState, + }), +})) + +// Mock useWorkflowConfig +const mockUseWorkflowConfig = vi.fn() +vi.mock('@/service/use-workflow', () => ({ + useWorkflowConfig: (url: string, callback: (data: unknown) => void) => mockUseWorkflowConfig(url, callback), +})) + +// Mock useDataSourceList +const mockUseDataSourceList = vi.fn() +vi.mock('@/service/use-pipeline', () => ({ + useDataSourceList: (enabled: boolean, callback: (data: unknown) => void) => mockUseDataSourceList(enabled, callback), +})) + +// Mock basePath +vi.mock('@/utils/var', () => ({ + basePath: '/base', +})) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('usePipelineConfig', () => { + const mockSetNodesDefaultConfigs = vi.fn() + const mockSetPublishedAt = vi.fn() + const mockSetDataSourceList = vi.fn() + const mockSetFileUploadConfig = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + + mockUseStore.mockImplementation((selector: (state: Record) => unknown) => { + const state = { pipelineId: 'test-pipeline-id' } + return selector(state) + }) + + mockWorkflowStoreGetState.mockReturnValue({ + setNodesDefaultConfigs: mockSetNodesDefaultConfigs, + setPublishedAt: mockSetPublishedAt, + setDataSourceList: mockSetDataSourceList, + setFileUploadConfig: mockSetFileUploadConfig, + }) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('hook initialization', () => { + it('should render without crashing', () => { + expect(() => renderHook(() => usePipelineConfig())).not.toThrow() + }) + + it('should call useWorkflowConfig with correct URL for nodes default configs', () => { + renderHook(() => usePipelineConfig()) + + expect(mockUseWorkflowConfig).toHaveBeenCalledWith( + '/rag/pipelines/test-pipeline-id/workflows/default-workflow-block-configs', + expect.any(Function), + ) + }) + + it('should call useWorkflowConfig with correct URL for published workflow', () => { + renderHook(() => usePipelineConfig()) + + expect(mockUseWorkflowConfig).toHaveBeenCalledWith( + '/rag/pipelines/test-pipeline-id/workflows/publish', + expect.any(Function), + ) + }) + + it('should call useWorkflowConfig with correct URL for file upload config', () => { + renderHook(() => usePipelineConfig()) + + expect(mockUseWorkflowConfig).toHaveBeenCalledWith( + '/files/upload', + expect.any(Function), + ) + }) + + it('should call useDataSourceList when pipelineId exists', () => { + renderHook(() => usePipelineConfig()) + + expect(mockUseDataSourceList).toHaveBeenCalledWith(true, expect.any(Function)) + }) + + it('should call useDataSourceList with false when pipelineId is missing', () => { + mockUseStore.mockImplementation((selector: (state: Record) => unknown) => { + const state = { pipelineId: undefined } + return selector(state) + }) + + renderHook(() => usePipelineConfig()) + + expect(mockUseDataSourceList).toHaveBeenCalledWith(false, expect.any(Function)) + }) + + it('should use empty URL when pipelineId is missing for nodes configs', () => { + mockUseStore.mockImplementation((selector: (state: Record) => unknown) => { + const state = { pipelineId: undefined } + return selector(state) + }) + + renderHook(() => usePipelineConfig()) + + expect(mockUseWorkflowConfig).toHaveBeenCalledWith('', expect.any(Function)) + }) + }) + + describe('handleUpdateNodesDefaultConfigs', () => { + it('should handle array format configs', () => { + let capturedCallback: ((data: unknown) => void) | undefined + mockUseWorkflowConfig.mockImplementation((url: string, callback: (data: unknown) => void) => { + if (url.includes('default-workflow-block-configs')) { + capturedCallback = callback + } + }) + + renderHook(() => usePipelineConfig()) + + const arrayConfigs = [ + { type: 'llm', config: { model: 'gpt-4' } }, + { type: 'code', config: { language: 'python' } }, + ] + + capturedCallback?.(arrayConfigs) + + expect(mockSetNodesDefaultConfigs).toHaveBeenCalledWith({ + llm: { model: 'gpt-4' }, + code: { language: 'python' }, + }) + }) + + it('should handle object format configs', () => { + let capturedCallback: ((data: unknown) => void) | undefined + mockUseWorkflowConfig.mockImplementation((url: string, callback: (data: unknown) => void) => { + if (url.includes('default-workflow-block-configs')) { + capturedCallback = callback + } + }) + + renderHook(() => usePipelineConfig()) + + const objectConfigs = { + llm: { model: 'gpt-4' }, + code: { language: 'python' }, + } + + capturedCallback?.(objectConfigs) + + expect(mockSetNodesDefaultConfigs).toHaveBeenCalledWith(objectConfigs) + }) + }) + + describe('handleUpdatePublishedAt', () => { + it('should set published at from workflow response', () => { + let capturedCallback: ((data: unknown) => void) | undefined + mockUseWorkflowConfig.mockImplementation((url: string, callback: (data: unknown) => void) => { + if (url.includes('/publish')) { + capturedCallback = callback + } + }) + + renderHook(() => usePipelineConfig()) + + capturedCallback?.({ created_at: '2024-01-01T00:00:00Z' }) + + expect(mockSetPublishedAt).toHaveBeenCalledWith('2024-01-01T00:00:00Z') + }) + + it('should handle undefined workflow response', () => { + let capturedCallback: ((data: unknown) => void) | undefined + mockUseWorkflowConfig.mockImplementation((url: string, callback: (data: unknown) => void) => { + if (url.includes('/publish')) { + capturedCallback = callback + } + }) + + renderHook(() => usePipelineConfig()) + + capturedCallback?.(undefined) + + expect(mockSetPublishedAt).toHaveBeenCalledWith(undefined) + }) + }) + + describe('handleUpdateDataSourceList', () => { + it('should set data source list', () => { + let capturedCallback: ((data: unknown) => void) | undefined + mockUseDataSourceList.mockImplementation((_enabled: boolean, callback: (data: unknown) => void) => { + capturedCallback = callback + }) + + renderHook(() => usePipelineConfig()) + + const dataSourceList = [ + { declaration: { identity: { icon: '/icon.png' } } }, + ] + + capturedCallback?.(dataSourceList) + + expect(mockSetDataSourceList).toHaveBeenCalled() + }) + + it('should prepend basePath to icon if not included', () => { + let capturedCallback: ((data: unknown) => void) | undefined + mockUseDataSourceList.mockImplementation((_enabled: boolean, callback: (data: unknown) => void) => { + capturedCallback = callback + }) + + renderHook(() => usePipelineConfig()) + + const dataSourceList = [ + { declaration: { identity: { icon: '/icon.png' } } }, + ] + + capturedCallback?.(dataSourceList) + + // The callback modifies the array in place + expect(dataSourceList[0].declaration.identity.icon).toBe('/base/icon.png') + }) + + it('should not modify icon if it already includes basePath', () => { + let capturedCallback: ((data: unknown) => void) | undefined + mockUseDataSourceList.mockImplementation((_enabled: boolean, callback: (data: unknown) => void) => { + capturedCallback = callback + }) + + renderHook(() => usePipelineConfig()) + + const dataSourceList = [ + { declaration: { identity: { icon: '/base/icon.png' } } }, + ] + + capturedCallback?.(dataSourceList) + + expect(dataSourceList[0].declaration.identity.icon).toBe('/base/icon.png') + }) + + it('should handle non-string icon', () => { + let capturedCallback: ((data: unknown) => void) | undefined + mockUseDataSourceList.mockImplementation((_enabled: boolean, callback: (data: unknown) => void) => { + capturedCallback = callback + }) + + renderHook(() => usePipelineConfig()) + + const dataSourceList = [ + { declaration: { identity: { icon: { url: '/icon.png' } } } }, + ] + + capturedCallback?.(dataSourceList) + + // Should not modify object icon + expect(dataSourceList[0].declaration.identity.icon).toEqual({ url: '/icon.png' }) + }) + }) + + describe('handleUpdateWorkflowFileUploadConfig', () => { + it('should set file upload config', () => { + let capturedCallback: ((data: unknown) => void) | undefined + mockUseWorkflowConfig.mockImplementation((url: string, callback: (data: unknown) => void) => { + if (url === '/files/upload') { + capturedCallback = callback + } + }) + + renderHook(() => usePipelineConfig()) + + const config = { max_file_size: 10 * 1024 * 1024 } + capturedCallback?.(config) + + expect(mockSetFileUploadConfig).toHaveBeenCalledWith(config) + }) + }) +}) diff --git a/web/app/components/rag-pipeline/hooks/use-pipeline-init.spec.ts b/web/app/components/rag-pipeline/hooks/use-pipeline-init.spec.ts new file mode 100644 index 0000000000..3938525311 --- /dev/null +++ b/web/app/components/rag-pipeline/hooks/use-pipeline-init.spec.ts @@ -0,0 +1,345 @@ +import { renderHook, waitFor } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +// ============================================================================ +// Import after mocks +// ============================================================================ + +import { usePipelineInit } from './use-pipeline-init' + +// ============================================================================ +// Mocks +// ============================================================================ + +// Mock workflow store +const mockWorkflowStoreGetState = vi.fn() +const mockWorkflowStoreSetState = vi.fn() +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ + getState: mockWorkflowStoreGetState, + setState: mockWorkflowStoreSetState, + }), +})) + +// Mock dataset detail context +const mockUseDatasetDetailContextWithSelector = vi.fn() +vi.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContextWithSelector: (selector: (state: Record) => unknown) => + mockUseDatasetDetailContextWithSelector(selector), +})) + +// Mock workflow service +const mockFetchWorkflowDraft = vi.fn() +const mockSyncWorkflowDraft = vi.fn() +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: (url: string) => mockFetchWorkflowDraft(url), + syncWorkflowDraft: (params: unknown) => mockSyncWorkflowDraft(params), +})) + +// Mock usePipelineConfig +vi.mock('./use-pipeline-config', () => ({ + usePipelineConfig: vi.fn(), +})) + +// Mock usePipelineTemplate +vi.mock('./use-pipeline-template', () => ({ + usePipelineTemplate: () => ({ + nodes: [{ id: 'template-node' }], + edges: [], + }), +})) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('usePipelineInit', () => { + const mockSetEnvSecrets = vi.fn() + const mockSetEnvironmentVariables = vi.fn() + const mockSetSyncWorkflowDraftHash = vi.fn() + const mockSetDraftUpdatedAt = vi.fn() + const mockSetToolPublished = vi.fn() + const mockSetRagPipelineVariables = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + + mockWorkflowStoreGetState.mockReturnValue({ + setEnvSecrets: mockSetEnvSecrets, + setEnvironmentVariables: mockSetEnvironmentVariables, + setSyncWorkflowDraftHash: mockSetSyncWorkflowDraftHash, + setDraftUpdatedAt: mockSetDraftUpdatedAt, + setToolPublished: mockSetToolPublished, + setRagPipelineVariables: mockSetRagPipelineVariables, + }) + + mockUseDatasetDetailContextWithSelector.mockImplementation((selector: (state: Record) => unknown) => { + const state = { + dataset: { + pipeline_id: 'test-pipeline-id', + name: 'Test Knowledge', + icon_info: { icon: 'test-icon' }, + }, + } + return selector(state) + }) + + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { + nodes: [{ id: 'node-1' }], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + hash: 'test-hash', + updated_at: '2024-01-01T00:00:00Z', + tool_published: true, + environment_variables: [], + rag_pipeline_variables: [], + }) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('hook initialization', () => { + it('should return data and isLoading', async () => { + const { result } = renderHook(() => usePipelineInit()) + + expect(result.current.isLoading).toBe(true) + expect(result.current.data).toBeUndefined() + }) + + it('should set pipelineId in workflow store on mount', () => { + renderHook(() => usePipelineInit()) + + expect(mockWorkflowStoreSetState).toHaveBeenCalledWith({ + pipelineId: 'test-pipeline-id', + knowledgeName: 'Test Knowledge', + knowledgeIcon: { icon: 'test-icon' }, + }) + }) + }) + + describe('data fetching', () => { + it('should fetch workflow draft on mount', async () => { + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockFetchWorkflowDraft).toHaveBeenCalledWith('/rag/pipelines/test-pipeline-id/workflows/draft') + }) + }) + + it('should set data after successful fetch', async () => { + const { result } = renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(result.current.data).toBeDefined() + }) + }) + + it('should set isLoading to false after fetch', async () => { + const { result } = renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(result.current.isLoading).toBe(false) + }) + }) + + it('should set draft updated at', async () => { + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockSetDraftUpdatedAt).toHaveBeenCalledWith('2024-01-01T00:00:00Z') + }) + }) + + it('should set tool published status', async () => { + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockSetToolPublished).toHaveBeenCalledWith(true) + }) + }) + + it('should set sync hash', async () => { + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockSetSyncWorkflowDraftHash).toHaveBeenCalledWith('test-hash') + }) + }) + }) + + describe('environment variables handling', () => { + it('should extract secret environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { nodes: [], edges: [], viewport: {} }, + hash: 'test-hash', + updated_at: '2024-01-01T00:00:00Z', + tool_published: false, + environment_variables: [ + { id: 'env-1', value_type: 'secret', value: 'secret-value' }, + { id: 'env-2', value_type: 'string', value: 'plain-value' }, + ], + rag_pipeline_variables: [], + }) + + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockSetEnvSecrets).toHaveBeenCalledWith({ 'env-1': 'secret-value' }) + }) + }) + + it('should mask secret values in environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { nodes: [], edges: [], viewport: {} }, + hash: 'test-hash', + updated_at: '2024-01-01T00:00:00Z', + tool_published: false, + environment_variables: [ + { id: 'env-1', value_type: 'secret', value: 'secret-value' }, + { id: 'env-2', value_type: 'string', value: 'plain-value' }, + ], + rag_pipeline_variables: [], + }) + + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([ + { id: 'env-1', value_type: 'secret', value: '[__HIDDEN__]' }, + { id: 'env-2', value_type: 'string', value: 'plain-value' }, + ]) + }) + }) + + it('should handle empty environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { nodes: [], edges: [], viewport: {} }, + hash: 'test-hash', + updated_at: '2024-01-01T00:00:00Z', + tool_published: false, + environment_variables: [], + rag_pipeline_variables: [], + }) + + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockSetEnvSecrets).toHaveBeenCalledWith({}) + expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([]) + }) + }) + }) + + describe('rag pipeline variables handling', () => { + it('should set rag pipeline variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { nodes: [], edges: [], viewport: {} }, + hash: 'test-hash', + updated_at: '2024-01-01T00:00:00Z', + tool_published: false, + environment_variables: [], + rag_pipeline_variables: [ + { variable: 'query', type: 'text-input' }, + ], + }) + + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockSetRagPipelineVariables).toHaveBeenCalledWith([ + { variable: 'query', type: 'text-input' }, + ]) + }) + }) + + it('should handle undefined rag pipeline variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { nodes: [], edges: [], viewport: {} }, + hash: 'test-hash', + updated_at: '2024-01-01T00:00:00Z', + tool_published: false, + environment_variables: [], + rag_pipeline_variables: undefined, + }) + + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockSetRagPipelineVariables).toHaveBeenCalledWith([]) + }) + }) + }) + + describe('draft not exist error handling', () => { + it('should create initial workflow when draft does not exist', async () => { + const mockJsonError = { + json: vi.fn().mockResolvedValue({ code: 'draft_workflow_not_exist' }), + bodyUsed: false, + } + mockFetchWorkflowDraft.mockRejectedValueOnce(mockJsonError) + mockSyncWorkflowDraft.mockResolvedValue({ updated_at: '2024-01-02T00:00:00Z' }) + + // Second fetch succeeds + mockFetchWorkflowDraft.mockResolvedValueOnce({ + graph: { nodes: [], edges: [], viewport: {} }, + hash: 'new-hash', + updated_at: '2024-01-02T00:00:00Z', + tool_published: false, + environment_variables: [], + rag_pipeline_variables: [], + }) + + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockWorkflowStoreSetState).toHaveBeenCalledWith({ + notInitialWorkflow: true, + shouldAutoOpenStartNodeSelector: true, + }) + }) + }) + + it('should sync initial workflow with template nodes', async () => { + const mockJsonError = { + json: vi.fn().mockResolvedValue({ code: 'draft_workflow_not_exist' }), + bodyUsed: false, + } + mockFetchWorkflowDraft.mockRejectedValueOnce(mockJsonError) + mockSyncWorkflowDraft.mockResolvedValue({ updated_at: '2024-01-02T00:00:00Z' }) + + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockSyncWorkflowDraft).toHaveBeenCalledWith({ + url: '/rag/pipelines/test-pipeline-id/workflows/draft', + params: { + graph: { + nodes: [{ id: 'template-node' }], + edges: [], + }, + environment_variables: [], + }, + }) + }) + }) + }) + + describe('missing datasetId', () => { + it('should not fetch when datasetId is missing', async () => { + mockUseDatasetDetailContextWithSelector.mockImplementation((selector: (state: Record) => unknown) => { + const state = { dataset: undefined } + return selector(state) + }) + + renderHook(() => usePipelineInit()) + + await waitFor(() => { + expect(mockFetchWorkflowDraft).toHaveBeenCalled() + }) + }) + }) +}) diff --git a/web/app/components/rag-pipeline/hooks/use-pipeline-refresh-draft.spec.ts b/web/app/components/rag-pipeline/hooks/use-pipeline-refresh-draft.spec.ts new file mode 100644 index 0000000000..efdb18b7d4 --- /dev/null +++ b/web/app/components/rag-pipeline/hooks/use-pipeline-refresh-draft.spec.ts @@ -0,0 +1,246 @@ +import { renderHook, waitFor } from '@testing-library/react' +import { act } from 'react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +// ============================================================================ +// Import after mocks +// ============================================================================ + +import { usePipelineRefreshDraft } from './use-pipeline-refresh-draft' + +// ============================================================================ +// Mocks +// ============================================================================ + +// Mock workflow store +const mockWorkflowStoreGetState = vi.fn() +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ + getState: mockWorkflowStoreGetState, + }), +})) + +// Mock useWorkflowUpdate +const mockHandleUpdateWorkflowCanvas = vi.fn() +vi.mock('@/app/components/workflow/hooks', () => ({ + useWorkflowUpdate: () => ({ + handleUpdateWorkflowCanvas: mockHandleUpdateWorkflowCanvas, + }), +})) + +// Mock workflow service +const mockFetchWorkflowDraft = vi.fn() +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: (url: string) => mockFetchWorkflowDraft(url), +})) + +// Mock utils +vi.mock('../utils', () => ({ + processNodesWithoutDataSource: (nodes: unknown[], viewport: unknown) => ({ + nodes, + viewport, + }), +})) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('usePipelineRefreshDraft', () => { + const mockSetSyncWorkflowDraftHash = vi.fn() + const mockSetIsSyncingWorkflowDraft = vi.fn() + const mockSetEnvironmentVariables = vi.fn() + const mockSetEnvSecrets = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + setSyncWorkflowDraftHash: mockSetSyncWorkflowDraftHash, + setIsSyncingWorkflowDraft: mockSetIsSyncingWorkflowDraft, + setEnvironmentVariables: mockSetEnvironmentVariables, + setEnvSecrets: mockSetEnvSecrets, + }) + + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { + nodes: [{ id: 'node-1' }], + edges: [{ id: 'edge-1' }], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + hash: 'new-hash', + environment_variables: [], + }) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('hook initialization', () => { + it('should return handleRefreshWorkflowDraft function', () => { + const { result } = renderHook(() => usePipelineRefreshDraft()) + + expect(result.current.handleRefreshWorkflowDraft).toBeDefined() + expect(typeof result.current.handleRefreshWorkflowDraft).toBe('function') + }) + }) + + describe('handleRefreshWorkflowDraft', () => { + it('should set syncing state to true at start', async () => { + const { result } = renderHook(() => usePipelineRefreshDraft()) + + act(() => { + result.current.handleRefreshWorkflowDraft() + }) + + expect(mockSetIsSyncingWorkflowDraft).toHaveBeenCalledWith(true) + }) + + it('should fetch workflow draft with correct URL', async () => { + const { result } = renderHook(() => usePipelineRefreshDraft()) + + act(() => { + result.current.handleRefreshWorkflowDraft() + }) + + expect(mockFetchWorkflowDraft).toHaveBeenCalledWith('/rag/pipelines/test-pipeline-id/workflows/draft') + }) + + it('should update workflow canvas with response data', async () => { + const { result } = renderHook(() => usePipelineRefreshDraft()) + + act(() => { + result.current.handleRefreshWorkflowDraft() + }) + + await waitFor(() => { + expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalled() + }) + }) + + it('should update sync hash after fetch', async () => { + const { result } = renderHook(() => usePipelineRefreshDraft()) + + act(() => { + result.current.handleRefreshWorkflowDraft() + }) + + await waitFor(() => { + expect(mockSetSyncWorkflowDraftHash).toHaveBeenCalledWith('new-hash') + }) + }) + + it('should set syncing state to false after completion', async () => { + const { result } = renderHook(() => usePipelineRefreshDraft()) + + act(() => { + result.current.handleRefreshWorkflowDraft() + }) + + await waitFor(() => { + expect(mockSetIsSyncingWorkflowDraft).toHaveBeenLastCalledWith(false) + }) + }) + + it('should handle secret environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { + nodes: [], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + hash: 'new-hash', + environment_variables: [ + { id: 'env-1', value_type: 'secret', value: 'secret-value' }, + { id: 'env-2', value_type: 'string', value: 'plain-value' }, + ], + }) + + const { result } = renderHook(() => usePipelineRefreshDraft()) + + act(() => { + result.current.handleRefreshWorkflowDraft() + }) + + await waitFor(() => { + expect(mockSetEnvSecrets).toHaveBeenCalledWith({ 'env-1': 'secret-value' }) + }) + }) + + it('should mask secret values in environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { + nodes: [], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + hash: 'new-hash', + environment_variables: [ + { id: 'env-1', value_type: 'secret', value: 'secret-value' }, + { id: 'env-2', value_type: 'string', value: 'plain-value' }, + ], + }) + + const { result } = renderHook(() => usePipelineRefreshDraft()) + + act(() => { + result.current.handleRefreshWorkflowDraft() + }) + + await waitFor(() => { + expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([ + { id: 'env-1', value_type: 'secret', value: '[__HIDDEN__]' }, + { id: 'env-2', value_type: 'string', value: 'plain-value' }, + ]) + }) + }) + + it('should handle empty environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { + nodes: [], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + hash: 'new-hash', + environment_variables: [], + }) + + const { result } = renderHook(() => usePipelineRefreshDraft()) + + act(() => { + result.current.handleRefreshWorkflowDraft() + }) + + await waitFor(() => { + expect(mockSetEnvSecrets).toHaveBeenCalledWith({}) + expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([]) + }) + }) + + it('should handle undefined environment variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + graph: { + nodes: [], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + hash: 'new-hash', + environment_variables: undefined, + }) + + const { result } = renderHook(() => usePipelineRefreshDraft()) + + act(() => { + result.current.handleRefreshWorkflowDraft() + }) + + await waitFor(() => { + expect(mockSetEnvSecrets).toHaveBeenCalledWith({}) + expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([]) + }) + }) + }) +}) diff --git a/web/app/components/rag-pipeline/hooks/use-pipeline-run.spec.ts b/web/app/components/rag-pipeline/hooks/use-pipeline-run.spec.ts new file mode 100644 index 0000000000..2b21001839 --- /dev/null +++ b/web/app/components/rag-pipeline/hooks/use-pipeline-run.spec.ts @@ -0,0 +1,825 @@ +/* eslint-disable ts/no-explicit-any */ +import { renderHook } from '@testing-library/react' +import { act } from 'react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { WorkflowRunningStatus } from '@/app/components/workflow/types' + +// ============================================================================ +// Import after mocks +// ============================================================================ + +import { usePipelineRun } from './use-pipeline-run' + +// ============================================================================ +// Mocks +// ============================================================================ + +// Mock reactflow +const mockStoreGetState = vi.fn() +const mockGetViewport = vi.fn() +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ + getState: mockStoreGetState, + }), + useReactFlow: () => ({ + getViewport: mockGetViewport, + }), +})) + +// Mock workflow store +const mockUseStore = vi.fn() +const mockWorkflowStoreGetState = vi.fn() +const mockWorkflowStoreSetState = vi.fn() +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: Record) => unknown) => mockUseStore(selector), + useWorkflowStore: () => ({ + getState: mockWorkflowStoreGetState, + setState: mockWorkflowStoreSetState, + }), +})) + +// Mock useNodesSyncDraft +const mockDoSyncWorkflowDraft = vi.fn() +vi.mock('./use-nodes-sync-draft', () => ({ + useNodesSyncDraft: () => ({ + doSyncWorkflowDraft: mockDoSyncWorkflowDraft, + }), +})) + +// Mock workflow hooks +vi.mock('@/app/components/workflow/hooks/use-fetch-workflow-inspect-vars', () => ({ + useSetWorkflowVarsWithValue: () => ({ + fetchInspectVars: vi.fn(), + }), +})) + +const mockHandleUpdateWorkflowCanvas = vi.fn() +vi.mock('@/app/components/workflow/hooks/use-workflow-interactions', () => ({ + useWorkflowUpdate: () => ({ + handleUpdateWorkflowCanvas: mockHandleUpdateWorkflowCanvas, + }), +})) + +vi.mock('@/app/components/workflow/hooks/use-workflow-run-event/use-workflow-run-event', () => ({ + useWorkflowRunEvent: () => ({ + handleWorkflowStarted: vi.fn(), + handleWorkflowFinished: vi.fn(), + handleWorkflowFailed: vi.fn(), + handleWorkflowNodeStarted: vi.fn(), + handleWorkflowNodeFinished: vi.fn(), + handleWorkflowNodeIterationStarted: vi.fn(), + handleWorkflowNodeIterationNext: vi.fn(), + handleWorkflowNodeIterationFinished: vi.fn(), + handleWorkflowNodeLoopStarted: vi.fn(), + handleWorkflowNodeLoopNext: vi.fn(), + handleWorkflowNodeLoopFinished: vi.fn(), + handleWorkflowNodeRetry: vi.fn(), + handleWorkflowAgentLog: vi.fn(), + handleWorkflowTextChunk: vi.fn(), + handleWorkflowTextReplace: vi.fn(), + }), +})) + +// Mock service +const mockSsePost = vi.fn() +vi.mock('@/service/base', () => ({ + ssePost: (url: string, ...args: unknown[]) => mockSsePost(url, ...args), +})) + +const mockStopWorkflowRun = vi.fn() +vi.mock('@/service/workflow', () => ({ + stopWorkflowRun: (url: string) => mockStopWorkflowRun(url), +})) + +const mockInvalidAllLastRun = vi.fn() +vi.mock('@/service/use-workflow', () => ({ + useInvalidAllLastRun: () => mockInvalidAllLastRun, +})) + +// Mock FlowType +vi.mock('@/types/common', () => ({ + FlowType: { + ragPipeline: 'rag-pipeline', + }, +})) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('usePipelineRun', () => { + const mockSetNodes = vi.fn() + const mockGetNodes = vi.fn() + const mockSetBackupDraft = vi.fn() + const mockSetEnvironmentVariables = vi.fn() + const mockSetRagPipelineVariables = vi.fn() + const mockSetWorkflowRunningData = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + + // Mock DOM element + const mockWorkflowContainer = document.createElement('div') + mockWorkflowContainer.id = 'workflow-container' + Object.defineProperty(mockWorkflowContainer, 'clientWidth', { value: 1000 }) + Object.defineProperty(mockWorkflowContainer, 'clientHeight', { value: 800 }) + document.body.appendChild(mockWorkflowContainer) + + mockStoreGetState.mockReturnValue({ + getNodes: mockGetNodes, + setNodes: mockSetNodes, + edges: [], + }) + + mockGetNodes.mockReturnValue([ + { id: 'node-1', data: { type: 'start', selected: true, _runningStatus: WorkflowRunningStatus.Running } }, + ]) + + mockGetViewport.mockReturnValue({ x: 0, y: 0, zoom: 1 }) + + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + backupDraft: undefined, + environmentVariables: [], + setBackupDraft: mockSetBackupDraft, + setEnvironmentVariables: mockSetEnvironmentVariables, + setRagPipelineVariables: mockSetRagPipelineVariables, + setWorkflowRunningData: mockSetWorkflowRunningData, + }) + + mockUseStore.mockImplementation((selector: (state: Record) => unknown) => { + return selector({ pipelineId: 'test-pipeline-id' }) + }) + + mockDoSyncWorkflowDraft.mockResolvedValue(undefined) + }) + + afterEach(() => { + const container = document.getElementById('workflow-container') + if (container) { + document.body.removeChild(container) + } + vi.clearAllMocks() + }) + + describe('hook initialization', () => { + it('should return handleBackupDraft function', () => { + const { result } = renderHook(() => usePipelineRun()) + + expect(result.current.handleBackupDraft).toBeDefined() + expect(typeof result.current.handleBackupDraft).toBe('function') + }) + + it('should return handleLoadBackupDraft function', () => { + const { result } = renderHook(() => usePipelineRun()) + + expect(result.current.handleLoadBackupDraft).toBeDefined() + expect(typeof result.current.handleLoadBackupDraft).toBe('function') + }) + + it('should return handleRun function', () => { + const { result } = renderHook(() => usePipelineRun()) + + expect(result.current.handleRun).toBeDefined() + expect(typeof result.current.handleRun).toBe('function') + }) + + it('should return handleStopRun function', () => { + const { result } = renderHook(() => usePipelineRun()) + + expect(result.current.handleStopRun).toBeDefined() + expect(typeof result.current.handleStopRun).toBe('function') + }) + + it('should return handleRestoreFromPublishedWorkflow function', () => { + const { result } = renderHook(() => usePipelineRun()) + + expect(result.current.handleRestoreFromPublishedWorkflow).toBeDefined() + expect(typeof result.current.handleRestoreFromPublishedWorkflow).toBe('function') + }) + }) + + describe('handleBackupDraft', () => { + it('should backup draft when no backup exists', () => { + const { result } = renderHook(() => usePipelineRun()) + + act(() => { + result.current.handleBackupDraft() + }) + + expect(mockSetBackupDraft).toHaveBeenCalled() + expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() + }) + + it('should not backup draft when backup already exists', () => { + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + backupDraft: { nodes: [], edges: [], viewport: {}, environmentVariables: [] }, + environmentVariables: [], + setBackupDraft: mockSetBackupDraft, + setEnvironmentVariables: mockSetEnvironmentVariables, + setRagPipelineVariables: mockSetRagPipelineVariables, + setWorkflowRunningData: mockSetWorkflowRunningData, + }) + + const { result } = renderHook(() => usePipelineRun()) + + act(() => { + result.current.handleBackupDraft() + }) + + expect(mockSetBackupDraft).not.toHaveBeenCalled() + }) + }) + + describe('handleLoadBackupDraft', () => { + it('should load backup draft when exists', () => { + const backupDraft = { + nodes: [{ id: 'backup-node' }], + edges: [{ id: 'backup-edge' }], + viewport: { x: 100, y: 100, zoom: 1.5 }, + environmentVariables: [{ key: 'ENV', value: 'test' }], + } + + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + backupDraft, + environmentVariables: [], + setBackupDraft: mockSetBackupDraft, + setEnvironmentVariables: mockSetEnvironmentVariables, + setRagPipelineVariables: mockSetRagPipelineVariables, + setWorkflowRunningData: mockSetWorkflowRunningData, + }) + + const { result } = renderHook(() => usePipelineRun()) + + act(() => { + result.current.handleLoadBackupDraft() + }) + + expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalledWith({ + nodes: backupDraft.nodes, + edges: backupDraft.edges, + viewport: backupDraft.viewport, + }) + expect(mockSetEnvironmentVariables).toHaveBeenCalledWith(backupDraft.environmentVariables) + expect(mockSetBackupDraft).toHaveBeenCalledWith(undefined) + }) + + it('should not load when no backup exists', () => { + mockWorkflowStoreGetState.mockReturnValue({ + pipelineId: 'test-pipeline-id', + backupDraft: undefined, + environmentVariables: [], + setBackupDraft: mockSetBackupDraft, + setEnvironmentVariables: mockSetEnvironmentVariables, + setRagPipelineVariables: mockSetRagPipelineVariables, + setWorkflowRunningData: mockSetWorkflowRunningData, + }) + + const { result } = renderHook(() => usePipelineRun()) + + act(() => { + result.current.handleLoadBackupDraft() + }) + + expect(mockHandleUpdateWorkflowCanvas).not.toHaveBeenCalled() + }) + }) + + describe('handleStopRun', () => { + it('should call stop workflow run service', () => { + const { result } = renderHook(() => usePipelineRun()) + + act(() => { + result.current.handleStopRun('task-123') + }) + + expect(mockStopWorkflowRun).toHaveBeenCalledWith( + '/rag/pipelines/test-pipeline-id/workflow-runs/tasks/task-123/stop', + ) + }) + }) + + describe('handleRestoreFromPublishedWorkflow', () => { + it('should restore workflow from published version', () => { + const publishedWorkflow = { + graph: { + nodes: [{ id: 'pub-node', data: { type: 'start' } }], + edges: [{ id: 'pub-edge' }], + viewport: { x: 50, y: 50, zoom: 1 }, + }, + environment_variables: [{ key: 'PUB_ENV', value: 'pub' }], + rag_pipeline_variables: [{ variable: 'input', type: 'text-input' }], + } + + const { result } = renderHook(() => usePipelineRun()) + + act(() => { + result.current.handleRestoreFromPublishedWorkflow(publishedWorkflow as any) + }) + + expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalledWith({ + nodes: [{ id: 'pub-node', data: { type: 'start', selected: false }, selected: false }], + edges: publishedWorkflow.graph.edges, + viewport: publishedWorkflow.graph.viewport, + }) + }) + + it('should set environment variables from published workflow', () => { + const publishedWorkflow = { + graph: { + nodes: [], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + environment_variables: [{ key: 'ENV', value: 'value' }], + rag_pipeline_variables: [], + } + + const { result } = renderHook(() => usePipelineRun()) + + act(() => { + result.current.handleRestoreFromPublishedWorkflow(publishedWorkflow as any) + }) + + expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([{ key: 'ENV', value: 'value' }]) + }) + + it('should set rag pipeline variables from published workflow', () => { + const publishedWorkflow = { + graph: { + nodes: [], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + environment_variables: [], + rag_pipeline_variables: [{ variable: 'query', type: 'text-input' }], + } + + const { result } = renderHook(() => usePipelineRun()) + + act(() => { + result.current.handleRestoreFromPublishedWorkflow(publishedWorkflow as any) + }) + + expect(mockSetRagPipelineVariables).toHaveBeenCalledWith([{ variable: 'query', type: 'text-input' }]) + }) + + it('should handle empty environment and rag pipeline variables', () => { + const publishedWorkflow = { + graph: { + nodes: [], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + environment_variables: undefined, + rag_pipeline_variables: undefined, + } + + const { result } = renderHook(() => usePipelineRun()) + + act(() => { + result.current.handleRestoreFromPublishedWorkflow(publishedWorkflow as any) + }) + + expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([]) + expect(mockSetRagPipelineVariables).toHaveBeenCalledWith([]) + }) + }) + + describe('handleRun', () => { + it('should sync workflow draft before running', async () => { + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }) + }) + + expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() + }) + + it('should reset node selection and running status', async () => { + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }) + }) + + expect(mockSetNodes).toHaveBeenCalled() + }) + + it('should clear history workflow data', async () => { + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }) + }) + + expect(mockWorkflowStoreSetState).toHaveBeenCalledWith({ historyWorkflowData: undefined }) + }) + + it('should set initial running data', async () => { + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }) + }) + + expect(mockSetWorkflowRunningData).toHaveBeenCalledWith({ + result: { + inputs_truncated: false, + process_data_truncated: false, + outputs_truncated: false, + status: WorkflowRunningStatus.Running, + }, + tracing: [], + resultText: '', + }) + }) + + it('should call ssePost with correct URL', async () => { + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: { query: 'test' } }) + }) + + expect(mockSsePost).toHaveBeenCalledWith( + '/rag/pipelines/test-pipeline-id/workflows/draft/run', + expect.any(Object), + expect.any(Object), + ) + }) + + it('should call onWorkflowStarted callback when provided', async () => { + const onWorkflowStarted = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onWorkflowStarted }) + }) + + // Trigger the callback + await act(async () => { + capturedCallbacks.onWorkflowStarted?.({ task_id: 'task-1' }) + }) + + expect(onWorkflowStarted).toHaveBeenCalledWith({ task_id: 'task-1' }) + }) + + it('should call onWorkflowFinished callback when provided', async () => { + const onWorkflowFinished = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onWorkflowFinished }) + }) + + await act(async () => { + capturedCallbacks.onWorkflowFinished?.({ status: 'succeeded' }) + }) + + expect(onWorkflowFinished).toHaveBeenCalledWith({ status: 'succeeded' }) + }) + + it('should call onError callback when provided', async () => { + const onError = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onError }) + }) + + await act(async () => { + capturedCallbacks.onError?.({ message: 'error' }) + }) + + expect(onError).toHaveBeenCalledWith({ message: 'error' }) + }) + + it('should call onNodeStarted callback when provided', async () => { + const onNodeStarted = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onNodeStarted }) + }) + + await act(async () => { + capturedCallbacks.onNodeStarted?.({ node_id: 'node-1' }) + }) + + expect(onNodeStarted).toHaveBeenCalledWith({ node_id: 'node-1' }) + }) + + it('should call onNodeFinished callback when provided', async () => { + const onNodeFinished = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onNodeFinished }) + }) + + await act(async () => { + capturedCallbacks.onNodeFinished?.({ node_id: 'node-1' }) + }) + + expect(onNodeFinished).toHaveBeenCalledWith({ node_id: 'node-1' }) + }) + + it('should call onIterationStart callback when provided', async () => { + const onIterationStart = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onIterationStart }) + }) + + await act(async () => { + capturedCallbacks.onIterationStart?.({ iteration_id: 'iter-1' }) + }) + + expect(onIterationStart).toHaveBeenCalledWith({ iteration_id: 'iter-1' }) + }) + + it('should call onIterationNext callback when provided', async () => { + const onIterationNext = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onIterationNext }) + }) + + await act(async () => { + capturedCallbacks.onIterationNext?.({ index: 1 }) + }) + + expect(onIterationNext).toHaveBeenCalledWith({ index: 1 }) + }) + + it('should call onIterationFinish callback when provided', async () => { + const onIterationFinish = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onIterationFinish }) + }) + + await act(async () => { + capturedCallbacks.onIterationFinish?.({ iteration_id: 'iter-1' }) + }) + + expect(onIterationFinish).toHaveBeenCalledWith({ iteration_id: 'iter-1' }) + }) + + it('should call onLoopStart callback when provided', async () => { + const onLoopStart = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onLoopStart }) + }) + + await act(async () => { + capturedCallbacks.onLoopStart?.({ loop_id: 'loop-1' }) + }) + + expect(onLoopStart).toHaveBeenCalledWith({ loop_id: 'loop-1' }) + }) + + it('should call onLoopNext callback when provided', async () => { + const onLoopNext = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onLoopNext }) + }) + + await act(async () => { + capturedCallbacks.onLoopNext?.({ index: 2 }) + }) + + expect(onLoopNext).toHaveBeenCalledWith({ index: 2 }) + }) + + it('should call onLoopFinish callback when provided', async () => { + const onLoopFinish = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onLoopFinish }) + }) + + await act(async () => { + capturedCallbacks.onLoopFinish?.({ loop_id: 'loop-1' }) + }) + + expect(onLoopFinish).toHaveBeenCalledWith({ loop_id: 'loop-1' }) + }) + + it('should call onNodeRetry callback when provided', async () => { + const onNodeRetry = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onNodeRetry }) + }) + + await act(async () => { + capturedCallbacks.onNodeRetry?.({ node_id: 'node-1', retry: 1 }) + }) + + expect(onNodeRetry).toHaveBeenCalledWith({ node_id: 'node-1', retry: 1 }) + }) + + it('should call onAgentLog callback when provided', async () => { + const onAgentLog = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onAgentLog }) + }) + + await act(async () => { + capturedCallbacks.onAgentLog?.({ message: 'agent log' }) + }) + + expect(onAgentLog).toHaveBeenCalledWith({ message: 'agent log' }) + }) + + it('should handle onTextChunk callback', async () => { + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }) + }) + + await act(async () => { + capturedCallbacks.onTextChunk?.({ text: 'chunk' }) + }) + + // Just verify it doesn't throw + expect(capturedCallbacks.onTextChunk).toBeDefined() + }) + + it('should handle onTextReplace callback', async () => { + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }) + }) + + await act(async () => { + capturedCallbacks.onTextReplace?.({ text: 'replaced' }) + }) + + // Just verify it doesn't throw + expect(capturedCallbacks.onTextReplace).toBeDefined() + }) + + it('should pass rest callback to ssePost', async () => { + const customCallback = vi.fn() + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + await act(async () => { + await result.current.handleRun({ inputs: {} }, { onData: customCallback } as any) + }) + + expect(capturedCallbacks.onData).toBeDefined() + }) + + it('should handle callbacks without optional handlers', async () => { + let capturedCallbacks: Record void> = {} + + mockSsePost.mockImplementation((_url, _body, callbacks) => { + capturedCallbacks = callbacks + }) + + const { result } = renderHook(() => usePipelineRun()) + + // Run without any optional callbacks + await act(async () => { + await result.current.handleRun({ inputs: {} }) + }) + + // Trigger all callbacks - they should not throw even without optional handlers + await act(async () => { + capturedCallbacks.onWorkflowStarted?.({ task_id: 'task-1' }) + capturedCallbacks.onWorkflowFinished?.({ status: 'succeeded' }) + capturedCallbacks.onError?.({ message: 'error' }) + capturedCallbacks.onNodeStarted?.({ node_id: 'node-1' }) + capturedCallbacks.onNodeFinished?.({ node_id: 'node-1' }) + capturedCallbacks.onIterationStart?.({ iteration_id: 'iter-1' }) + capturedCallbacks.onIterationNext?.({ index: 1 }) + capturedCallbacks.onIterationFinish?.({ iteration_id: 'iter-1' }) + capturedCallbacks.onLoopStart?.({ loop_id: 'loop-1' }) + capturedCallbacks.onLoopNext?.({ index: 2 }) + capturedCallbacks.onLoopFinish?.({ loop_id: 'loop-1' }) + capturedCallbacks.onNodeRetry?.({ node_id: 'node-1', retry: 1 }) + capturedCallbacks.onAgentLog?.({ message: 'agent log' }) + capturedCallbacks.onTextChunk?.({ text: 'chunk' }) + capturedCallbacks.onTextReplace?.({ text: 'replaced' }) + }) + + // Verify ssePost was called + expect(mockSsePost).toHaveBeenCalled() + }) + }) +}) diff --git a/web/app/components/rag-pipeline/hooks/use-pipeline-start-run.spec.ts b/web/app/components/rag-pipeline/hooks/use-pipeline-start-run.spec.ts new file mode 100644 index 0000000000..4266fb993d --- /dev/null +++ b/web/app/components/rag-pipeline/hooks/use-pipeline-start-run.spec.ts @@ -0,0 +1,217 @@ +import { renderHook } from '@testing-library/react' +import { act } from 'react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { WorkflowRunningStatus } from '@/app/components/workflow/types' + +// ============================================================================ +// Import after mocks +// ============================================================================ + +import { usePipelineStartRun } from './use-pipeline-start-run' + +// ============================================================================ +// Mocks +// ============================================================================ + +// Mock workflow store +const mockWorkflowStoreGetState = vi.fn() +const mockWorkflowStoreSetState = vi.fn() +vi.mock('@/app/components/workflow/store', () => ({ + useWorkflowStore: () => ({ + getState: mockWorkflowStoreGetState, + setState: mockWorkflowStoreSetState, + }), +})) + +// Mock workflow interactions +const mockHandleCancelDebugAndPreviewPanel = vi.fn() +vi.mock('@/app/components/workflow/hooks', () => ({ + useWorkflowInteractions: () => ({ + handleCancelDebugAndPreviewPanel: mockHandleCancelDebugAndPreviewPanel, + }), +})) + +// Mock useNodesSyncDraft +const mockDoSyncWorkflowDraft = vi.fn() +vi.mock('@/app/components/rag-pipeline/hooks', () => ({ + useNodesSyncDraft: () => ({ + doSyncWorkflowDraft: mockDoSyncWorkflowDraft, + }), + useInputFieldPanel: () => ({ + closeAllInputFieldPanels: vi.fn(), + }), +})) + +// ============================================================================ +// Tests +// ============================================================================ + +describe('usePipelineStartRun', () => { + const mockSetIsPreparingDataSource = vi.fn() + const mockSetShowEnvPanel = vi.fn() + const mockSetShowDebugAndPreviewPanel = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + + mockWorkflowStoreGetState.mockReturnValue({ + workflowRunningData: undefined, + isPreparingDataSource: false, + showDebugAndPreviewPanel: false, + setIsPreparingDataSource: mockSetIsPreparingDataSource, + setShowEnvPanel: mockSetShowEnvPanel, + setShowDebugAndPreviewPanel: mockSetShowDebugAndPreviewPanel, + }) + + mockDoSyncWorkflowDraft.mockResolvedValue(undefined) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('hook initialization', () => { + it('should return handleStartWorkflowRun function', () => { + const { result } = renderHook(() => usePipelineStartRun()) + + expect(result.current.handleStartWorkflowRun).toBeDefined() + expect(typeof result.current.handleStartWorkflowRun).toBe('function') + }) + + it('should return handleWorkflowStartRunInWorkflow function', () => { + const { result } = renderHook(() => usePipelineStartRun()) + + expect(result.current.handleWorkflowStartRunInWorkflow).toBeDefined() + expect(typeof result.current.handleWorkflowStartRunInWorkflow).toBe('function') + }) + }) + + describe('handleWorkflowStartRunInWorkflow', () => { + it('should not proceed when workflow is already running', async () => { + mockWorkflowStoreGetState.mockReturnValue({ + workflowRunningData: { + result: { status: WorkflowRunningStatus.Running }, + }, + isPreparingDataSource: false, + showDebugAndPreviewPanel: false, + setIsPreparingDataSource: mockSetIsPreparingDataSource, + setShowEnvPanel: mockSetShowEnvPanel, + setShowDebugAndPreviewPanel: mockSetShowDebugAndPreviewPanel, + }) + + const { result } = renderHook(() => usePipelineStartRun()) + + await act(async () => { + await result.current.handleWorkflowStartRunInWorkflow() + }) + + expect(mockSetShowEnvPanel).not.toHaveBeenCalled() + }) + + it('should set preparing data source when not preparing and has running data', async () => { + mockWorkflowStoreGetState.mockReturnValue({ + workflowRunningData: { + result: { status: WorkflowRunningStatus.Succeeded }, + }, + isPreparingDataSource: false, + showDebugAndPreviewPanel: false, + setIsPreparingDataSource: mockSetIsPreparingDataSource, + setShowEnvPanel: mockSetShowEnvPanel, + setShowDebugAndPreviewPanel: mockSetShowDebugAndPreviewPanel, + }) + + const { result } = renderHook(() => usePipelineStartRun()) + + await act(async () => { + await result.current.handleWorkflowStartRunInWorkflow() + }) + + expect(mockWorkflowStoreSetState).toHaveBeenCalledWith({ + isPreparingDataSource: true, + workflowRunningData: undefined, + }) + }) + + it('should cancel debug panel when already showing', async () => { + mockWorkflowStoreGetState.mockReturnValue({ + workflowRunningData: undefined, + isPreparingDataSource: false, + showDebugAndPreviewPanel: true, + setIsPreparingDataSource: mockSetIsPreparingDataSource, + setShowEnvPanel: mockSetShowEnvPanel, + setShowDebugAndPreviewPanel: mockSetShowDebugAndPreviewPanel, + }) + + const { result } = renderHook(() => usePipelineStartRun()) + + await act(async () => { + await result.current.handleWorkflowStartRunInWorkflow() + }) + + expect(mockSetIsPreparingDataSource).toHaveBeenCalledWith(false) + expect(mockHandleCancelDebugAndPreviewPanel).toHaveBeenCalled() + }) + + it('should sync draft and show debug panel when conditions are met', async () => { + mockWorkflowStoreGetState.mockReturnValue({ + workflowRunningData: undefined, + isPreparingDataSource: false, + showDebugAndPreviewPanel: false, + setIsPreparingDataSource: mockSetIsPreparingDataSource, + setShowEnvPanel: mockSetShowEnvPanel, + setShowDebugAndPreviewPanel: mockSetShowDebugAndPreviewPanel, + }) + + const { result } = renderHook(() => usePipelineStartRun()) + + await act(async () => { + await result.current.handleWorkflowStartRunInWorkflow() + }) + + expect(mockDoSyncWorkflowDraft).toHaveBeenCalled() + expect(mockSetIsPreparingDataSource).toHaveBeenCalledWith(true) + expect(mockSetShowDebugAndPreviewPanel).toHaveBeenCalledWith(true) + }) + + it('should hide env panel at start', async () => { + mockWorkflowStoreGetState.mockReturnValue({ + workflowRunningData: undefined, + isPreparingDataSource: false, + showDebugAndPreviewPanel: false, + setIsPreparingDataSource: mockSetIsPreparingDataSource, + setShowEnvPanel: mockSetShowEnvPanel, + setShowDebugAndPreviewPanel: mockSetShowDebugAndPreviewPanel, + }) + + const { result } = renderHook(() => usePipelineStartRun()) + + await act(async () => { + await result.current.handleWorkflowStartRunInWorkflow() + }) + + expect(mockSetShowEnvPanel).toHaveBeenCalledWith(false) + }) + }) + + describe('handleStartWorkflowRun', () => { + it('should call handleWorkflowStartRunInWorkflow', async () => { + mockWorkflowStoreGetState.mockReturnValue({ + workflowRunningData: undefined, + isPreparingDataSource: false, + showDebugAndPreviewPanel: false, + setIsPreparingDataSource: mockSetIsPreparingDataSource, + setShowEnvPanel: mockSetShowEnvPanel, + setShowDebugAndPreviewPanel: mockSetShowDebugAndPreviewPanel, + }) + + const { result } = renderHook(() => usePipelineStartRun()) + + await act(async () => { + result.current.handleStartWorkflowRun() + }) + + // Should trigger the same workflow as handleWorkflowStartRunInWorkflow + expect(mockSetShowEnvPanel).toHaveBeenCalledWith(false) + }) + }) +}) diff --git a/web/app/components/rag-pipeline/store/index.spec.ts b/web/app/components/rag-pipeline/store/index.spec.ts new file mode 100644 index 0000000000..c8c0a35330 --- /dev/null +++ b/web/app/components/rag-pipeline/store/index.spec.ts @@ -0,0 +1,289 @@ +/* eslint-disable ts/no-explicit-any */ +import type { DataSourceItem } from '@/app/components/workflow/block-selector/types' +import { describe, expect, it, vi } from 'vitest' +import { createRagPipelineSliceSlice } from './index' + +// Mock the transformDataSourceToTool function +vi.mock('@/app/components/workflow/block-selector/utils', () => ({ + transformDataSourceToTool: (item: DataSourceItem) => ({ + ...item, + transformed: true, + }), +})) + +describe('createRagPipelineSliceSlice', () => { + const mockSet = vi.fn() + + describe('initial state', () => { + it('should have empty pipelineId', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + expect(slice.pipelineId).toBe('') + }) + + it('should have empty knowledgeName', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + expect(slice.knowledgeName).toBe('') + }) + + it('should have showInputFieldPanel as false', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + expect(slice.showInputFieldPanel).toBe(false) + }) + + it('should have showInputFieldPreviewPanel as false', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + expect(slice.showInputFieldPreviewPanel).toBe(false) + }) + + it('should have inputFieldEditPanelProps as null', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + expect(slice.inputFieldEditPanelProps).toBeNull() + }) + + it('should have empty nodesDefaultConfigs', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + expect(slice.nodesDefaultConfigs).toEqual({}) + }) + + it('should have empty ragPipelineVariables', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + expect(slice.ragPipelineVariables).toEqual([]) + }) + + it('should have empty dataSourceList', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + expect(slice.dataSourceList).toEqual([]) + }) + + it('should have isPreparingDataSource as false', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + expect(slice.isPreparingDataSource).toBe(false) + }) + }) + + describe('setShowInputFieldPanel', () => { + it('should call set with showInputFieldPanel true', () => { + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setShowInputFieldPanel(true) + + expect(mockSet).toHaveBeenCalledWith(expect.any(Function)) + + // Get the setter function and execute it + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ showInputFieldPanel: true }) + }) + + it('should call set with showInputFieldPanel false', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setShowInputFieldPanel(false) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ showInputFieldPanel: false }) + }) + }) + + describe('setShowInputFieldPreviewPanel', () => { + it('should call set with showInputFieldPreviewPanel true', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setShowInputFieldPreviewPanel(true) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ showInputFieldPreviewPanel: true }) + }) + + it('should call set with showInputFieldPreviewPanel false', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setShowInputFieldPreviewPanel(false) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ showInputFieldPreviewPanel: false }) + }) + }) + + describe('setInputFieldEditPanelProps', () => { + it('should call set with inputFieldEditPanelProps object', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + const props = { type: 'create' as const } + + slice.setInputFieldEditPanelProps(props as any) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ inputFieldEditPanelProps: props }) + }) + + it('should call set with inputFieldEditPanelProps null', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setInputFieldEditPanelProps(null) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ inputFieldEditPanelProps: null }) + }) + }) + + describe('setNodesDefaultConfigs', () => { + it('should call set with nodesDefaultConfigs', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + const configs = { node1: { key: 'value' } } + + slice.setNodesDefaultConfigs(configs) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ nodesDefaultConfigs: configs }) + }) + + it('should call set with empty nodesDefaultConfigs', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setNodesDefaultConfigs({}) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ nodesDefaultConfigs: {} }) + }) + }) + + describe('setRagPipelineVariables', () => { + it('should call set with ragPipelineVariables', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + const variables = [ + { type: 'text-input', variable: 'var1', label: 'Var 1', required: true }, + ] + + slice.setRagPipelineVariables(variables as any) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ ragPipelineVariables: variables }) + }) + + it('should call set with empty ragPipelineVariables', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setRagPipelineVariables([]) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ ragPipelineVariables: [] }) + }) + }) + + describe('setDataSourceList', () => { + it('should transform and set dataSourceList', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + const dataSourceList: DataSourceItem[] = [ + { name: 'source1', key: 'key1' } as unknown as DataSourceItem, + { name: 'source2', key: 'key2' } as unknown as DataSourceItem, + ] + + slice.setDataSourceList(dataSourceList) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result.dataSourceList).toHaveLength(2) + expect(result.dataSourceList[0]).toEqual({ name: 'source1', key: 'key1', transformed: true }) + expect(result.dataSourceList[1]).toEqual({ name: 'source2', key: 'key2', transformed: true }) + }) + + it('should set empty dataSourceList', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setDataSourceList([]) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result.dataSourceList).toEqual([]) + }) + }) + + describe('setIsPreparingDataSource', () => { + it('should call set with isPreparingDataSource true', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setIsPreparingDataSource(true) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ isPreparingDataSource: true }) + }) + + it('should call set with isPreparingDataSource false', () => { + mockSet.mockClear() + const slice = createRagPipelineSliceSlice(mockSet, vi.fn() as any, vi.fn() as any) + + slice.setIsPreparingDataSource(false) + + const setterFn = mockSet.mock.calls[0][0] + const result = setterFn() + expect(result).toEqual({ isPreparingDataSource: false }) + }) + }) +}) + +describe('RagPipelineSliceShape type', () => { + it('should define all required properties', () => { + const slice = createRagPipelineSliceSlice(vi.fn(), vi.fn() as any, vi.fn() as any) + + // Check all properties exist + expect(slice).toHaveProperty('pipelineId') + expect(slice).toHaveProperty('knowledgeName') + expect(slice).toHaveProperty('showInputFieldPanel') + expect(slice).toHaveProperty('setShowInputFieldPanel') + expect(slice).toHaveProperty('showInputFieldPreviewPanel') + expect(slice).toHaveProperty('setShowInputFieldPreviewPanel') + expect(slice).toHaveProperty('inputFieldEditPanelProps') + expect(slice).toHaveProperty('setInputFieldEditPanelProps') + expect(slice).toHaveProperty('nodesDefaultConfigs') + expect(slice).toHaveProperty('setNodesDefaultConfigs') + expect(slice).toHaveProperty('ragPipelineVariables') + expect(slice).toHaveProperty('setRagPipelineVariables') + expect(slice).toHaveProperty('dataSourceList') + expect(slice).toHaveProperty('setDataSourceList') + expect(slice).toHaveProperty('isPreparingDataSource') + expect(slice).toHaveProperty('setIsPreparingDataSource') + }) + + it('should have all setters as functions', () => { + const slice = createRagPipelineSliceSlice(vi.fn(), vi.fn() as any, vi.fn() as any) + + expect(typeof slice.setShowInputFieldPanel).toBe('function') + expect(typeof slice.setShowInputFieldPreviewPanel).toBe('function') + expect(typeof slice.setInputFieldEditPanelProps).toBe('function') + expect(typeof slice.setNodesDefaultConfigs).toBe('function') + expect(typeof slice.setRagPipelineVariables).toBe('function') + expect(typeof slice.setDataSourceList).toBe('function') + expect(typeof slice.setIsPreparingDataSource).toBe('function') + }) +}) diff --git a/web/app/components/rag-pipeline/utils/index.spec.ts b/web/app/components/rag-pipeline/utils/index.spec.ts new file mode 100644 index 0000000000..9d816af685 --- /dev/null +++ b/web/app/components/rag-pipeline/utils/index.spec.ts @@ -0,0 +1,348 @@ +import type { Viewport } from 'reactflow' +import type { Node } from '@/app/components/workflow/types' +import { describe, expect, it, vi } from 'vitest' +import { BlockEnum } from '@/app/components/workflow/types' +import { processNodesWithoutDataSource } from './nodes' + +// Mock constants +vi.mock('@/app/components/workflow/constants', () => ({ + CUSTOM_NODE: 'custom', + NODE_WIDTH_X_OFFSET: 400, + START_INITIAL_POSITION: { x: 100, y: 100 }, +})) + +vi.mock('@/app/components/workflow/nodes/data-source-empty/constants', () => ({ + CUSTOM_DATA_SOURCE_EMPTY_NODE: 'data-source-empty', +})) + +vi.mock('@/app/components/workflow/note-node/constants', () => ({ + CUSTOM_NOTE_NODE: 'note', +})) + +vi.mock('@/app/components/workflow/note-node/types', () => ({ + NoteTheme: { blue: 'blue' }, +})) + +vi.mock('@/app/components/workflow/utils', () => ({ + generateNewNode: ({ id, type, data, position }: { id: string, type?: string, data: object, position: { x: number, y: number } }) => ({ + newNode: { id, type: type || 'custom', data, position }, + }), +})) + +describe('processNodesWithoutDataSource', () => { + describe('when nodes contain DataSource', () => { + it('should return original nodes and viewport unchanged', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.DataSource, title: 'Data Source' }, + position: { x: 100, y: 100 }, + } as Node, + { + id: 'node-2', + type: 'custom', + data: { type: BlockEnum.End, title: 'End' }, + position: { x: 500, y: 100 }, + } as Node, + ] + const viewport: Viewport = { x: 0, y: 0, zoom: 1 } + + const result = processNodesWithoutDataSource(nodes, viewport) + + expect(result.nodes).toBe(nodes) + expect(result.viewport).toBe(viewport) + }) + + it('should check all nodes before returning early', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.Start, title: 'Start' }, + position: { x: 0, y: 0 }, + } as Node, + { + id: 'node-2', + type: 'custom', + data: { type: BlockEnum.DataSource, title: 'Data Source' }, + position: { x: 100, y: 100 }, + } as Node, + ] + + const result = processNodesWithoutDataSource(nodes) + + expect(result.nodes).toBe(nodes) + }) + }) + + describe('when nodes do not contain DataSource', () => { + it('should add data source empty node and note node for single custom node', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'Knowledge Base' }, + position: { x: 500, y: 200 }, + } as Node, + ] + const viewport: Viewport = { x: 0, y: 0, zoom: 1 } + + const result = processNodesWithoutDataSource(nodes, viewport) + + expect(result.nodes.length).toBe(3) + expect(result.nodes[0].id).toBe('data-source-empty') + expect(result.nodes[1].id).toBe('note') + expect(result.nodes[2]).toBe(nodes[0]) + }) + + it('should use the leftmost custom node position for new nodes', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB 1' }, + position: { x: 700, y: 100 }, + } as Node, + { + id: 'node-2', + type: 'custom', + data: { type: BlockEnum.End, title: 'End' }, + position: { x: 200, y: 100 }, // This is the leftmost + } as Node, + { + id: 'node-3', + type: 'custom', + data: { type: BlockEnum.Start, title: 'Start' }, + position: { x: 500, y: 100 }, + } as Node, + ] + const viewport: Viewport = { x: 0, y: 0, zoom: 1 } + + const result = processNodesWithoutDataSource(nodes, viewport) + + // New nodes should be positioned based on the leftmost node (x: 200) + // startX = 200 - 400 = -200 + expect(result.nodes[0].position.x).toBe(-200) + expect(result.nodes[0].position.y).toBe(100) + }) + + it('should adjust viewport based on new node position', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB' }, + position: { x: 300, y: 200 }, + } as Node, + ] + const viewport: Viewport = { x: 0, y: 0, zoom: 1 } + + const result = processNodesWithoutDataSource(nodes, viewport) + + // startX = 300 - 400 = -100 + // startY = 200 + // viewport.x = (100 - (-100)) * 1 = 200 + // viewport.y = (100 - 200) * 1 = -100 + expect(result.viewport).toEqual({ + x: 200, + y: -100, + zoom: 1, + }) + }) + + it('should apply zoom factor to viewport calculation', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB' }, + position: { x: 300, y: 200 }, + } as Node, + ] + const viewport: Viewport = { x: 0, y: 0, zoom: 2 } + + const result = processNodesWithoutDataSource(nodes, viewport) + + // startX = 300 - 400 = -100 + // startY = 200 + // viewport.x = (100 - (-100)) * 2 = 400 + // viewport.y = (100 - 200) * 2 = -200 + expect(result.viewport).toEqual({ + x: 400, + y: -200, + zoom: 2, + }) + }) + + it('should use default zoom 1 when viewport zoom is undefined', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB' }, + position: { x: 500, y: 100 }, + } as Node, + ] + + const result = processNodesWithoutDataSource(nodes, undefined) + + expect(result.viewport?.zoom).toBe(1) + }) + + it('should add note node below data source empty node', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB' }, + position: { x: 500, y: 100 }, + } as Node, + ] + + const result = processNodesWithoutDataSource(nodes) + + // Data source empty node position + const dataSourceEmptyNode = result.nodes[0] + const noteNode = result.nodes[1] + + // Note node should be 100px below data source empty node + expect(noteNode.position.x).toBe(dataSourceEmptyNode.position.x) + expect(noteNode.position.y).toBe(dataSourceEmptyNode.position.y + 100) + }) + + it('should set correct data for data source empty node', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB' }, + position: { x: 500, y: 100 }, + } as Node, + ] + + const result = processNodesWithoutDataSource(nodes) + + expect(result.nodes[0].data.type).toBe(BlockEnum.DataSourceEmpty) + expect(result.nodes[0].data._isTempNode).toBe(true) + expect(result.nodes[0].data.width).toBe(240) + }) + + it('should set correct data for note node', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB' }, + position: { x: 500, y: 100 }, + } as Node, + ] + + const result = processNodesWithoutDataSource(nodes) + + const noteNode = result.nodes[1] + const noteData = noteNode.data as Record + expect(noteData._isTempNode).toBe(true) + expect(noteData.theme).toBe('blue') + expect(noteData.width).toBe(240) + expect(noteData.height).toBe(300) + expect(noteData.showAuthor).toBe(true) + }) + }) + + describe('when nodes array is empty', () => { + it('should return empty nodes array unchanged', () => { + const nodes: Node[] = [] + const viewport: Viewport = { x: 0, y: 0, zoom: 1 } + + const result = processNodesWithoutDataSource(nodes, viewport) + + expect(result.nodes).toEqual([]) + expect(result.viewport).toBe(viewport) + }) + }) + + describe('when no custom nodes exist', () => { + it('should return original nodes when only non-custom nodes', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'special', // Not 'custom' + data: { type: BlockEnum.Start, title: 'Start' }, + position: { x: 100, y: 100 }, + } as Node, + ] + const viewport: Viewport = { x: 0, y: 0, zoom: 1 } + + const result = processNodesWithoutDataSource(nodes, viewport) + + // No custom nodes to find leftmost, so no new nodes are added + expect(result.nodes).toBe(nodes) + expect(result.viewport).toBe(viewport) + }) + }) + + describe('edge cases', () => { + it('should handle nodes with same x position', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB 1' }, + position: { x: 300, y: 100 }, + } as Node, + { + id: 'node-2', + type: 'custom', + data: { type: BlockEnum.End, title: 'End' }, + position: { x: 300, y: 200 }, + } as Node, + ] + + const result = processNodesWithoutDataSource(nodes) + + // First node should be used as leftNode + expect(result.nodes.length).toBe(4) + }) + + it('should handle negative positions', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB' }, + position: { x: -100, y: -50 }, + } as Node, + ] + + const result = processNodesWithoutDataSource(nodes) + + // startX = -100 - 400 = -500 + expect(result.nodes[0].position.x).toBe(-500) + expect(result.nodes[0].position.y).toBe(-50) + }) + + it('should handle undefined viewport gracefully', () => { + const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + data: { type: BlockEnum.KnowledgeBase, title: 'KB' }, + position: { x: 500, y: 100 }, + } as Node, + ] + + const result = processNodesWithoutDataSource(nodes, undefined) + + expect(result.viewport).toBeDefined() + expect(result.viewport?.zoom).toBe(1) + }) + }) +}) + +describe('module exports', () => { + it('should export processNodesWithoutDataSource', () => { + expect(processNodesWithoutDataSource).toBeDefined() + expect(typeof processNodesWithoutDataSource).toBe('function') + }) +}) diff --git a/web/app/components/share/text-generation/info-modal.spec.tsx b/web/app/components/share/text-generation/info-modal.spec.tsx new file mode 100644 index 0000000000..025c5edde1 --- /dev/null +++ b/web/app/components/share/text-generation/info-modal.spec.tsx @@ -0,0 +1,205 @@ +import type { SiteInfo } from '@/models/share' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import InfoModal from './info-modal' + +// Only mock react-i18next for translations +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +afterEach(() => { + cleanup() +}) + +describe('InfoModal', () => { + const mockOnClose = vi.fn() + + const baseSiteInfo: SiteInfo = { + title: 'Test App', + icon: '๐Ÿš€', + icon_type: 'emoji', + icon_background: '#ffffff', + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering', () => { + it('should not render when isShow is false', () => { + render( + , + ) + + expect(screen.queryByText('Test App')).not.toBeInTheDocument() + }) + + it('should render when isShow is true', () => { + render( + , + ) + + expect(screen.getByText('Test App')).toBeInTheDocument() + }) + + it('should render app title', () => { + render( + , + ) + + expect(screen.getByText('Test App')).toBeInTheDocument() + }) + + it('should render copyright when provided', () => { + const siteInfoWithCopyright: SiteInfo = { + ...baseSiteInfo, + copyright: 'Dify Inc.', + } + + render( + , + ) + + expect(screen.getByText(/Dify Inc./)).toBeInTheDocument() + }) + + it('should render current year in copyright', () => { + const siteInfoWithCopyright: SiteInfo = { + ...baseSiteInfo, + copyright: 'Test Company', + } + + render( + , + ) + + const currentYear = new Date().getFullYear().toString() + expect(screen.getByText(new RegExp(currentYear))).toBeInTheDocument() + }) + + it('should render custom disclaimer when provided', () => { + const siteInfoWithDisclaimer: SiteInfo = { + ...baseSiteInfo, + custom_disclaimer: 'This is a custom disclaimer', + } + + render( + , + ) + + expect(screen.getByText('This is a custom disclaimer')).toBeInTheDocument() + }) + + it('should not render copyright section when not provided', () => { + render( + , + ) + + const year = new Date().getFullYear().toString() + expect(screen.queryByText(new RegExp(`ยฉ.*${year}`))).not.toBeInTheDocument() + }) + + it('should render with undefined data', () => { + render( + , + ) + + // Modal should still render but without content + expect(screen.queryByText('Test App')).not.toBeInTheDocument() + }) + + it('should render with image icon type', () => { + const siteInfoWithImage: SiteInfo = { + ...baseSiteInfo, + icon_type: 'image', + icon_url: 'https://example.com/icon.png', + } + + render( + , + ) + + expect(screen.getByText(siteInfoWithImage.title!)).toBeInTheDocument() + }) + }) + + describe('close functionality', () => { + it('should call onClose when close button is clicked', () => { + render( + , + ) + + // Find the close icon (RiCloseLine) which has text-text-tertiary class + const closeIcon = document.querySelector('[class*="text-text-tertiary"]') + expect(closeIcon).toBeInTheDocument() + if (closeIcon) { + fireEvent.click(closeIcon) + expect(mockOnClose).toHaveBeenCalled() + } + }) + }) + + describe('both copyright and disclaimer', () => { + it('should render both when both are provided', () => { + const siteInfoWithBoth: SiteInfo = { + ...baseSiteInfo, + copyright: 'My Company', + custom_disclaimer: 'Disclaimer text here', + } + + render( + , + ) + + expect(screen.getByText(/My Company/)).toBeInTheDocument() + expect(screen.getByText('Disclaimer text here')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/share/text-generation/menu-dropdown.spec.tsx b/web/app/components/share/text-generation/menu-dropdown.spec.tsx new file mode 100644 index 0000000000..b54a2df632 --- /dev/null +++ b/web/app/components/share/text-generation/menu-dropdown.spec.tsx @@ -0,0 +1,261 @@ +import type { SiteInfo } from '@/models/share' +import { act, cleanup, fireEvent, render, screen, waitFor } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import MenuDropdown from './menu-dropdown' + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock next/navigation +const mockReplace = vi.fn() +const mockPathname = '/test-path' +vi.mock('next/navigation', () => ({ + useRouter: () => ({ + replace: mockReplace, + }), + usePathname: () => mockPathname, +})) + +// Mock web-app-context +const mockShareCode = 'test-share-code' +vi.mock('@/context/web-app-context', () => ({ + useWebAppStore: (selector: (state: Record) => unknown) => { + const state = { + webAppAccessMode: 'code', + shareCode: mockShareCode, + } + return selector(state) + }, +})) + +// Mock webapp-auth service +const mockWebAppLogout = vi.fn().mockResolvedValue(undefined) +vi.mock('@/service/webapp-auth', () => ({ + webAppLogout: (...args: unknown[]) => mockWebAppLogout(...args), +})) + +afterEach(() => { + cleanup() +}) + +describe('MenuDropdown', () => { + const baseSiteInfo: SiteInfo = { + title: 'Test App', + icon: '๐Ÿš€', + icon_type: 'emoji', + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering', () => { + it('should render the trigger button', () => { + render() + + // The trigger button contains a settings icon (RiEqualizer2Line) + const triggerButton = screen.getByRole('button') + expect(triggerButton).toBeInTheDocument() + }) + + it('should not show dropdown content initially', () => { + render() + + // Dropdown content should not be visible initially + expect(screen.queryByText('theme.theme')).not.toBeInTheDocument() + }) + + it('should show dropdown content when clicked', async () => { + render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + expect(screen.getByText('theme.theme')).toBeInTheDocument() + }) + }) + + it('should show About option in dropdown', async () => { + render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + expect(screen.getByText('userProfile.about')).toBeInTheDocument() + }) + }) + }) + + describe('privacy policy link', () => { + it('should show privacy policy link when provided', async () => { + const siteInfoWithPrivacy: SiteInfo = { + ...baseSiteInfo, + privacy_policy: 'https://example.com/privacy', + } + + render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + expect(screen.getByText('chat.privacyPolicyMiddle')).toBeInTheDocument() + }) + }) + + it('should not show privacy policy link when not provided', async () => { + render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + expect(screen.queryByText('chat.privacyPolicyMiddle')).not.toBeInTheDocument() + }) + }) + + it('should have correct href for privacy policy link', async () => { + const privacyUrl = 'https://example.com/privacy' + const siteInfoWithPrivacy: SiteInfo = { + ...baseSiteInfo, + privacy_policy: privacyUrl, + } + + render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + const link = screen.getByText('chat.privacyPolicyMiddle').closest('a') + expect(link).toHaveAttribute('href', privacyUrl) + expect(link).toHaveAttribute('target', '_blank') + }) + }) + }) + + describe('logout functionality', () => { + it('should show logout option when hideLogout is false', async () => { + render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + expect(screen.getByText('userProfile.logout')).toBeInTheDocument() + }) + }) + + it('should hide logout option when hideLogout is true', async () => { + render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + expect(screen.queryByText('userProfile.logout')).not.toBeInTheDocument() + }) + }) + + it('should call webAppLogout and redirect when logout is clicked', async () => { + render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + expect(screen.getByText('userProfile.logout')).toBeInTheDocument() + }) + + const logoutButton = screen.getByText('userProfile.logout') + await act(async () => { + fireEvent.click(logoutButton) + }) + + await waitFor(() => { + expect(mockWebAppLogout).toHaveBeenCalledWith(mockShareCode) + expect(mockReplace).toHaveBeenCalledWith(`/webapp-signin?redirect_url=${mockPathname}`) + }) + }) + }) + + describe('about modal', () => { + it('should show InfoModal when About is clicked', async () => { + render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + expect(screen.getByText('userProfile.about')).toBeInTheDocument() + }) + + const aboutButton = screen.getByText('userProfile.about') + fireEvent.click(aboutButton) + + await waitFor(() => { + expect(screen.getByText('Test App')).toBeInTheDocument() + }) + }) + }) + + describe('forceClose prop', () => { + it('should close dropdown when forceClose changes to true', async () => { + const { rerender } = render() + + const triggerButton = screen.getByRole('button') + fireEvent.click(triggerButton) + + await waitFor(() => { + expect(screen.getByText('theme.theme')).toBeInTheDocument() + }) + + rerender() + + await waitFor(() => { + expect(screen.queryByText('theme.theme')).not.toBeInTheDocument() + }) + }) + }) + + describe('placement prop', () => { + it('should accept custom placement', () => { + render() + + const triggerButton = screen.getByRole('button') + expect(triggerButton).toBeInTheDocument() + }) + }) + + describe('toggle behavior', () => { + it('should close dropdown when clicking trigger again', async () => { + render() + + const triggerButton = screen.getByRole('button') + + // Open + fireEvent.click(triggerButton) + await waitFor(() => { + expect(screen.getByText('theme.theme')).toBeInTheDocument() + }) + + // Close + fireEvent.click(triggerButton) + await waitFor(() => { + expect(screen.queryByText('theme.theme')).not.toBeInTheDocument() + }) + }) + }) + + describe('memoization', () => { + it('should be wrapped with React.memo', () => { + expect((MenuDropdown as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) + }) +}) diff --git a/web/app/components/share/text-generation/result/content.spec.tsx b/web/app/components/share/text-generation/result/content.spec.tsx new file mode 100644 index 0000000000..242ae7aa5f --- /dev/null +++ b/web/app/components/share/text-generation/result/content.spec.tsx @@ -0,0 +1,133 @@ +import type { FeedbackType } from '@/app/components/base/chat/chat/type' +import { cleanup, render, screen } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import Result from './content' + +// Only mock react-i18next for translations +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock copy-to-clipboard for the Header component +vi.mock('copy-to-clipboard', () => ({ + default: vi.fn(() => true), +})) + +// Mock the format function from service/base +vi.mock('@/service/base', () => ({ + format: (content: string) => content.replace(/\n/g, '
    '), +})) + +afterEach(() => { + cleanup() +}) + +describe('Result (content)', () => { + const mockOnFeedback = vi.fn() + + const defaultProps = { + content: 'Test content here', + showFeedback: true, + feedback: { rating: null } as FeedbackType, + onFeedback: mockOnFeedback, + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering', () => { + it('should render the Header component', () => { + render() + + // Header renders the result title + expect(screen.getByText('generation.resultTitle')).toBeInTheDocument() + }) + + it('should render content', () => { + render() + + expect(screen.getByText('Test content here')).toBeInTheDocument() + }) + + it('should render formatted content with line breaks', () => { + render( + , + ) + + // The format function converts \n to
    + const contentDiv = document.querySelector('[class*="overflow-scroll"]') + expect(contentDiv?.innerHTML).toContain('Line 1
    Line 2') + }) + + it('should have max height style', () => { + render() + + const contentDiv = document.querySelector('[class*="overflow-scroll"]') + expect(contentDiv).toHaveStyle({ maxHeight: '70vh' }) + }) + + it('should render with empty content', () => { + render( + , + ) + + expect(screen.getByText('generation.resultTitle')).toBeInTheDocument() + }) + + it('should render with HTML content safely', () => { + render( + , + ) + + // Content is rendered via dangerouslySetInnerHTML + const contentDiv = document.querySelector('[class*="overflow-scroll"]') + expect(contentDiv).toBeInTheDocument() + }) + }) + + describe('feedback props', () => { + it('should pass showFeedback to Header', () => { + render( + , + ) + + // Feedback buttons should not be visible + const feedbackArea = document.querySelector('[class*="space-x-1 rounded-lg border"]') + expect(feedbackArea).not.toBeInTheDocument() + }) + + it('should pass feedback to Header', () => { + render( + , + ) + + // Like button should be highlighted + const likeButton = document.querySelector('[class*="primary"]') + expect(likeButton).toBeInTheDocument() + }) + }) + + describe('memoization', () => { + it('should be wrapped with React.memo', () => { + expect((Result as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) + }) +}) diff --git a/web/app/components/share/text-generation/result/header.spec.tsx b/web/app/components/share/text-generation/result/header.spec.tsx new file mode 100644 index 0000000000..b2ef0fadc4 --- /dev/null +++ b/web/app/components/share/text-generation/result/header.spec.tsx @@ -0,0 +1,176 @@ +import type { FeedbackType } from '@/app/components/base/chat/chat/type' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import Header from './header' + +// Only mock react-i18next for translations +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +// Mock copy-to-clipboard +const mockCopy = vi.fn((_text: string) => true) +vi.mock('copy-to-clipboard', () => ({ + default: (text: string) => mockCopy(text), +})) + +afterEach(() => { + cleanup() +}) + +describe('Header', () => { + const mockOnFeedback = vi.fn() + + const defaultProps = { + result: 'Test result content', + showFeedback: true, + feedback: { rating: null } as FeedbackType, + onFeedback: mockOnFeedback, + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('rendering', () => { + it('should render the result title', () => { + render(
    ) + + expect(screen.getByText('generation.resultTitle')).toBeInTheDocument() + }) + + it('should render the copy button', () => { + render(
    ) + + expect(screen.getByText('generation.copy')).toBeInTheDocument() + }) + }) + + describe('copy functionality', () => { + it('should copy result when copy button is clicked', () => { + render(
    ) + + const copyButton = screen.getByText('generation.copy').closest('button') + fireEvent.click(copyButton!) + + expect(mockCopy).toHaveBeenCalledWith('Test result content') + }) + }) + + describe('feedback buttons when showFeedback is true', () => { + it('should show feedback buttons when no rating is given', () => { + render(
    ) + + // Should show both thumbs up and down buttons + const buttons = document.querySelectorAll('[class*="cursor-pointer"]') + expect(buttons.length).toBeGreaterThan(0) + }) + + it('should show like button highlighted when rating is like', () => { + render( +
    , + ) + + // Should show the undo button for like + const likeButton = document.querySelector('[class*="primary"]') + expect(likeButton).toBeInTheDocument() + }) + + it('should show dislike button highlighted when rating is dislike', () => { + render( +
    , + ) + + // Should show the undo button for dislike + const dislikeButton = document.querySelector('[class*="red"]') + expect(dislikeButton).toBeInTheDocument() + }) + + it('should call onFeedback with like when thumbs up is clicked', () => { + render(
    ) + + // Find the thumbs up button (first one in the feedback area) + const thumbButtons = document.querySelectorAll('[class*="cursor-pointer"]') + const thumbsUp = Array.from(thumbButtons).find(btn => + btn.className.includes('rounded-md') && !btn.className.includes('primary'), + ) + + if (thumbsUp) { + fireEvent.click(thumbsUp) + expect(mockOnFeedback).toHaveBeenCalledWith({ rating: 'like' }) + } + }) + + it('should call onFeedback with dislike when thumbs down is clicked', () => { + render(
    ) + + // Find the thumbs down button + const thumbButtons = document.querySelectorAll('[class*="cursor-pointer"]') + const thumbsDown = Array.from(thumbButtons).pop() + + if (thumbsDown) { + fireEvent.click(thumbsDown) + expect(mockOnFeedback).toHaveBeenCalledWith({ rating: 'dislike' }) + } + }) + + it('should call onFeedback with null when undo like is clicked', () => { + render( +
    , + ) + + // When liked, clicking the like button again should undo it (has bg-primary-100 class) + const likeButton = document.querySelector('[class*="bg-primary-100"]') + expect(likeButton).toBeInTheDocument() + fireEvent.click(likeButton!) + expect(mockOnFeedback).toHaveBeenCalledWith({ rating: null }) + }) + + it('should call onFeedback with null when undo dislike is clicked', () => { + render( +
    , + ) + + // When disliked, clicking the dislike button again should undo it (has bg-red-100 class) + const dislikeButton = document.querySelector('[class*="bg-red-100"]') + expect(dislikeButton).toBeInTheDocument() + fireEvent.click(dislikeButton!) + expect(mockOnFeedback).toHaveBeenCalledWith({ rating: null }) + }) + }) + + describe('feedback buttons when showFeedback is false', () => { + it('should not show feedback buttons', () => { + render( +
    , + ) + + // Should not show feedback area buttons (only copy button) + const feedbackArea = document.querySelector('[class*="space-x-1 rounded-lg border"]') + expect(feedbackArea).not.toBeInTheDocument() + }) + }) + + describe('memoization', () => { + it('should be wrapped with React.memo', () => { + expect((Header as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo')) + }) + }) +}) diff --git a/web/app/components/share/text-generation/run-once/index.spec.tsx b/web/app/components/share/text-generation/run-once/index.spec.tsx index ea5ce3c902..af3d723d20 100644 --- a/web/app/components/share/text-generation/run-once/index.spec.tsx +++ b/web/app/components/share/text-generation/run-once/index.spec.tsx @@ -1,6 +1,7 @@ +import type { InputValueTypes } from '../types' import type { PromptConfig, PromptVariable } from '@/models/debug' import type { SiteInfo } from '@/models/share' -import type { VisionSettings } from '@/types/app' +import type { VisionFile, VisionSettings } from '@/types/app' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' import { useEffect, useRef, useState } from 'react' @@ -27,7 +28,7 @@ vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', ( })) vi.mock('@/app/components/base/image-uploader/text-generation-image-uploader', () => { - function TextGenerationImageUploaderMock({ onFilesChange }: { onFilesChange: (files: any[]) => void }) { + function TextGenerationImageUploaderMock({ onFilesChange }: { onFilesChange: (files: VisionFile[]) => void }) { useEffect(() => { onFilesChange([]) }, [onFilesChange]) @@ -38,6 +39,20 @@ vi.mock('@/app/components/base/image-uploader/text-generation-image-uploader', ( } }) +// Mock FileUploaderInAttachmentWrapper as it requires context providers not available in tests +vi.mock('@/app/components/base/file-uploader', () => ({ + FileUploaderInAttachmentWrapper: ({ value, onChange }: { value: object[], onChange: (files: object[]) => void }) => ( +
    + + + {value?.length || 0} + {' '} + files + +
    + ), +})) + const createPromptVariable = (overrides: Partial): PromptVariable => ({ key: 'input', name: 'Input', @@ -95,11 +110,11 @@ const setup = (overrides: { const onInputsChange = vi.fn() const onSend = vi.fn() const onVisionFilesChange = vi.fn() - let inputsRefCapture: React.MutableRefObject> | null = null + let inputsRefCapture: React.MutableRefObject> | null = null const Wrapper = () => { - const [inputs, setInputs] = useState>({}) - const inputsRef = useRef>({}) + const [inputs, setInputs] = useState>({}) + const inputsRef = useRef>({}) inputsRefCapture = inputsRef return ( { expect(stopButton).toBeDisabled() }) + describe('select input type', () => { + it('should render select input and handle selection', async () => { + const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + createPromptVariable({ + key: 'selectInput', + name: 'Select Input', + type: 'select', + options: ['Option A', 'Option B', 'Option C'], + default: 'Option A', + }), + ], + } + const { onInputsChange } = setup({ promptConfig, visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalledWith({ + selectInput: 'Option A', + }) + }) + // The Select component should be rendered + expect(screen.getByText('Select Input')).toBeInTheDocument() + }) + }) + + describe('file input types', () => { + it('should render file uploader for single file input', async () => { + const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + createPromptVariable({ + key: 'fileInput', + name: 'File Input', + type: 'file', + }), + ], + } + const { onInputsChange } = setup({ promptConfig, visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalledWith({ + fileInput: undefined, + }) + }) + expect(screen.getByText('File Input')).toBeInTheDocument() + }) + + it('should render file uploader for file-list input', async () => { + const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + createPromptVariable({ + key: 'fileListInput', + name: 'File List Input', + type: 'file-list', + }), + ], + } + const { onInputsChange } = setup({ promptConfig, visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalledWith({ + fileListInput: [], + }) + }) + expect(screen.getByText('File List Input')).toBeInTheDocument() + }) + }) + + describe('json_object input type', () => { + it('should render code editor for json_object input', async () => { + const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + createPromptVariable({ + key: 'jsonInput', + name: 'JSON Input', + type: 'json_object' as PromptVariable['type'], + json_schema: '{"type": "object"}', + }), + ], + } + const { onInputsChange } = setup({ promptConfig, visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalledWith({ + jsonInput: undefined, + }) + }) + expect(screen.getByText('JSON Input')).toBeInTheDocument() + expect(screen.getByTestId('code-editor-mock')).toBeInTheDocument() + }) + + it('should update json_object input when code editor changes', async () => { + const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + createPromptVariable({ + key: 'jsonInput', + name: 'JSON Input', + type: 'json_object' as PromptVariable['type'], + }), + ], + } + const { onInputsChange } = setup({ promptConfig, visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalled() + }) + onInputsChange.mockClear() + + const codeEditor = screen.getByTestId('code-editor-mock') + fireEvent.change(codeEditor, { target: { value: '{"key": "value"}' } }) + + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalledWith({ + jsonInput: '{"key": "value"}', + }) + }) + }) + }) + + describe('hidden and optional fields', () => { + it('should not render hidden variables', async () => { + const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + createPromptVariable({ + key: 'hiddenInput', + name: 'Hidden Input', + type: 'string', + hide: true, + }), + createPromptVariable({ + key: 'visibleInput', + name: 'Visible Input', + type: 'string', + }), + ], + } + const { onInputsChange } = setup({ promptConfig, visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalled() + }) + expect(screen.queryByText('Hidden Input')).not.toBeInTheDocument() + expect(screen.getByText('Visible Input')).toBeInTheDocument() + }) + + it('should show optional label for non-required fields', async () => { + const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + createPromptVariable({ + key: 'optionalInput', + name: 'Optional Input', + type: 'string', + required: false, + }), + ], + } + const { onInputsChange } = setup({ promptConfig, visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalled() + }) + expect(screen.getByText('workflow.panel.optional')).toBeInTheDocument() + }) + }) + + describe('vision uploader', () => { + it('should not render vision uploader when disabled', async () => { + const { onInputsChange } = setup({ visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalled() + }) + expect(screen.queryByText('common.imageUploader.imageUpload')).not.toBeInTheDocument() + }) + }) + + describe('clear with different input types', () => { + it('should clear select input to undefined', async () => { + const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + createPromptVariable({ + key: 'selectInput', + name: 'Select Input', + type: 'select', + options: ['Option A', 'Option B'], + default: 'Option A', + }), + ], + } + const { onInputsChange } = setup({ promptConfig, visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalled() + }) + onInputsChange.mockClear() + + fireEvent.click(screen.getByRole('button', { name: 'common.operation.clear' })) + + expect(onInputsChange).toHaveBeenCalledWith({ + selectInput: undefined, + }) + }) + }) + describe('maxLength behavior', () => { it('should not have maxLength attribute when max_length is not set', async () => { const promptConfig: PromptConfig = { diff --git a/web/app/components/share/utils.spec.ts b/web/app/components/share/utils.spec.ts new file mode 100644 index 0000000000..ee2aab58eb --- /dev/null +++ b/web/app/components/share/utils.spec.ts @@ -0,0 +1,71 @@ +import { describe, expect, it } from 'vitest' +import { getInitialTokenV2, isTokenV1 } from './utils' + +describe('utils', () => { + describe('isTokenV1', () => { + it('should return true when token has no version property', () => { + const token = { someKey: 'value' } + expect(isTokenV1(token)).toBe(true) + }) + + it('should return true when token.version is undefined', () => { + const token = { version: undefined } + expect(isTokenV1(token)).toBe(true) + }) + + it('should return true when token.version is null', () => { + const token = { version: null } + expect(isTokenV1(token)).toBe(true) + }) + + it('should return true when token.version is 0', () => { + const token = { version: 0 } + expect(isTokenV1(token)).toBe(true) + }) + + it('should return true when token.version is empty string', () => { + const token = { version: '' } + expect(isTokenV1(token)).toBe(true) + }) + + it('should return false when token has version 1', () => { + const token = { version: 1 } + expect(isTokenV1(token)).toBe(false) + }) + + it('should return false when token has version 2', () => { + const token = { version: 2 } + expect(isTokenV1(token)).toBe(false) + }) + + it('should return false when token has string version', () => { + const token = { version: '2' } + expect(isTokenV1(token)).toBe(false) + }) + + it('should handle empty object', () => { + const token = {} + expect(isTokenV1(token)).toBe(true) + }) + }) + + describe('getInitialTokenV2', () => { + it('should return object with version 2', () => { + const token = getInitialTokenV2() + expect(token.version).toBe(2) + }) + + it('should return a new object each time', () => { + const token1 = getInitialTokenV2() + const token2 = getInitialTokenV2() + expect(token1).not.toBe(token2) + }) + + it('should return an object that can be modified without affecting future calls', () => { + const token1 = getInitialTokenV2() + token1.customField = 'test' + const token2 = getInitialTokenV2() + expect(token2.customField).toBeUndefined() + }) + }) +}) diff --git a/web/app/components/workflow-app/hooks/use-DSL.ts b/web/app/components/workflow-app/hooks/use-DSL.ts index 6c01509bc5..939e43b554 100644 --- a/web/app/components/workflow-app/hooks/use-DSL.ts +++ b/web/app/components/workflow-app/hooks/use-DSL.ts @@ -11,6 +11,7 @@ import { import { useEventEmitterContextContext } from '@/context/event-emitter' import { exportAppConfig } from '@/service/apps' import { fetchWorkflowDraft } from '@/service/workflow' +import { downloadBlob } from '@/utils/download' import { useNodesSyncDraft } from './use-nodes-sync-draft' export const useDSL = () => { @@ -37,13 +38,8 @@ export const useDSL = () => { include, workflowID: workflowId, }) - const a = document.createElement('a') const file = new Blob([data], { type: 'application/yaml' }) - const url = URL.createObjectURL(file) - a.href = url - a.download = `${appDetail.name}.yml` - a.click() - URL.revokeObjectURL(url) + downloadBlob({ data: file, fileName: `${appDetail.name}.yml` }) } catch { notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) diff --git a/web/app/components/workflow/block-selector/market-place-plugin/action.tsx b/web/app/components/workflow/block-selector/market-place-plugin/action.tsx index b8300d6f2b..abdbae1b4c 100644 --- a/web/app/components/workflow/block-selector/market-place-plugin/action.tsx +++ b/web/app/components/workflow/block-selector/market-place-plugin/action.tsx @@ -15,7 +15,7 @@ import { } from '@/app/components/base/portal-to-follow-elem' import { useDownloadPlugin } from '@/service/use-plugins' import { cn } from '@/utils/classnames' -import { downloadFile } from '@/utils/format' +import { downloadBlob } from '@/utils/download' import { getMarketplaceUrl } from '@/utils/var' type Props = { @@ -67,7 +67,7 @@ const OperationDropdown: FC = ({ if (!needDownload || !blob) return const fileName = `${author}-${name}_${version}.zip` - downloadFile({ data: blob, fileName }) + downloadBlob({ data: blob, fileName }) setNeedDownload(false) queryClient.removeQueries({ queryKey: ['plugins', 'downloadPlugin', downloadInfo], diff --git a/web/app/components/workflow/nodes/http/components/curl-panel.tsx b/web/app/components/workflow/nodes/http/components/curl-panel.tsx index aa67a2a0ae..6c809c310f 100644 --- a/web/app/components/workflow/nodes/http/components/curl-panel.tsx +++ b/web/app/components/workflow/nodes/http/components/curl-panel.tsx @@ -41,7 +41,7 @@ const parseCurl = (curlCommand: string): { node: HttpNodeType | null, error: str case '--request': if (i + 1 >= args.length) return { node: null, error: 'Missing HTTP method after -X or --request.' } - node.method = (args[++i].replace(/^['"]|['"]$/g, '') as Method) || Method.get + node.method = (args[++i].replace(/^['"]|['"]$/g, '').toLowerCase() as Method) || Method.get hasData = true break case '-H': diff --git a/web/app/components/workflow/operator/more-actions.tsx b/web/app/components/workflow/operator/more-actions.tsx index e9fc1ea87d..7e6617e84b 100644 --- a/web/app/components/workflow/operator/more-actions.tsx +++ b/web/app/components/workflow/operator/more-actions.tsx @@ -19,6 +19,7 @@ import { } from '@/app/components/base/portal-to-follow-elem' import { useStore } from '@/app/components/workflow/store' import { cn } from '@/utils/classnames' +import { downloadUrl } from '@/utils/download' import { useNodesReadOnly } from '../hooks' import TipPopup from './tip-popup' @@ -146,26 +147,14 @@ const MoreActions: FC = () => { } } + const fileName = `${filename}.${type}` + if (currentWorkflow) { setPreviewUrl(dataUrl) - setPreviewTitle(`${filename}.${type}`) + setPreviewTitle(fileName) + } - const link = document.createElement('a') - link.href = dataUrl - link.download = `${filename}.${type}` - document.body.appendChild(link) - link.click() - document.body.removeChild(link) - } - else { - // For current view, just download - const link = document.createElement('a') - link.href = dataUrl - link.download = `${filename}.${type}` - document.body.appendChild(link) - link.click() - document.body.removeChild(link) - } + downloadUrl({ url: dataUrl, fileName }) } catch (error) { console.error('Export image failed:', error) diff --git a/web/app/styles/monaco-sticky-fix.css b/web/app/styles/monaco-sticky-fix.css index 66bb5921ce..ac928cf246 100644 --- a/web/app/styles/monaco-sticky-fix.css +++ b/web/app/styles/monaco-sticky-fix.css @@ -9,8 +9,7 @@ html[data-theme="dark"] .monaco-editor .sticky-line-content:hover { background-color: var(--color-components-sticky-header-bg-hover) !important; } -/* Fallback: any app sticky header using input-bg variables should use the sticky header bg when sticky */ -html[data-theme="dark"] .sticky, html[data-theme="dark"] .is-sticky { +/* Monaco editor specific sticky scroll styles in dark mode */ +html[data-theme="dark"] .monaco-editor .sticky-line-root { background-color: var(--color-components-sticky-header-bg) !important; - border-bottom: 1px solid var(--color-components-sticky-header-border) !important; } \ No newline at end of file diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index 1d7465fffe..965ac492fb 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -1004,7 +1004,7 @@ "count": 1 }, "ts/no-explicit-any": { - "count": 3 + "count": 2 } }, "app/components/base/file-uploader/utils.ts": { @@ -1686,7 +1686,7 @@ "count": 1 }, "ts/no-explicit-any": { - "count": 5 + "count": 4 } }, "app/components/datasets/create/website/watercrawl/options.tsx": { @@ -2609,11 +2609,6 @@ "count": 2 } }, - "app/components/share/text-generation/run-once/index.spec.tsx": { - "ts/no-explicit-any": { - "count": 4 - } - }, "app/components/share/text-generation/run-once/index.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 @@ -4434,11 +4429,6 @@ "count": 1 } }, - "utils/format.spec.ts": { - "ts/no-explicit-any": { - "count": 1 - } - }, "utils/get-icon.spec.ts": { "ts/no-explicit-any": { "count": 2 diff --git a/web/next.config.ts b/web/next.config.ts index fc4dee3289..05f4158ac8 100644 --- a/web/next.config.ts +++ b/web/next.config.ts @@ -67,6 +67,9 @@ const nextConfig: NextConfig = { compiler: { removeConsole: isDev ? false : { exclude: ['warn', 'error'] }, }, + experimental: { + turbopackFileSystemCacheForDev: false, + }, } export default withBundleAnalyzer(withMDX(nextConfig)) diff --git a/web/utils/download.spec.ts b/web/utils/download.spec.ts new file mode 100644 index 0000000000..ff41ddfff7 --- /dev/null +++ b/web/utils/download.spec.ts @@ -0,0 +1,75 @@ +import { downloadBlob, downloadUrl } from './download' + +describe('downloadUrl', () => { + let mockAnchor: HTMLAnchorElement + + beforeEach(() => { + mockAnchor = { + href: '', + download: '', + rel: '', + target: '', + style: { display: '' }, + click: vi.fn(), + remove: vi.fn(), + } as unknown as HTMLAnchorElement + + vi.spyOn(document, 'createElement').mockReturnValue(mockAnchor) + vi.spyOn(document.body, 'appendChild').mockImplementation((node: Node) => node) + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + it('should create a link and trigger a download correctly', () => { + downloadUrl({ url: 'https://example.com/file.txt', fileName: 'file.txt', target: '_blank' }) + + expect(mockAnchor.href).toBe('https://example.com/file.txt') + expect(mockAnchor.download).toBe('file.txt') + expect(mockAnchor.rel).toBe('noopener noreferrer') + expect(mockAnchor.target).toBe('_blank') + expect(mockAnchor.style.display).toBe('none') + expect(mockAnchor.click).toHaveBeenCalled() + expect(mockAnchor.remove).toHaveBeenCalled() + }) + + it('should skip when url is empty', () => { + downloadUrl({ url: '' }) + expect(document.createElement).not.toHaveBeenCalled() + }) +}) + +describe('downloadBlob', () => { + it('should create a blob url, trigger download, and revoke url', () => { + const blob = new Blob(['test'], { type: 'text/plain' }) + const mockUrl = 'blob:mock-url' + const createObjectURLMock = vi.spyOn(window.URL, 'createObjectURL').mockReturnValue(mockUrl) + const revokeObjectURLMock = vi.spyOn(window.URL, 'revokeObjectURL').mockImplementation(() => {}) + + const mockAnchor = { + href: '', + download: '', + rel: '', + target: '', + style: { display: '' }, + click: vi.fn(), + remove: vi.fn(), + } as unknown as HTMLAnchorElement + + vi.spyOn(document, 'createElement').mockReturnValue(mockAnchor) + vi.spyOn(document.body, 'appendChild').mockImplementation((node: Node) => node) + + downloadBlob({ data: blob, fileName: 'file.txt' }) + + expect(createObjectURLMock).toHaveBeenCalledWith(blob) + expect(mockAnchor.href).toBe(mockUrl) + expect(mockAnchor.download).toBe('file.txt') + expect(mockAnchor.rel).toBe('noopener noreferrer') + expect(mockAnchor.click).toHaveBeenCalled() + expect(mockAnchor.remove).toHaveBeenCalled() + expect(revokeObjectURLMock).toHaveBeenCalledWith(mockUrl) + + vi.restoreAllMocks() + }) +}) diff --git a/web/utils/format.spec.ts b/web/utils/format.spec.ts index 3a1709dbdc..2796854e34 100644 --- a/web/utils/format.spec.ts +++ b/web/utils/format.spec.ts @@ -1,4 +1,4 @@ -import { downloadFile, formatFileSize, formatNumber, formatNumberAbbreviated, formatTime } from './format' +import { formatFileSize, formatNumber, formatNumberAbbreviated, formatTime } from './format' describe('formatNumber', () => { it('should correctly format integers', () => { @@ -82,49 +82,6 @@ describe('formatTime', () => { expect(formatTime(7200)).toBe('2.00 h') }) }) -describe('downloadFile', () => { - it('should create a link and trigger a download correctly', () => { - // Mock data - const blob = new Blob(['test content'], { type: 'text/plain' }) - const fileName = 'test-file.txt' - const mockUrl = 'blob:mockUrl' - - // Mock URL.createObjectURL - const createObjectURLMock = vi.fn().mockReturnValue(mockUrl) - const revokeObjectURLMock = vi.fn() - Object.defineProperty(window.URL, 'createObjectURL', { value: createObjectURLMock }) - Object.defineProperty(window.URL, 'revokeObjectURL', { value: revokeObjectURLMock }) - - // Mock createElement and appendChild - const mockLink = { - href: '', - download: '', - click: vi.fn(), - remove: vi.fn(), - } - const createElementMock = vi.spyOn(document, 'createElement').mockReturnValue(mockLink as any) - const appendChildMock = vi.spyOn(document.body, 'appendChild').mockImplementation((node: Node) => { - return node - }) - - // Call the function - downloadFile({ data: blob, fileName }) - - // Assertions - expect(createObjectURLMock).toHaveBeenCalledWith(blob) - expect(createElementMock).toHaveBeenCalledWith('a') - expect(mockLink.href).toBe(mockUrl) - expect(mockLink.download).toBe(fileName) - expect(appendChildMock).toHaveBeenCalledWith(mockLink) - expect(mockLink.click).toHaveBeenCalled() - expect(mockLink.remove).toHaveBeenCalled() - expect(revokeObjectURLMock).toHaveBeenCalledWith(mockUrl) - - // Clean up mocks - vi.restoreAllMocks() - }) -}) - describe('formatNumberAbbreviated', () => { it('should return number as string when less than 1000', () => { expect(formatNumberAbbreviated(0)).toBe('0') diff --git a/web/utils/format.ts b/web/utils/format.ts index ce813d3999..d6968e0ef1 100644 --- a/web/utils/format.ts +++ b/web/utils/format.ts @@ -100,17 +100,6 @@ export const formatTime = (seconds: number) => { return `${seconds.toFixed(2)} ${units[index]}` } -export const downloadFile = ({ data, fileName }: { data: Blob, fileName: string }) => { - const url = window.URL.createObjectURL(data) - const a = document.createElement('a') - a.href = url - a.download = fileName - document.body.appendChild(a) - a.click() - a.remove() - window.URL.revokeObjectURL(url) -} - /** * Formats a number into a readable string using "k", "M", or "B" suffix. * @example